plunder

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

Tab.hs (17358B)


      1 -- Copyright 2023 The Plunder Authors
      2 -- Use of this source code is governed by a BSD-style license that can be
      3 -- found in the LICENSE file.
      4 
      5 -- | Maps and Sets as Sorted Vectors
      6 
      7 {-# OPTIONS_GHC -Wall        #-}
      8 {-# OPTIONS_GHC -Werror      #-}
      9 {-# OPTIONS_GHC -Wno-orphans #-}
     10 {-# LANGUAGE UnboxedTuples #-}
     11 
     12 module Data.Sorted.Tab
     13     ( mkTab
     14     , tabSingleton
     15     , tabInsert
     16     , tabLookup
     17     , tabSize
     18     , tabElemAt
     19     , tabSplit
     20     , tabSplitAt
     21     , tabSpanAntitone
     22     , tabMap
     23     , tabMapWithKey
     24     , tabUnion
     25     , tabUnionWith
     26     , tabIntersection
     27     , tabDifference
     28     , tabLookupMin
     29     , tabLookupMax
     30     , tabAlter
     31     , tabDelete
     32     , tabMember
     33     , tabElemsList
     34     , tabElemsArray
     35     , tabFoldlWithKey'
     36     , tabFilterWithKey
     37     , tabKeysSet
     38     , tabKeysList
     39     , tabKeysArray
     40     , tabToAscPairsList
     41     , tabToDescPairsList
     42     )
     43 where
     44 
     45 import Prelude
     46 
     47 import Control.Monad.ST
     48 import Data.Containers
     49 import Data.Foldable
     50 import Data.MonoTraversable
     51 import Data.Primitive.Array
     52 import Data.Sorted.Row
     53 import Data.Sorted.Search
     54 import Data.Sorted.Set
     55 import Data.Sorted.Types
     56 
     57 import PlunderPrelude (on)
     58 import GHC.Exts       (Int(..), indexArray#, sizeofArray#, (+#))
     59 
     60 -- import qualified Fan.Prof as Prof
     61 
     62 
     63 -- Searching -------------------------------------------------------------------
     64 
     65 {-# INLINE emptyTab #-}
     66 emptyTab :: Tab k v
     67 emptyTab = TAB mempty mempty
     68 
     69 {-# INLINE tabSingleton #-}
     70 tabSingleton :: k -> v -> Tab k v
     71 tabSingleton k v = TAB (rowSingleton k) (rowSingleton v)
     72 
     73 -- The first key MUST be strictly smaller than the second.
     74 {-# INLINE tabUnsafeDuo #-}
     75 tabUnsafeDuo :: k -> v -> k -> v -> Tab k v
     76 tabUnsafeDuo xk xv yk yv =
     77     TAB (rowDuo xk yk) (rowDuo xv yv)
     78 
     79 -- If found, update the values array at the found index.  Otherwise insert
     80 -- the key and value at the found-index of the relevent arrays.
     81 tabInsert :: Ord k => k -> v -> Tab k v -> Tab k v
     82 tabInsert k v (TAB ks@(Array ks#) vs) =
     83     -- Prof.withSimpleTracingEventPure "tabInsert" "sorted" $
     84     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
     85     let i = I# i# in
     86     case found of
     87         1# -> TAB ks (rowUnsafePut i v vs)
     88         _  -> TAB (rowInsert i k ks) (rowInsert i v vs)
     89 
     90 -- If found, merge the two values with (merge newVal oldVal).  Otherwise
     91 -- insert the key and value at the found-index of the relevant arrays.
     92 tabInsertWith :: Ord k => (v -> v -> v) -> k -> v -> Tab k v -> Tab k v
     93 tabInsertWith merge k v (TAB ks@(Array ks#) vs) =
     94     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
     95     let i = I# i# in
     96     case found of
     97         1# -> TAB ks $ rowUnsafePut i (merge v (vs!i)) vs
     98         _  -> TAB (rowInsert i k ks) (rowInsert i v vs)
     99 
    100 -- Do a search on the keys set, if we found a match, return the matching
    101 -- value in the values array.
    102 tabLookup :: Ord k => k -> Tab k v -> Maybe v
    103 tabLookup k (TAB (Array ks#) (Array vs#)) =
    104     -- Prof.withSimpleTracingEventPure "tabLookup" "sorted" $
    105     let !(# i, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
    106     case found of
    107         0# -> Nothing
    108         _  -> case indexArray# vs# i of
    109                   (# res #) -> Just res
    110 
    111 {-# INLINE tabSize #-}
    112 tabSize :: Tab k v -> Int
    113 tabSize (TAB ks _) = sizeofArray ks
    114 
    115 {-# INLINE tabElemAt #-}
    116 tabElemAt :: Int -> Tab k v -> (k, v)
    117 tabElemAt i (TAB ks vs) =
    118     if i > length ks
    119     then error "tabElemAt: out-of-bounds"
    120     else (ks!i, vs!i)
    121 
    122 {-# INLINE tabSplitAt #-}
    123 tabSplitAt :: Int -> Tab k v -> (Tab k v, Tab k v)
    124 tabSplitAt i (TAB ks vs) =
    125     ( TAB (rowTake i ks) (rowTake i vs)
    126     , TAB (rowDrop i ks) (rowDrop i vs)
    127     )
    128 
    129 -- Find index, call split (TODO: What behavior on found vs not-found?
    130 -- Avoid off-by-one-errors)
    131 {-# INLINE tabSplit #-}
    132 tabSplit :: Ord k => k -> Tab k v -> (Tab k v, Tab k v)
    133 tabSplit k (TAB ks@(Array ks#) vs) =
    134     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#)
    135         i = I# i#
    136         j = I# (i# +# found)
    137     in 
    138         ( TAB (rowTake i ks) (rowTake i vs)
    139         , TAB (rowDrop j ks) (rowDrop j vs)
    140         )
    141 
    142 {-# INLINE tabSpanAntitone #-}
    143 tabSpanAntitone :: (k -> Bool) -> Tab k v -> (Tab k v, Tab k v)
    144 tabSpanAntitone f (TAB ks vs) =
    145     ( TAB (rowTake numTrue ks) (rowTake numTrue vs)
    146     , TAB (rowDrop numTrue ks) (rowDrop numTrue vs)
    147     )
    148   where
    149     numTrue = bfind f ks
    150 
    151 {-# INLINE tabMapWithKey #-}
    152 tabMapWithKey :: (k -> v -> a) -> Tab k v -> Tab k a
    153 tabMapWithKey f (TAB ks vs) = TAB ks (rowZipWith f ks vs)
    154 
    155 {-# INLINE tabMap #-}
    156 tabMap :: (a -> b) -> Tab k a -> Tab k b
    157 tabMap f (TAB k v) = TAB k (f <$> v)
    158 
    159 {-# INLINE tabUnion #-}
    160 tabUnion :: Ord k => Tab k v -> Tab k v -> Tab k v
    161 tabUnion = tabUnionWith const
    162 
    163 -- O(n) union
    164 tabUnionWith :: Ord k => (v -> v -> v) -> Tab k v -> Tab k v -> Tab k v
    165 tabUnionWith merge x@(TAB xKeys xVals) y@(TAB yKeys yVals) =
    166     -- Prof.withSimpleTracingEventPure "tabUnion" "sorted" $
    167     case (sizeofArray xKeys, sizeofArray yKeys) of
    168         ( 0,  _  ) -> y
    169         ( _,  0  ) -> x
    170         ( 1,  1  ) -> let xk = xKeys!0
    171                           yk = yKeys!0
    172                           xv = xVals!0
    173                           yv = yVals!0
    174                       in case compare xk yk of
    175                           LT -> tabUnsafeDuo xk xv yk yv
    176                           GT -> tabUnsafeDuo yk yv xk xv
    177                           EQ -> tabSingleton xk (merge xv yv)
    178         ( 1,  _  ) -> tabInsertWith merge        (xKeys!0) (xVals!0) y
    179         ( _,  1  ) -> tabInsertWith (flip merge) (yKeys!0) (yVals!0) x
    180         ( xw, yw ) -> tabUnionWithGeneric merge x xw y yw
    181 
    182 tabUnionWithGeneric
    183     :: Ord k => (v -> v -> v) -> Tab k v -> Int -> Tab k v -> Int -> Tab k v
    184 tabUnionWithGeneric merge (TAB xKeys xVals) !xWid (TAB yKeys yVals) !yWid =
    185   runST do
    186     let rWid = xWid + yWid
    187 
    188     valsBuf <- newArray rWid (error "ssetUnion: uninitialized")
    189     keysBuf <- newArray rWid (error "ssetUnion: uninitialized")
    190 
    191     let go o i j = do
    192             let xRemain = xWid - i
    193             let yRemain = yWid - j
    194             case (xRemain, yRemain) of
    195                 (0, 0) -> pure o
    196                 (0, _) -> do
    197                     copyArray keysBuf o yKeys j yRemain
    198                     copyArray valsBuf o yVals j yRemain
    199                     pure (o + yRemain)
    200                 (_, 0) -> do
    201                     copyArray keysBuf o xKeys i xRemain
    202                     copyArray valsBuf o xVals i xRemain
    203                     pure (o + xRemain)
    204                 (_, _) -> do
    205                     let x = xKeys ! i
    206                     let y = yKeys ! j
    207                     case compare x y of
    208                         EQ -> do writeArray keysBuf o x
    209                                  writeArray valsBuf o (merge (xVals!i) (yVals!j))
    210                                  go (o+1) (i+1) (j+1)
    211                         LT -> do writeArray keysBuf o x
    212                                  writeArray valsBuf o (xVals!i)
    213                                  go (o+1) (i+1) j
    214                         GT -> do writeArray keysBuf o y
    215                                  writeArray valsBuf o (yVals!j)
    216                                  go (o+1) i     (j+1)
    217 
    218     written <- go 0 0 0
    219 
    220     if written == rWid
    221     then TAB <$> unsafeFreezeArray keysBuf
    222              <*> unsafeFreezeArray valsBuf
    223     else TAB <$> freezeArray keysBuf 0 written
    224              <*> freezeArray valsBuf 0 written
    225 
    226 -- O(n) tab difference
    227 tabDifference :: Ord k => Tab k v -> Tab k v -> Tab k v
    228 tabDifference x (TAB yKeys _) | null yKeys = x
    229 tabDifference (TAB xKeys _) _ | null xKeys = mempty
    230 tabDifference (TAB xKeys xVals) (TAB yKeys yVals) =
    231     -- Prof.withSimpleTracingEventPure "tabDifference" "sorted" $
    232     runST do
    233     let xWid = sizeofArray xKeys
    234     let yWid = sizeofArray yVals
    235     rKeys <- newArray xWid (error "tabDifference: uninitialized")
    236     rVals <- newArray xWid (error "tabDifference: uninitialized")
    237     let go o i j =
    238             if i >= xWid then pure o else
    239             if j >= yWid then do
    240                 let extra = xWid - i
    241                 copyArray rKeys o xKeys i extra
    242                 copyArray rVals o xVals i extra
    243                 pure (o + extra)
    244             else do
    245                 let x = xKeys ! i
    246                 let y = yKeys ! j
    247                 case compare x y of
    248                     LT -> do
    249                         writeArray rKeys o x
    250                         writeArray rVals o (xVals!i)
    251                         go (o+1) (i+1) j
    252                     EQ -> go o (i+1) (j+1)
    253                     GT -> go o i (j+1)
    254     used <- go 0 0 0
    255     if used == 0 then
    256         pure mempty
    257     else if used == xWid then
    258         TAB <$> unsafeFreezeArray rKeys <*> unsafeFreezeArray rVals
    259     else
    260         TAB <$> freezeArray rKeys 0 used <*> freezeArray rVals 0 used
    261 
    262 {-# INLINE tabIntersection #-}
    263 tabIntersection :: Ord k => Tab k v -> Tab k v -> Tab k v
    264 tabIntersection x y = tabIntersectionWith const x y
    265 
    266 tabIntersectionWith :: Ord k => (v -> v -> v) -> Tab k v -> Tab k v -> Tab k v
    267 tabIntersectionWith f x@(TAB xKeys xVals) y@(TAB yKeys yVals) =
    268     -- Prof.withSimpleTracingEventPure "tabIntersectionWith" "sorted" $
    269     case (sizeofArray xKeys, sizeofArray yKeys) of
    270         ( 0,  _  ) -> mempty
    271         ( _,  0  ) -> mempty
    272         ( 1,  1  ) -> let xk = xKeys!0
    273                       in if xk == (yKeys!0)
    274                          then tabSingleton xk (f (xVals!0) (yVals!0))
    275                          else mempty
    276         ( 1,  _  ) -> let xk = xKeys!0 in
    277                       case tabLookup xk y of
    278                           Nothing -> mempty
    279                           Just yv -> tabSingleton xk (f (xVals!0) yv)
    280         ( _,  1  ) -> let yk = yKeys!0 in
    281                       case tabLookup yk x of
    282                           Nothing -> mempty
    283                           Just xv -> tabSingleton yk (f xv (yVals!0))
    284         ( xw, yw ) -> tabIntersectionWithGeneric f x xw y yw
    285 
    286 tabIntersectionWithGeneric
    287     :: Ord k => (v -> v -> v) -> Tab k v -> Int -> Tab k v -> Int -> Tab k v
    288 tabIntersectionWithGeneric merge (TAB xKeys xVals) xWid (TAB yKeys yVals) yWid =
    289   runST do
    290     let rWid = min xWid yWid
    291     rKeys <- newArray rWid (error "setIntersection: uninitialized")
    292     rVals <- newArray rWid (error "setIntersection: uninitialized")
    293     let go o i j =
    294             if i >= xWid || j >= yWid then pure o else do
    295                 let x = xKeys ! i
    296                 let y = yKeys ! j
    297                 case compare x y of
    298                     EQ -> do
    299                         writeArray rKeys o x
    300                         writeArray rVals o $ merge (xVals!i) (yVals!j)
    301                         go (o+1) (i+1) (j+1)
    302                     LT -> go o (i+1) j
    303                     GT -> go o i (j+1)
    304     used <- go 0 0 0
    305     if used == rWid
    306     then TAB <$> unsafeFreezeArray rKeys <*> unsafeFreezeArray rVals
    307     else TAB <$> freezeArray rKeys 0 used <*> freezeArray rVals 0 used
    308 
    309 {-# INLINE tabLookupMin #-}
    310 tabLookupMin :: Tab k v -> Maybe (k, v)
    311 tabLookupMin (TAB k v) =
    312     if null k then Nothing else Just (k!0, v!0)
    313 
    314 {-# INLINE tabLookupMax #-}
    315 tabLookupMax :: Tab k v -> Maybe (k, v)
    316 tabLookupMax (TAB k v) =
    317     case sizeofArray k of
    318         0 -> Nothing
    319         n -> let !i = n-1 in Just (k!i, v!i)
    320 
    321 tabAlter :: Ord k => (Maybe v -> Maybe v) -> k -> Tab k v -> Tab k v
    322 tabAlter f k tab@(TAB ks@(Array ks#) vs) =
    323     -- Prof.withSimpleTracingEventPure "tabAlter" "sorted" $
    324     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
    325     let i = I# i# in
    326     case found of
    327         1# ->
    328             case f (Just (vs!i)) of
    329                 Nothing -> TAB (rowUnsafeDelete i ks) (rowUnsafeDelete i vs)
    330                 Just v  -> TAB ks                     (rowUnsafePut i v vs)
    331         _ ->
    332             case f Nothing of
    333                 Nothing -> tab -- no change
    334                 Just v  -> TAB (rowInsert i k ks) (rowInsert i v vs)
    335 
    336 tabDelete :: Ord k => k -> Tab k v -> Tab k v
    337 tabDelete k tab@(TAB ks@(Array ks#) vs) =
    338     -- Prof.withSimpleTracingEventPure "tabDelete" "sorted" $
    339     case bsearch# compare k ks# 0# (sizeofArray# ks#) of
    340         (# _,  0# #) -> tab
    341         (# i#, _  #) -> TAB ks' vs'
    342             where i = I# i#
    343                   !ks' = rowUnsafeDelete i ks
    344                   !vs' = rowUnsafeDelete i vs
    345 
    346 {-# INLINE tabMember #-}
    347 tabMember :: Ord k => k -> Tab k v -> Bool
    348 tabMember k (TAB ks _) = ssetMember k (SET ks)
    349 
    350 {-# INLINE tabKeysSet #-}
    351 tabKeysSet :: Tab k v -> Set k
    352 tabKeysSet (TAB k _) = SET k
    353 
    354 {-# INLINE tabKeysArray #-}
    355 tabKeysArray :: Tab k v -> Array k
    356 tabKeysArray (TAB k _) = k
    357 
    358 {-# INLINE tabKeysList #-}
    359 tabKeysList :: Tab k v -> [k]
    360 tabKeysList (TAB k _) = toList k
    361 
    362 {-# INLINE tabFoldlWithKey' #-}
    363 tabFoldlWithKey' :: (a -> k -> v -> a) -> a -> Tab k v -> a
    364 tabFoldlWithKey' f !x (TAB ks vs) =
    365     -- Prof.withSimpleTracingEventPure "tabFoldWithKey" "sorted" do
    366     go 0 x
    367   where
    368     !wid = sizeofArray ks
    369 
    370     go i !acc | i >= wid  = acc
    371     go i !acc | otherwise = go (i+1) $ f acc (ks!i) (vs!i)
    372 
    373 tabFilterWithKey :: (k -> v -> Bool) -> Tab k v -> Tab k v
    374 tabFilterWithKey f tab@(TAB ks vs) =
    375   -- Prof.withSimpleTracingEventPure "tabFilterWithKey" "sorted" $
    376   runST do
    377     let !wid = sizeofArray ks
    378     keysBuf <- newArray wid (error "tabFilterWithKey: uninitialized")
    379     valsBuf <- newArray wid (error "tabFilterWithKey: uninitialized")
    380     let go o i | i >= wid  = pure o
    381         go o i | otherwise = do
    382             let key = ks!i
    383             let val = vs!i
    384             if f key val then do
    385                 writeArray keysBuf o key
    386                 writeArray valsBuf o val
    387                 go (o+1) (i+1)
    388             else do
    389                 go o (i+1)
    390     written <- go 0 0
    391     if written == wid then pure tab else
    392         TAB <$> freezeArray keysBuf 0 written
    393             <*> freezeArray valsBuf 0 written
    394 
    395 {-# INLINE tabElemsArray #-}
    396 tabElemsArray :: Tab k v -> Array v
    397 tabElemsArray (TAB _ v) = v
    398 
    399 {-# INLINE tabElemsList #-}
    400 tabElemsList :: Tab k v -> [v]
    401 tabElemsList (TAB _ v) = toList v
    402 
    403 {-
    404     Creates a table from a list of key-value pairs.  If keys appear
    405     multiple times, values later in the list are used.
    406 
    407     Implementation:
    408 
    409     - Collect the list of pairs into an array.
    410 
    411     - Do a stable sort on the array (comparing only the keys), removing
    412       duplicates.
    413 
    414     - Split the resulting array out into a keys array and a values array.
    415 
    416     Note that rowSortUniqBy chooses earlier values, not later values.
    417     We resolve this by filling the array in reverse.
    418 
    419     TODO: We can skip a copy by directly creating a mutable array,
    420     and doing the sort on that, in place.
    421 -}
    422 tabFromPairsList :: Ord k => [(k,v)] -> Tab k v
    423 tabFromPairsList pairs =
    424     -- Prof.withSimpleTracingEventPure "tabFromPairsList" "sorted" $
    425     let buf = rowSortUniqBy (on compare fst) $ arrayFromListRev pairs in
    426     TAB (fst <$> buf) (snd <$> buf)
    427 
    428 {-# INLINE tabToAscPairsList #-}
    429 tabToAscPairsList :: Tab k v -> [(k,v)]
    430 tabToAscPairsList (TAB k v) = go 0
    431   where
    432     !len = sizeofArray k
    433     go i | i >= len = []
    434     go i            = (k!i, v!i) : go (i+1)
    435 
    436 {-# INLINE tabToDescPairsList #-}
    437 tabToDescPairsList :: Tab k v -> [(k,v)]
    438 tabToDescPairsList (TAB k v) = go (length k - 1)
    439   where
    440     go i | i < 0 = []
    441     go i         = (k!i, v!i) : go (i-1)
    442 
    443 {-# INLINE mkTab #-}
    444 mkTab :: Set k -> Array v -> Tab k v
    445 mkTab (SET k) v =
    446     if sizeofArray k /= sizeofArray v
    447     then error "mkTab: keys must have same size as values"
    448     else TAB k v
    449 
    450 --------------------------------------------------------------------------------
    451 -- TODO: Optimize and verify these instances
    452 
    453 instance Foldable (Tab k) where
    454     foldMap f (TAB _ v) = foldMap f v
    455     {-# INLINE foldMap #-}
    456 
    457 instance MonoFunctor (Tab k v) where
    458     omap = fmap
    459     {-# INLINE omap #-}
    460 
    461 -- TODO: Explicit instances
    462 instance MonoFoldable (Tab k v) where
    463 
    464 instance MonoTraversable (Tab k v) where
    465     otraverse = traverse
    466     omapM = traverse
    467     {-# INLINE otraverse #-}
    468     {-# INLINE omapM #-}
    469 
    470 type instance Element (Tab k v) = v
    471 
    472 instance Ord k => Semigroup (Tab k v) where
    473     (<>) = tabUnion
    474     {-# INLINE (<>) #-}
    475 
    476 instance Ord k => Monoid (Tab k v) where
    477     mempty = emptyTab
    478     {-# INLINE mempty #-}
    479 
    480 -- Assert that append (<>) never produces something smaller (this class
    481 -- how no methods).
    482 instance GrowingAppend (Tab k v) where
    483 
    484 {-
    485     This explicitly implements all methods
    486 
    487     TODO: Better implementation of `unions`!
    488 -}
    489 instance Ord k => SetContainer (Tab k v) where
    490     type ContainerKey (Tab k v) = k
    491     member = tabMember
    492     notMember k t = not (tabMember k t)
    493     union = tabUnion
    494     difference = tabDifference
    495     intersection = tabIntersection
    496     keys = tabKeysList
    497     unions = ofoldl' tabUnion mempty
    498     {-# INLINE member #-}
    499     {-# INLINE notMember #-}
    500     {-# INLINE union #-}
    501     {-# INLINE unions #-}
    502     {-# INLINE difference #-}
    503     {-# INLINE intersection #-}
    504     {-# INLINE keys #-}
    505 
    506 instance Functor (Tab k) where
    507     fmap = tabMap
    508     {-# INLINE fmap #-}
    509 
    510 -- TODO: implement way more methods.
    511 instance Traversable (Tab k) where
    512     traverse f (TAB ks vs) = TAB ks <$> traverse f vs
    513     sequenceA (TAB ks vs)  = TAB ks <$> sequenceA vs
    514 
    515 -- TODO: Implement many more methods.
    516 instance Ord k => IsMap (Tab k v) where
    517     type MapValue (Tab k v) = v
    518     lookup = tabLookup
    519     insertMap = tabInsert
    520     deleteMap = tabDelete
    521     singletonMap = tabSingleton
    522     mapFromList = tabFromPairsList
    523     mapToList = tabToAscPairsList