plunder

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

Set.hs (19091B)


      1 -- Copyright 2023 The Plunder Authors
      2 -- Use of this source code is governed by a BSD-style license that can be
      3 -- found in the LICENSE file.
      4 
      5 -- | Maps and Sets as Sorted Vectors
      6 
      7 {-# OPTIONS_GHC -Wall        #-}
      8 {-# OPTIONS_GHC -Werror      #-}
      9 {-# OPTIONS_GHC -Wno-orphans #-}
     10 {-# LANGUAGE UnboxedTuples #-}
     11 
     12 module Data.Sorted.Set
     13     ( ssetFromList
     14     , ssetToAscArray
     15     , ssetToAscList
     16     , ssetToDescList
     17     , ssetToArray
     18     , ssetSplitAt
     19     , ssetSpanAntitone
     20     , ssetIntersection
     21     , ssetDifference
     22     , ssetSingleton
     23     , ssetInsert
     24     , ssetLookupMin
     25     , ssetLookupMax
     26     , ssetDelete
     27     , ssetSize
     28     , ssetUnion
     29     , ssetIsEmpty
     30     , ssetFromDistinctAscList
     31     , ssetFindMax
     32     , ssetFindMin
     33     , ssetMember
     34     , ssetTake
     35     , ssetDrop
     36     )
     37 where
     38 
     39 import Control.Monad.ST
     40 import Data.Bits
     41 import Data.Containers
     42 import Data.Foldable
     43 import Data.MonoTraversable
     44 import Data.Primitive.Array
     45 import Data.Sorted.Row
     46 import Data.Sorted.Search
     47 import Data.Sorted.Types
     48 import Prelude
     49 
     50 import Data.Coerce  (coerce)
     51 import Data.Functor (($>))
     52 import GHC.Exts     (Array#, Int(..), Int#, sizeofArray#, (+#))
     53 
     54 -- import qualified Fan.Prof as Prof
     55 
     56 
     57 --------------------------------------------------------------------------------
     58 
     59 {-# INLINE emptySet #-}
     60 emptySet :: Set k
     61 emptySet = SET mempty
     62 
     63 {-# INLINE ssetSingleton #-}
     64 ssetSingleton :: k -> Set k
     65 ssetSingleton k = SET (rowSingleton k)
     66 
     67 -- x must be less than y, otherwise the resulting set will be invalid.
     68 {-# INLINE ssetUnsafeDuo #-}
     69 ssetUnsafeDuo :: k -> k -> Set k
     70 ssetUnsafeDuo x y = coerce (rowDuo x y)
     71   -- coerce here avoids the fmap.  SET is a newtype for row, so there
     72   -- is a type-safe cast between the two.
     73 
     74 -- Collect list into an mutable array.  Sort it, remove duplicates.
     75 --
     76 -- TODO: Avoid the copies by loading into a mutable array, sorting the
     77 -- mutable array, and then doing an unsafe freeze of that.
     78 ssetFromList :: Ord k => [k] -> Set k
     79 ssetFromList kList =
     80     -- Prof.withSimpleTracingEventPure "ssetFromList" "sorted" $
     81     SET $ rowSortUniqBy compare $ arrayFromList kList
     82 
     83 {-# INLINE ssetInsert# #-}
     84 ssetInsert# :: Ord a => a -> Array# a -> Int# -> Array# a
     85 ssetInsert# k ks# wid# =
     86     let !(# i, found #) = bsearch# compare k ks# 0# wid# in
     87     case found of
     88         0# -> rowUnsafeInsert# i k ks# wid#
     89         _  -> ks#
     90 
     91 {-# INLINE ssetInsert #-}
     92 ssetInsert :: Ord k => k -> Set k -> Set k
     93 ssetInsert k (SET (Array ks#)) =
     94     -- Prof.withSimpleTracingEventPure "setInsert" "sorted" $
     95     SET (Array (ssetInsert# k ks# (sizeofArray# ks#)))
     96 
     97 {-# INLINE ssetDelete #-}
     98 ssetDelete :: Ord k => k -> Set k -> Set k
     99 ssetDelete k set@(SET ks@(Array ks#)) =
    100     -- Prof.withSimpleTracingEventPure "setDelete" "sorted" $
    101     let !(# i, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
    102     case found of
    103         0# -> set
    104         _  -> SET (rowDelete (I# i) ks)
    105 
    106 {-# INLINE ssetLookupMin #-}
    107 ssetLookupMin :: Set a -> Maybe a
    108 ssetLookupMin (SET ks) | null ks = Nothing
    109 ssetLookupMin (SET ks)           = Just (ks!0)
    110 
    111 {-# INLINE ssetLookupMax #-}
    112 ssetLookupMax :: Set a -> Maybe a
    113 ssetLookupMax (SET ks) =
    114     if wid==0 then Nothing else Just (ks ! (wid-1))
    115   where
    116     !wid = sizeofArray ks
    117 
    118 {-# INLINE ssetToAscArray #-}
    119 ssetToAscArray :: Set k -> Array k
    120 ssetToAscArray (SET a) = a
    121 
    122 {-# INLINE ssetToAscList #-}
    123 ssetToAscList :: Set k -> [k]
    124 ssetToAscList (SET a) = toList a
    125 
    126 {-# INLINE ssetToArray #-}
    127 ssetToArray :: Set k -> Array k
    128 ssetToArray (SET a) = a
    129 
    130 {-# INLINE ssetToDescList #-}
    131 ssetToDescList :: Set k -> [k]
    132 ssetToDescList (SET a) = go (length a - 1)
    133   where
    134     go i | i<0 = []
    135     go i       = (a!i) : go (i-1)
    136 
    137 {-# INLINE ssetSize #-}
    138 ssetSize :: Set k -> Int
    139 ssetSize (SET a) = sizeofArray a
    140 
    141 -- O(n+m) given input sets of size n and m.
    142 --
    143 -- We special-case sets of size zero and one.
    144 ssetUnion :: Ord k => Set k -> Set k -> Set k
    145 ssetUnion x@(SET xs) y@(SET ys) =
    146     case (sizeofArray xs, sizeofArray ys) of
    147         ( 0,  _  ) -> y
    148         ( _,  0  ) -> x
    149         ( 1,  1  ) -> let xv = (xs!0)
    150                           yv = (ys!0)
    151                       in case compare xv yv of
    152                           EQ -> x
    153                           LT -> ssetUnsafeDuo xv yv
    154                           GT -> ssetUnsafeDuo yv xv
    155         ( 1,  _  ) -> ssetInsert (xs!0) y
    156         ( _,  1  ) -> ssetInsert (ys!0) x
    157         ( xw, yw ) -> ssetUnionGeneric x xw y yw
    158 
    159 -- This assumes that neither of the inputs are empty.
    160 --
    161 -- TODO: Skip the shrinking optimization if the sizes are small.
    162 ssetUnionGeneric :: Ord a => Set a -> Int -> Set a -> Int -> Set a
    163 ssetUnionGeneric (SET xs) !xWid (SET ys) !yWid =
    164     -- Prof.withSimpleTracingEventPure "setUnionGeneric" "sorted" $
    165     let
    166         xSmallest = xs ! 0
    167         xLargest  = xs ! (xWid-1)
    168         ySmallest = ys ! 0
    169         yLargest  = ys ! (yWid-1)
    170     in
    171 
    172     -- If there is no overlap, then the union is just array concatenation.
    173     if xSmallest > yLargest then SET (ys <> xs) else
    174     if ySmallest > xLargest then SET (xs <> ys) else
    175 
    176     -- Find the overlapping range of the sets so we can walk merely the
    177     -- parts we know overlap
    178     let
    179         (xMin, yMin, beforeCount, initialArray) =
    180             case compare xSmallest ySmallest of
    181                 EQ -> (0, 0, 0, xs)
    182                 GT -> let yMn = bsearchIndex xSmallest ys
    183                       in (0, yMn, yMn, ys)
    184                 LT -> let xMn = bsearchIndex ySmallest xs
    185                       in (xMn, 0, xMn, xs)
    186 
    187         (xMax, yMax, afterStart, afterCount, finalArray) =
    188             case compare xLargest yLargest of
    189                 EQ -> (xWid, yWid, 0, 0, xs)
    190                 GT -> let xMx = bsearchPostIndex yLargest xs
    191                       in (xMx, yWid, xMx, xWid - xMx, xs)
    192                 LT -> let yMx = bsearchPostIndex xLargest ys
    193                       in (xWid, yMx, yMx, yWid - yMx, ys)
    194         xOverlapWidth = xMax - xMin
    195         yOverlapWidth = yMax - yMin
    196     in
    197         coerce $
    198         if (yOverlapWidth > xOverlapWidth) then
    199             ssetUnionGenericSwapped
    200                 ys yMin yMax yOverlapWidth
    201                 xs xMin xMax xOverlapWidth
    202                 initialArray beforeCount
    203                 finalArray afterStart afterCount
    204         else
    205             ssetUnionGenericSwapped
    206                 xs xMin xMax xOverlapWidth
    207                 ys yMin yMax yOverlapWidth
    208                 initialArray beforeCount
    209                 finalArray afterStart afterCount
    210 
    211 -- TODO: Too many arguments, find a way to rejigger this!
    212 ssetUnionGenericSwapped
    213     :: Ord a
    214     => Array a -> Int -> Int -> Int
    215     -> Array a -> Int -> Int -> Int
    216     -> Array a -> Int
    217     -> Array a -> Int -> Int
    218     -> Array a
    219 ssetUnionGenericSwapped
    220         xs xMin xMax xOverlapWidth
    221         ys yMin yMax yOverlapWidth
    222         initialArray beforeCount
    223         finalArray finalStart finalCount = runST do
    224 
    225     let maxOverlapWidth = xOverlapWidth + yOverlapWidth
    226     buf <- newArray maxOverlapWidth (error "setUnion: uninitialized")
    227 
    228     let go o iLow iEnd jLow jEnd = do
    229             let iRemain = iEnd - iLow
    230             let jRemain = jEnd - jLow
    231             case (iRemain, jRemain) of
    232                 (0, 0) -> pure o
    233                 (0, _) -> copyArray buf o ys jLow jRemain $> (o + jRemain)
    234                 (_, 0) -> copyArray buf o xs iLow iRemain $> (o + iRemain)
    235                 (1, 1) -> do
    236                     let x = xs!iLow
    237                     let y = ys!jLow
    238                     case compare x y of
    239                        LT -> writeArray buf o x >> writeArray buf (o+1) y >> pure (o+2)
    240                        EQ -> writeArray buf o x >> pure (o+1)
    241                        GT -> writeArray buf o y >> writeArray buf (o+1) x >> pure (o+2)
    242                 (1, _) -> do
    243                     let x = xs!iLow
    244                     case bsearch_ x ys jLow jEnd of
    245                         (# _, 1# #) -> do
    246                             copyArray buf o ys jLow jRemain $> (o+jRemain)
    247                         (# jMid#, _ #) -> do
    248                             let jMid   = I# jMid#
    249                             let nBelo  = jMid - jLow
    250                             let nAbove = jEnd - jMid
    251                             copyArray buf o ys jLow nBelo
    252                             writeArray buf (o+nBelo) x
    253                             copyArray buf (o+nBelo+1) ys jMid nAbove
    254                             pure (o+nBelo+1+nAbove)
    255                 (_, 1) -> do
    256                     let y = ys!jLow
    257                     case bsearch_ y xs iLow iEnd of
    258                         (# _, 1# #) -> do
    259                             copyArray buf o xs iLow iRemain $> (o+iRemain)
    260                         (# iMid#, _ #) -> do
    261                             let iMid   = I# iMid#
    262                             let nBelo  = iMid - iLow
    263                             let nAbove = iEnd - iMid
    264                             copyArray buf o xs iLow nBelo
    265                             writeArray buf (o+nBelo) y
    266                             copyArray buf (o+nBelo+1) xs iMid nAbove
    267                             pure (o+nBelo+1+nAbove)
    268 
    269                 (_, _) -> do
    270                     -- Get the middle value for the left-set.
    271                     let iMid               = (iLow + iEnd) `shiftR` 1
    272                     let iMidVal            = xs ! iMid
    273                     let !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd
    274 
    275                     -- Recurse to the left of the split on both.
    276                     o2 <- go o iLow iMid jLow (I# jMid)
    277 
    278                     -- Always write out the pivot value.
    279                     writeArray buf o2 iMidVal
    280                     let o3 = o2+1
    281 
    282                     -- Skip over the pivot on recursion if it matched.
    283                     let iMid' = iMid + 1
    284                     let jMid' = I# (jMid +# found)
    285 
    286                     -- Recurse to the right of the split on both.
    287                     go o3 iMid' iEnd jMid' jEnd
    288 
    289     overlapCount <- go 0 xMin xMax yMin yMax
    290     overlap      <- unsafeFreezeArray buf
    291 
    292     if (overlapCount == maxOverlapWidth && finalCount+beforeCount == 0)
    293     then do
    294         pure overlap
    295     else do
    296         let totalCount = beforeCount + overlapCount + finalCount
    297         res <- newArray totalCount (error "setUnion: uninitialized")
    298         copyArray res 0                          initialArray 0          beforeCount
    299         copyArray res beforeCount                overlap      0          overlapCount
    300         copyArray res (beforeCount+overlapCount) finalArray   finalStart finalCount
    301         unsafeFreezeArray res
    302 
    303 {-# INLINE ssetIsEmpty #-}
    304 ssetIsEmpty :: Set k -> Bool
    305 ssetIsEmpty (SET a) = null a
    306 
    307 -- TODO: Should we check this invariant?
    308 {-# INLINE ssetFromDistinctAscList #-}
    309 ssetFromDistinctAscList :: [k] -> Set k
    310 ssetFromDistinctAscList ksList = SET (arrayFromList ksList)
    311 
    312 {-# INLINE ssetFindMax #-}
    313 ssetFindMax :: Set k -> k
    314 ssetFindMax (SET s) =
    315     case sizeofArray s of
    316         0 -> error "setFindMax: empty set"
    317         n -> s ! (n-1)
    318 
    319 {-# INLINE ssetFindMin #-}
    320 ssetFindMin :: Set k -> k
    321 ssetFindMin (SET s) =
    322     if null s
    323     then error "setFindMin: empty setE"
    324     else s!0
    325 
    326 -- Do a search, return True if if found something, otherwise False.
    327 {-# INLINE ssetMember #-}
    328 ssetMember :: Ord k => k -> Set k -> Bool
    329 ssetMember k (SET (Array ks#)) =
    330     case bsearch# compare k ks# 0# (sizeofArray# ks#) of
    331         (# _, 0# #) -> False
    332         (# _, _  #) -> True
    333 
    334 -- This doesn't affect the order invariants, so we just run the operation
    335 -- directly against the underlying array.
    336 {-# INLINE ssetTake #-}
    337 ssetTake :: Int -> Set k -> Set k
    338 ssetTake i (SET ks) =
    339     -- Prof.withSimpleTracingEventPure "setTake" "sorted" $
    340     SET (rowTake i ks)
    341 
    342 -- This doesn't affect the order invariants, so we just run the operation
    343 -- directly against the underlying array.
    344 {-# INLINE ssetDrop #-}
    345 ssetDrop :: Int -> Set k -> Set k
    346 ssetDrop i (SET ks) =
    347     -- Prof.withSimpleTracingEventPure "setDrop" "sorted" $
    348     SET (rowDrop i ks)
    349 
    350 -- Just split the underlying array, set invariants are not at risk.
    351 {-# INLINE ssetSplitAt #-}
    352 ssetSplitAt :: Int -> Set k -> (Set k, Set k)
    353 ssetSplitAt i (SET ks) = (SET (rowTake i ks), SET (rowDrop i ks))
    354 
    355 -- O(n) set intersection.  Special cases for (size=0 and size=1)
    356 ssetIntersection :: Ord a => Set a -> Set a -> Set a
    357 ssetIntersection x@(SET xs) y@(SET ys) =
    358     -- Prof.withSimpleTracingEventPure "setIntersection" "sorted" $
    359     case (sizeofArray xs, sizeofArray ys) of
    360         ( 0,    _    ) -> mempty
    361         ( _,    0    ) -> mempty
    362         ( 1,    1    ) -> if xs!0 == ys!0 then x else mempty
    363         ( 1,    _    ) -> if ssetMember (xs!0) y then x else mempty
    364         ( _,    1    ) -> if ssetMember (ys!0) x then y else mempty
    365         ( xWid, yWid ) ->
    366             let
    367                 xSmallest = xs ! 0
    368                 yLargest  = ys ! (yWid-1)
    369                 xLargest  = xs ! (xWid-1)
    370                 ySmallest = ys ! 0
    371             in
    372 
    373             -- If no overlap is possible, return empty.
    374             if xSmallest > yLargest then mempty else
    375             if ySmallest > xLargest then mempty else
    376 
    377             -- figure out which subregions overlap.
    378             let
    379                 (xMin, yMin) =
    380                     case compare xSmallest ySmallest of
    381                         EQ -> (0, 0)
    382                         GT -> (0, bsearchIndex xSmallest ys)
    383                         LT -> (bsearchIndex ySmallest xs, 0)
    384 
    385                 (xMax, yMax) =
    386                     case compare xLargest yLargest of
    387                         EQ -> (xWid, yWid)
    388                         GT -> (bsearchPostIndex yLargest xs, yWid)
    389                         LT -> (xWid, bsearchPostIndex xLargest ys)
    390 
    391                 xSz = (xMax - xMin)
    392                 ySz = (yMax - yMin)
    393             in
    394                 -- Run the intersection algorithm against the overlapping
    395                 -- region, with the smaller region as the left-hand-side.
    396                 coerce $
    397                 if xSz > ySz
    398                 then ssetIntersectionGeneric ys yMin yMax ySz xs xMin xMax xSz
    399                 else ssetIntersectionGeneric xs xMin xMax xSz ys yMin yMax ySz
    400 
    401 {-
    402     Performs intersection by divide-and-conquor on two non-empty slices
    403     for two sorted arrays.  Returns a sorted, unique array containing
    404     the intersection.
    405 -}
    406 ssetIntersectionGeneric
    407     :: Ord a
    408     => Array a -> Int -> Int -> Int
    409     -> Array a -> Int -> Int -> Int
    410     -> Array a
    411 ssetIntersectionGeneric xs !xMin !xMax !xSz ys !yMin !yMax !ySz = runST do
    412     let rWid = min ySz xSz
    413 
    414     buf <- newArray rWid (error "setIntersection: uninitialized")
    415 
    416     let go o iLow iEnd jLow jEnd = do
    417               case (iEnd - iLow, jEnd - jLow) of
    418                   (0, _) -> pure o
    419                   (_, 0) -> pure o
    420                   (1, 1) ->
    421                       let xVal = xs!iLow; yVal = ys!jLow in
    422                       if xVal ==yVal
    423                       then writeArray buf o xVal >> pure (o+1)
    424                       else pure o
    425 
    426                   (1, _) ->
    427                       let xVal = xs!iLow in
    428                       case bsearch_ xVal ys jLow jEnd of
    429                           (# _, 0# #) -> pure o
    430                           (# _, _  #) -> writeArray buf o xVal >> pure (o+1)
    431 
    432                   (_, 1) ->
    433                       let yVal = ys!jLow in
    434                       case bsearch_ yVal xs iLow iEnd of
    435                           (# _, 0# #) -> pure o
    436                           (# _, _  #) -> writeArray buf o yVal >> pure (o+1)
    437 
    438                   (_, _) ->
    439 
    440                       -- Get the middle value for the left-set.
    441                       let iMid               = (iLow + iEnd) `shiftR` 1
    442                           iMidVal            = xs ! iMid
    443                           !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd
    444                       in do
    445                               -- Recurse to the left of the split on both.
    446                               o2 <- go o iLow iMid jLow (I# jMid)
    447 
    448                               -- Write out the pivot value, if it exists
    449                               -- in both arrays.
    450                               o3 <- case found of
    451                                         0# -> pure o2
    452                                         _  -> do writeArray buf o2 iMidVal
    453                                                  pure (o2+1)
    454 
    455                               -- Skip over the pivot if it matched.
    456                               let iMid' = iMid + 1
    457                               let jMid' = I# (jMid +# found)
    458 
    459                               -- Recurse to the right of the split on both.
    460                               go o3  iMid' iEnd jMid' jEnd
    461 
    462     used <- go 0 xMin xMax yMin yMax
    463 
    464     if used == rWid
    465     then unsafeFreezeArray buf
    466     else freezeArray buf 0 used
    467 
    468 -- O(n) set difference.
    469 ssetDifference :: Ord a => Set a -> Set a -> Set a
    470 ssetDifference (SET xs) (SET ys) = runST do
    471     let xWid = sizeofArray xs
    472     let yWid = sizeofArray ys
    473     buf <- newArray xWid (error "setDifference: uninitialized")
    474     let go o i j =
    475             if i >= xWid then pure o else
    476             if j >= yWid then do
    477                 let extra = xWid - i
    478                 copyArray buf o xs i extra
    479                 pure (o + extra)
    480             else do
    481                 let x = xs ! i
    482                 let y = ys ! j
    483                 case compare x y of
    484                     LT -> writeArray buf o x >> go (o+1) (i+1) j
    485                     EQ -> go o (i+1) (j+1)
    486                     GT -> go o i (j+1)
    487     used <- go 0 0 0
    488     if used == xWid
    489     then SET <$> unsafeFreezeArray buf
    490     else SET <$> freezeArray buf 0 used
    491 
    492 -- Assuming that the predicate is monotone, find the point where the
    493 -- predicate stops holding, and split the set there.
    494 {-# INLINE ssetSpanAntitone #-}
    495 ssetSpanAntitone :: (a -> Bool) -> Set a -> (Set a, Set a)
    496 ssetSpanAntitone f (SET ks) =
    497     let numTrue = bfind f ks in
    498     ( SET $ rowTake numTrue ks
    499     , SET $ rowDrop numTrue ks
    500     )
    501 
    502 --------------------------------------------------------------------------------
    503 -- TODO: Optimize and verify these instances
    504 
    505 instance MonoFoldable (Set a) where
    506 
    507 type instance Element (Set a) = a
    508 
    509 -- Assert that append (<>) never produces something smaller.
    510 instance GrowingAppend (Set k) where
    511 
    512 instance Ord k => SetContainer (Set k) where
    513     type ContainerKey (Set k) = k
    514     member = ssetMember
    515     notMember k s = not (ssetMember k s)
    516     union = ssetUnion
    517     difference = ssetDifference
    518     intersection = ssetIntersection
    519     keys = ssetToAscList
    520     {-# INLINE member #-}
    521     {-# INLINE notMember #-}
    522     {-# INLINE union #-}
    523     {-# INLINE difference #-}
    524     {-# INLINE intersection #-}
    525     {-# INLINE keys #-}
    526 
    527 instance Ord a => Semigroup (Set a) where
    528     (<>) = ssetUnion
    529     {-# INLINE (<>) #-}
    530 
    531 instance Ord a => Monoid (Set a) where
    532     mempty = emptySet
    533     {-# INLINE mempty #-}
    534 
    535 instance Ord k => IsSet (Set k) where
    536     insertSet = ssetInsert
    537     deleteSet = ssetDelete
    538     singletonSet = ssetSingleton
    539     setFromList = ssetFromList
    540     setToList = ssetToAscList
    541     {-# INLINE insertSet #-}
    542     {-# INLINE deleteSet #-}
    543     {-# INLINE singletonSet #-}
    544     {-# INLINE setFromList #-}
    545     {-# INLINE setToList #-}