Tab.hs (17358B)
1 -- Copyright 2023 The Plunder Authors 2 -- Use of this source code is governed by a BSD-style license that can be 3 -- found in the LICENSE file. 4 5 -- | Maps and Sets as Sorted Vectors 6 7 {-# OPTIONS_GHC -Wall #-} 8 {-# OPTIONS_GHC -Werror #-} 9 {-# OPTIONS_GHC -Wno-orphans #-} 10 {-# LANGUAGE UnboxedTuples #-} 11 12 module Data.Sorted.Tab 13 ( mkTab 14 , tabSingleton 15 , tabInsert 16 , tabLookup 17 , tabSize 18 , tabElemAt 19 , tabSplit 20 , tabSplitAt 21 , tabSpanAntitone 22 , tabMap 23 , tabMapWithKey 24 , tabUnion 25 , tabUnionWith 26 , tabIntersection 27 , tabDifference 28 , tabLookupMin 29 , tabLookupMax 30 , tabAlter 31 , tabDelete 32 , tabMember 33 , tabElemsList 34 , tabElemsArray 35 , tabFoldlWithKey' 36 , tabFilterWithKey 37 , tabKeysSet 38 , tabKeysList 39 , tabKeysArray 40 , tabToAscPairsList 41 , tabToDescPairsList 42 ) 43 where 44 45 import Prelude 46 47 import Control.Monad.ST 48 import Data.Containers 49 import Data.Foldable 50 import Data.MonoTraversable 51 import Data.Primitive.Array 52 import Data.Sorted.Row 53 import Data.Sorted.Search 54 import Data.Sorted.Set 55 import Data.Sorted.Types 56 57 import PlunderPrelude (on) 58 import GHC.Exts (Int(..), indexArray#, sizeofArray#, (+#)) 59 60 -- import qualified Fan.Prof as Prof 61 62 63 -- Searching ------------------------------------------------------------------- 64 65 {-# INLINE emptyTab #-} 66 emptyTab :: Tab k v 67 emptyTab = TAB mempty mempty 68 69 {-# INLINE tabSingleton #-} 70 tabSingleton :: k -> v -> Tab k v 71 tabSingleton k v = TAB (rowSingleton k) (rowSingleton v) 72 73 -- The first key MUST be strictly smaller than the second. 74 {-# INLINE tabUnsafeDuo #-} 75 tabUnsafeDuo :: k -> v -> k -> v -> Tab k v 76 tabUnsafeDuo xk xv yk yv = 77 TAB (rowDuo xk yk) (rowDuo xv yv) 78 79 -- If found, update the values array at the found index. Otherwise insert 80 -- the key and value at the found-index of the relevent arrays. 81 tabInsert :: Ord k => k -> v -> Tab k v -> Tab k v 82 tabInsert k v (TAB ks@(Array ks#) vs) = 83 -- Prof.withSimpleTracingEventPure "tabInsert" "sorted" $ 84 let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in 85 let i = I# i# in 86 case found of 87 1# -> TAB ks (rowUnsafePut i v vs) 88 _ -> TAB (rowInsert i k ks) (rowInsert i v vs) 89 90 -- If found, merge the two values with (merge newVal oldVal). Otherwise 91 -- insert the key and value at the found-index of the relevant arrays. 92 tabInsertWith :: Ord k => (v -> v -> v) -> k -> v -> Tab k v -> Tab k v 93 tabInsertWith merge k v (TAB ks@(Array ks#) vs) = 94 let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in 95 let i = I# i# in 96 case found of 97 1# -> TAB ks $ rowUnsafePut i (merge v (vs!i)) vs 98 _ -> TAB (rowInsert i k ks) (rowInsert i v vs) 99 100 -- Do a search on the keys set, if we found a match, return the matching 101 -- value in the values array. 102 tabLookup :: Ord k => k -> Tab k v -> Maybe v 103 tabLookup k (TAB (Array ks#) (Array vs#)) = 104 -- Prof.withSimpleTracingEventPure "tabLookup" "sorted" $ 105 let !(# i, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in 106 case found of 107 0# -> Nothing 108 _ -> case indexArray# vs# i of 109 (# res #) -> Just res 110 111 {-# INLINE tabSize #-} 112 tabSize :: Tab k v -> Int 113 tabSize (TAB ks _) = sizeofArray ks 114 115 {-# INLINE tabElemAt #-} 116 tabElemAt :: Int -> Tab k v -> (k, v) 117 tabElemAt i (TAB ks vs) = 118 if i > length ks 119 then error "tabElemAt: out-of-bounds" 120 else (ks!i, vs!i) 121 122 {-# INLINE tabSplitAt #-} 123 tabSplitAt :: Int -> Tab k v -> (Tab k v, Tab k v) 124 tabSplitAt i (TAB ks vs) = 125 ( TAB (rowTake i ks) (rowTake i vs) 126 , TAB (rowDrop i ks) (rowDrop i vs) 127 ) 128 129 -- Find index, call split (TODO: What behavior on found vs not-found? 130 -- Avoid off-by-one-errors) 131 {-# INLINE tabSplit #-} 132 tabSplit :: Ord k => k -> Tab k v -> (Tab k v, Tab k v) 133 tabSplit k (TAB ks@(Array ks#) vs) = 134 let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) 135 i = I# i# 136 j = I# (i# +# found) 137 in 138 ( TAB (rowTake i ks) (rowTake i vs) 139 , TAB (rowDrop j ks) (rowDrop j vs) 140 ) 141 142 {-# INLINE tabSpanAntitone #-} 143 tabSpanAntitone :: (k -> Bool) -> Tab k v -> (Tab k v, Tab k v) 144 tabSpanAntitone f (TAB ks vs) = 145 ( TAB (rowTake numTrue ks) (rowTake numTrue vs) 146 , TAB (rowDrop numTrue ks) (rowDrop numTrue vs) 147 ) 148 where 149 numTrue = bfind f ks 150 151 {-# INLINE tabMapWithKey #-} 152 tabMapWithKey :: (k -> v -> a) -> Tab k v -> Tab k a 153 tabMapWithKey f (TAB ks vs) = TAB ks (rowZipWith f ks vs) 154 155 {-# INLINE tabMap #-} 156 tabMap :: (a -> b) -> Tab k a -> Tab k b 157 tabMap f (TAB k v) = TAB k (f <$> v) 158 159 {-# INLINE tabUnion #-} 160 tabUnion :: Ord k => Tab k v -> Tab k v -> Tab k v 161 tabUnion = tabUnionWith const 162 163 -- O(n) union 164 tabUnionWith :: Ord k => (v -> v -> v) -> Tab k v -> Tab k v -> Tab k v 165 tabUnionWith merge x@(TAB xKeys xVals) y@(TAB yKeys yVals) = 166 -- Prof.withSimpleTracingEventPure "tabUnion" "sorted" $ 167 case (sizeofArray xKeys, sizeofArray yKeys) of 168 ( 0, _ ) -> y 169 ( _, 0 ) -> x 170 ( 1, 1 ) -> let xk = xKeys!0 171 yk = yKeys!0 172 xv = xVals!0 173 yv = yVals!0 174 in case compare xk yk of 175 LT -> tabUnsafeDuo xk xv yk yv 176 GT -> tabUnsafeDuo yk yv xk xv 177 EQ -> tabSingleton xk (merge xv yv) 178 ( 1, _ ) -> tabInsertWith merge (xKeys!0) (xVals!0) y 179 ( _, 1 ) -> tabInsertWith (flip merge) (yKeys!0) (yVals!0) x 180 ( xw, yw ) -> tabUnionWithGeneric merge x xw y yw 181 182 tabUnionWithGeneric 183 :: Ord k => (v -> v -> v) -> Tab k v -> Int -> Tab k v -> Int -> Tab k v 184 tabUnionWithGeneric merge (TAB xKeys xVals) !xWid (TAB yKeys yVals) !yWid = 185 runST do 186 let rWid = xWid + yWid 187 188 valsBuf <- newArray rWid (error "ssetUnion: uninitialized") 189 keysBuf <- newArray rWid (error "ssetUnion: uninitialized") 190 191 let go o i j = do 192 let xRemain = xWid - i 193 let yRemain = yWid - j 194 case (xRemain, yRemain) of 195 (0, 0) -> pure o 196 (0, _) -> do 197 copyArray keysBuf o yKeys j yRemain 198 copyArray valsBuf o yVals j yRemain 199 pure (o + yRemain) 200 (_, 0) -> do 201 copyArray keysBuf o xKeys i xRemain 202 copyArray valsBuf o xVals i xRemain 203 pure (o + xRemain) 204 (_, _) -> do 205 let x = xKeys ! i 206 let y = yKeys ! j 207 case compare x y of 208 EQ -> do writeArray keysBuf o x 209 writeArray valsBuf o (merge (xVals!i) (yVals!j)) 210 go (o+1) (i+1) (j+1) 211 LT -> do writeArray keysBuf o x 212 writeArray valsBuf o (xVals!i) 213 go (o+1) (i+1) j 214 GT -> do writeArray keysBuf o y 215 writeArray valsBuf o (yVals!j) 216 go (o+1) i (j+1) 217 218 written <- go 0 0 0 219 220 if written == rWid 221 then TAB <$> unsafeFreezeArray keysBuf 222 <*> unsafeFreezeArray valsBuf 223 else TAB <$> freezeArray keysBuf 0 written 224 <*> freezeArray valsBuf 0 written 225 226 -- O(n) tab difference 227 tabDifference :: Ord k => Tab k v -> Tab k v -> Tab k v 228 tabDifference x (TAB yKeys _) | null yKeys = x 229 tabDifference (TAB xKeys _) _ | null xKeys = mempty 230 tabDifference (TAB xKeys xVals) (TAB yKeys yVals) = 231 -- Prof.withSimpleTracingEventPure "tabDifference" "sorted" $ 232 runST do 233 let xWid = sizeofArray xKeys 234 let yWid = sizeofArray yVals 235 rKeys <- newArray xWid (error "tabDifference: uninitialized") 236 rVals <- newArray xWid (error "tabDifference: uninitialized") 237 let go o i j = 238 if i >= xWid then pure o else 239 if j >= yWid then do 240 let extra = xWid - i 241 copyArray rKeys o xKeys i extra 242 copyArray rVals o xVals i extra 243 pure (o + extra) 244 else do 245 let x = xKeys ! i 246 let y = yKeys ! j 247 case compare x y of 248 LT -> do 249 writeArray rKeys o x 250 writeArray rVals o (xVals!i) 251 go (o+1) (i+1) j 252 EQ -> go o (i+1) (j+1) 253 GT -> go o i (j+1) 254 used <- go 0 0 0 255 if used == 0 then 256 pure mempty 257 else if used == xWid then 258 TAB <$> unsafeFreezeArray rKeys <*> unsafeFreezeArray rVals 259 else 260 TAB <$> freezeArray rKeys 0 used <*> freezeArray rVals 0 used 261 262 {-# INLINE tabIntersection #-} 263 tabIntersection :: Ord k => Tab k v -> Tab k v -> Tab k v 264 tabIntersection x y = tabIntersectionWith const x y 265 266 tabIntersectionWith :: Ord k => (v -> v -> v) -> Tab k v -> Tab k v -> Tab k v 267 tabIntersectionWith f x@(TAB xKeys xVals) y@(TAB yKeys yVals) = 268 -- Prof.withSimpleTracingEventPure "tabIntersectionWith" "sorted" $ 269 case (sizeofArray xKeys, sizeofArray yKeys) of 270 ( 0, _ ) -> mempty 271 ( _, 0 ) -> mempty 272 ( 1, 1 ) -> let xk = xKeys!0 273 in if xk == (yKeys!0) 274 then tabSingleton xk (f (xVals!0) (yVals!0)) 275 else mempty 276 ( 1, _ ) -> let xk = xKeys!0 in 277 case tabLookup xk y of 278 Nothing -> mempty 279 Just yv -> tabSingleton xk (f (xVals!0) yv) 280 ( _, 1 ) -> let yk = yKeys!0 in 281 case tabLookup yk x of 282 Nothing -> mempty 283 Just xv -> tabSingleton yk (f xv (yVals!0)) 284 ( xw, yw ) -> tabIntersectionWithGeneric f x xw y yw 285 286 tabIntersectionWithGeneric 287 :: Ord k => (v -> v -> v) -> Tab k v -> Int -> Tab k v -> Int -> Tab k v 288 tabIntersectionWithGeneric merge (TAB xKeys xVals) xWid (TAB yKeys yVals) yWid = 289 runST do 290 let rWid = min xWid yWid 291 rKeys <- newArray rWid (error "setIntersection: uninitialized") 292 rVals <- newArray rWid (error "setIntersection: uninitialized") 293 let go o i j = 294 if i >= xWid || j >= yWid then pure o else do 295 let x = xKeys ! i 296 let y = yKeys ! j 297 case compare x y of 298 EQ -> do 299 writeArray rKeys o x 300 writeArray rVals o $ merge (xVals!i) (yVals!j) 301 go (o+1) (i+1) (j+1) 302 LT -> go o (i+1) j 303 GT -> go o i (j+1) 304 used <- go 0 0 0 305 if used == rWid 306 then TAB <$> unsafeFreezeArray rKeys <*> unsafeFreezeArray rVals 307 else TAB <$> freezeArray rKeys 0 used <*> freezeArray rVals 0 used 308 309 {-# INLINE tabLookupMin #-} 310 tabLookupMin :: Tab k v -> Maybe (k, v) 311 tabLookupMin (TAB k v) = 312 if null k then Nothing else Just (k!0, v!0) 313 314 {-# INLINE tabLookupMax #-} 315 tabLookupMax :: Tab k v -> Maybe (k, v) 316 tabLookupMax (TAB k v) = 317 case sizeofArray k of 318 0 -> Nothing 319 n -> let !i = n-1 in Just (k!i, v!i) 320 321 tabAlter :: Ord k => (Maybe v -> Maybe v) -> k -> Tab k v -> Tab k v 322 tabAlter f k tab@(TAB ks@(Array ks#) vs) = 323 -- Prof.withSimpleTracingEventPure "tabAlter" "sorted" $ 324 let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in 325 let i = I# i# in 326 case found of 327 1# -> 328 case f (Just (vs!i)) of 329 Nothing -> TAB (rowUnsafeDelete i ks) (rowUnsafeDelete i vs) 330 Just v -> TAB ks (rowUnsafePut i v vs) 331 _ -> 332 case f Nothing of 333 Nothing -> tab -- no change 334 Just v -> TAB (rowInsert i k ks) (rowInsert i v vs) 335 336 tabDelete :: Ord k => k -> Tab k v -> Tab k v 337 tabDelete k tab@(TAB ks@(Array ks#) vs) = 338 -- Prof.withSimpleTracingEventPure "tabDelete" "sorted" $ 339 case bsearch# compare k ks# 0# (sizeofArray# ks#) of 340 (# _, 0# #) -> tab 341 (# i#, _ #) -> TAB ks' vs' 342 where i = I# i# 343 !ks' = rowUnsafeDelete i ks 344 !vs' = rowUnsafeDelete i vs 345 346 {-# INLINE tabMember #-} 347 tabMember :: Ord k => k -> Tab k v -> Bool 348 tabMember k (TAB ks _) = ssetMember k (SET ks) 349 350 {-# INLINE tabKeysSet #-} 351 tabKeysSet :: Tab k v -> Set k 352 tabKeysSet (TAB k _) = SET k 353 354 {-# INLINE tabKeysArray #-} 355 tabKeysArray :: Tab k v -> Array k 356 tabKeysArray (TAB k _) = k 357 358 {-# INLINE tabKeysList #-} 359 tabKeysList :: Tab k v -> [k] 360 tabKeysList (TAB k _) = toList k 361 362 {-# INLINE tabFoldlWithKey' #-} 363 tabFoldlWithKey' :: (a -> k -> v -> a) -> a -> Tab k v -> a 364 tabFoldlWithKey' f !x (TAB ks vs) = 365 -- Prof.withSimpleTracingEventPure "tabFoldWithKey" "sorted" do 366 go 0 x 367 where 368 !wid = sizeofArray ks 369 370 go i !acc | i >= wid = acc 371 go i !acc | otherwise = go (i+1) $ f acc (ks!i) (vs!i) 372 373 tabFilterWithKey :: (k -> v -> Bool) -> Tab k v -> Tab k v 374 tabFilterWithKey f tab@(TAB ks vs) = 375 -- Prof.withSimpleTracingEventPure "tabFilterWithKey" "sorted" $ 376 runST do 377 let !wid = sizeofArray ks 378 keysBuf <- newArray wid (error "tabFilterWithKey: uninitialized") 379 valsBuf <- newArray wid (error "tabFilterWithKey: uninitialized") 380 let go o i | i >= wid = pure o 381 go o i | otherwise = do 382 let key = ks!i 383 let val = vs!i 384 if f key val then do 385 writeArray keysBuf o key 386 writeArray valsBuf o val 387 go (o+1) (i+1) 388 else do 389 go o (i+1) 390 written <- go 0 0 391 if written == wid then pure tab else 392 TAB <$> freezeArray keysBuf 0 written 393 <*> freezeArray valsBuf 0 written 394 395 {-# INLINE tabElemsArray #-} 396 tabElemsArray :: Tab k v -> Array v 397 tabElemsArray (TAB _ v) = v 398 399 {-# INLINE tabElemsList #-} 400 tabElemsList :: Tab k v -> [v] 401 tabElemsList (TAB _ v) = toList v 402 403 {- 404 Creates a table from a list of key-value pairs. If keys appear 405 multiple times, values later in the list are used. 406 407 Implementation: 408 409 - Collect the list of pairs into an array. 410 411 - Do a stable sort on the array (comparing only the keys), removing 412 duplicates. 413 414 - Split the resulting array out into a keys array and a values array. 415 416 Note that rowSortUniqBy chooses earlier values, not later values. 417 We resolve this by filling the array in reverse. 418 419 TODO: We can skip a copy by directly creating a mutable array, 420 and doing the sort on that, in place. 421 -} 422 tabFromPairsList :: Ord k => [(k,v)] -> Tab k v 423 tabFromPairsList pairs = 424 -- Prof.withSimpleTracingEventPure "tabFromPairsList" "sorted" $ 425 let buf = rowSortUniqBy (on compare fst) $ arrayFromListRev pairs in 426 TAB (fst <$> buf) (snd <$> buf) 427 428 {-# INLINE tabToAscPairsList #-} 429 tabToAscPairsList :: Tab k v -> [(k,v)] 430 tabToAscPairsList (TAB k v) = go 0 431 where 432 !len = sizeofArray k 433 go i | i >= len = [] 434 go i = (k!i, v!i) : go (i+1) 435 436 {-# INLINE tabToDescPairsList #-} 437 tabToDescPairsList :: Tab k v -> [(k,v)] 438 tabToDescPairsList (TAB k v) = go (length k - 1) 439 where 440 go i | i < 0 = [] 441 go i = (k!i, v!i) : go (i-1) 442 443 {-# INLINE mkTab #-} 444 mkTab :: Set k -> Array v -> Tab k v 445 mkTab (SET k) v = 446 if sizeofArray k /= sizeofArray v 447 then error "mkTab: keys must have same size as values" 448 else TAB k v 449 450 -------------------------------------------------------------------------------- 451 -- TODO: Optimize and verify these instances 452 453 instance Foldable (Tab k) where 454 foldMap f (TAB _ v) = foldMap f v 455 {-# INLINE foldMap #-} 456 457 instance MonoFunctor (Tab k v) where 458 omap = fmap 459 {-# INLINE omap #-} 460 461 -- TODO: Explicit instances 462 instance MonoFoldable (Tab k v) where 463 464 instance MonoTraversable (Tab k v) where 465 otraverse = traverse 466 omapM = traverse 467 {-# INLINE otraverse #-} 468 {-# INLINE omapM #-} 469 470 type instance Element (Tab k v) = v 471 472 instance Ord k => Semigroup (Tab k v) where 473 (<>) = tabUnion 474 {-# INLINE (<>) #-} 475 476 instance Ord k => Monoid (Tab k v) where 477 mempty = emptyTab 478 {-# INLINE mempty #-} 479 480 -- Assert that append (<>) never produces something smaller (this class 481 -- how no methods). 482 instance GrowingAppend (Tab k v) where 483 484 {- 485 This explicitly implements all methods 486 487 TODO: Better implementation of `unions`! 488 -} 489 instance Ord k => SetContainer (Tab k v) where 490 type ContainerKey (Tab k v) = k 491 member = tabMember 492 notMember k t = not (tabMember k t) 493 union = tabUnion 494 difference = tabDifference 495 intersection = tabIntersection 496 keys = tabKeysList 497 unions = ofoldl' tabUnion mempty 498 {-# INLINE member #-} 499 {-# INLINE notMember #-} 500 {-# INLINE union #-} 501 {-# INLINE unions #-} 502 {-# INLINE difference #-} 503 {-# INLINE intersection #-} 504 {-# INLINE keys #-} 505 506 instance Functor (Tab k) where 507 fmap = tabMap 508 {-# INLINE fmap #-} 509 510 -- TODO: implement way more methods. 511 instance Traversable (Tab k) where 512 traverse f (TAB ks vs) = TAB ks <$> traverse f vs 513 sequenceA (TAB ks vs) = TAB ks <$> sequenceA vs 514 515 -- TODO: Implement many more methods. 516 instance Ord k => IsMap (Tab k v) where 517 type MapValue (Tab k v) = v 518 lookup = tabLookup 519 insertMap = tabInsert 520 deleteMap = tabDelete 521 singletonMap = tabSingleton 522 mapFromList = tabFromPairsList 523 mapToList = tabToAscPairsList