Set.hs (19091B)
1 -- Copyright 2023 The Plunder Authors 2 -- Use of this source code is governed by a BSD-style license that can be 3 -- found in the LICENSE file. 4 5 -- | Maps and Sets as Sorted Vectors 6 7 {-# OPTIONS_GHC -Wall #-} 8 {-# OPTIONS_GHC -Werror #-} 9 {-# OPTIONS_GHC -Wno-orphans #-} 10 {-# LANGUAGE UnboxedTuples #-} 11 12 module Data.Sorted.Set 13 ( ssetFromList 14 , ssetToAscArray 15 , ssetToAscList 16 , ssetToDescList 17 , ssetToArray 18 , ssetSplitAt 19 , ssetSpanAntitone 20 , ssetIntersection 21 , ssetDifference 22 , ssetSingleton 23 , ssetInsert 24 , ssetLookupMin 25 , ssetLookupMax 26 , ssetDelete 27 , ssetSize 28 , ssetUnion 29 , ssetIsEmpty 30 , ssetFromDistinctAscList 31 , ssetFindMax 32 , ssetFindMin 33 , ssetMember 34 , ssetTake 35 , ssetDrop 36 ) 37 where 38 39 import Control.Monad.ST 40 import Data.Bits 41 import Data.Containers 42 import Data.Foldable 43 import Data.MonoTraversable 44 import Data.Primitive.Array 45 import Data.Sorted.Row 46 import Data.Sorted.Search 47 import Data.Sorted.Types 48 import Prelude 49 50 import Data.Coerce (coerce) 51 import Data.Functor (($>)) 52 import GHC.Exts (Array#, Int(..), Int#, sizeofArray#, (+#)) 53 54 -- import qualified Fan.Prof as Prof 55 56 57 -------------------------------------------------------------------------------- 58 59 {-# INLINE emptySet #-} 60 emptySet :: Set k 61 emptySet = SET mempty 62 63 {-# INLINE ssetSingleton #-} 64 ssetSingleton :: k -> Set k 65 ssetSingleton k = SET (rowSingleton k) 66 67 -- x must be less than y, otherwise the resulting set will be invalid. 68 {-# INLINE ssetUnsafeDuo #-} 69 ssetUnsafeDuo :: k -> k -> Set k 70 ssetUnsafeDuo x y = coerce (rowDuo x y) 71 -- coerce here avoids the fmap. SET is a newtype for row, so there 72 -- is a type-safe cast between the two. 73 74 -- Collect list into an mutable array. Sort it, remove duplicates. 75 -- 76 -- TODO: Avoid the copies by loading into a mutable array, sorting the 77 -- mutable array, and then doing an unsafe freeze of that. 78 ssetFromList :: Ord k => [k] -> Set k 79 ssetFromList kList = 80 -- Prof.withSimpleTracingEventPure "ssetFromList" "sorted" $ 81 SET $ rowSortUniqBy compare $ arrayFromList kList 82 83 {-# INLINE ssetInsert# #-} 84 ssetInsert# :: Ord a => a -> Array# a -> Int# -> Array# a 85 ssetInsert# k ks# wid# = 86 let !(# i, found #) = bsearch# compare k ks# 0# wid# in 87 case found of 88 0# -> rowUnsafeInsert# i k ks# wid# 89 _ -> ks# 90 91 {-# INLINE ssetInsert #-} 92 ssetInsert :: Ord k => k -> Set k -> Set k 93 ssetInsert k (SET (Array ks#)) = 94 -- Prof.withSimpleTracingEventPure "setInsert" "sorted" $ 95 SET (Array (ssetInsert# k ks# (sizeofArray# ks#))) 96 97 {-# INLINE ssetDelete #-} 98 ssetDelete :: Ord k => k -> Set k -> Set k 99 ssetDelete k set@(SET ks@(Array ks#)) = 100 -- Prof.withSimpleTracingEventPure "setDelete" "sorted" $ 101 let !(# i, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in 102 case found of 103 0# -> set 104 _ -> SET (rowDelete (I# i) ks) 105 106 {-# INLINE ssetLookupMin #-} 107 ssetLookupMin :: Set a -> Maybe a 108 ssetLookupMin (SET ks) | null ks = Nothing 109 ssetLookupMin (SET ks) = Just (ks!0) 110 111 {-# INLINE ssetLookupMax #-} 112 ssetLookupMax :: Set a -> Maybe a 113 ssetLookupMax (SET ks) = 114 if wid==0 then Nothing else Just (ks ! (wid-1)) 115 where 116 !wid = sizeofArray ks 117 118 {-# INLINE ssetToAscArray #-} 119 ssetToAscArray :: Set k -> Array k 120 ssetToAscArray (SET a) = a 121 122 {-# INLINE ssetToAscList #-} 123 ssetToAscList :: Set k -> [k] 124 ssetToAscList (SET a) = toList a 125 126 {-# INLINE ssetToArray #-} 127 ssetToArray :: Set k -> Array k 128 ssetToArray (SET a) = a 129 130 {-# INLINE ssetToDescList #-} 131 ssetToDescList :: Set k -> [k] 132 ssetToDescList (SET a) = go (length a - 1) 133 where 134 go i | i<0 = [] 135 go i = (a!i) : go (i-1) 136 137 {-# INLINE ssetSize #-} 138 ssetSize :: Set k -> Int 139 ssetSize (SET a) = sizeofArray a 140 141 -- O(n+m) given input sets of size n and m. 142 -- 143 -- We special-case sets of size zero and one. 144 ssetUnion :: Ord k => Set k -> Set k -> Set k 145 ssetUnion x@(SET xs) y@(SET ys) = 146 case (sizeofArray xs, sizeofArray ys) of 147 ( 0, _ ) -> y 148 ( _, 0 ) -> x 149 ( 1, 1 ) -> let xv = (xs!0) 150 yv = (ys!0) 151 in case compare xv yv of 152 EQ -> x 153 LT -> ssetUnsafeDuo xv yv 154 GT -> ssetUnsafeDuo yv xv 155 ( 1, _ ) -> ssetInsert (xs!0) y 156 ( _, 1 ) -> ssetInsert (ys!0) x 157 ( xw, yw ) -> ssetUnionGeneric x xw y yw 158 159 -- This assumes that neither of the inputs are empty. 160 -- 161 -- TODO: Skip the shrinking optimization if the sizes are small. 162 ssetUnionGeneric :: Ord a => Set a -> Int -> Set a -> Int -> Set a 163 ssetUnionGeneric (SET xs) !xWid (SET ys) !yWid = 164 -- Prof.withSimpleTracingEventPure "setUnionGeneric" "sorted" $ 165 let 166 xSmallest = xs ! 0 167 xLargest = xs ! (xWid-1) 168 ySmallest = ys ! 0 169 yLargest = ys ! (yWid-1) 170 in 171 172 -- If there is no overlap, then the union is just array concatenation. 173 if xSmallest > yLargest then SET (ys <> xs) else 174 if ySmallest > xLargest then SET (xs <> ys) else 175 176 -- Find the overlapping range of the sets so we can walk merely the 177 -- parts we know overlap 178 let 179 (xMin, yMin, beforeCount, initialArray) = 180 case compare xSmallest ySmallest of 181 EQ -> (0, 0, 0, xs) 182 GT -> let yMn = bsearchIndex xSmallest ys 183 in (0, yMn, yMn, ys) 184 LT -> let xMn = bsearchIndex ySmallest xs 185 in (xMn, 0, xMn, xs) 186 187 (xMax, yMax, afterStart, afterCount, finalArray) = 188 case compare xLargest yLargest of 189 EQ -> (xWid, yWid, 0, 0, xs) 190 GT -> let xMx = bsearchPostIndex yLargest xs 191 in (xMx, yWid, xMx, xWid - xMx, xs) 192 LT -> let yMx = bsearchPostIndex xLargest ys 193 in (xWid, yMx, yMx, yWid - yMx, ys) 194 xOverlapWidth = xMax - xMin 195 yOverlapWidth = yMax - yMin 196 in 197 coerce $ 198 if (yOverlapWidth > xOverlapWidth) then 199 ssetUnionGenericSwapped 200 ys yMin yMax yOverlapWidth 201 xs xMin xMax xOverlapWidth 202 initialArray beforeCount 203 finalArray afterStart afterCount 204 else 205 ssetUnionGenericSwapped 206 xs xMin xMax xOverlapWidth 207 ys yMin yMax yOverlapWidth 208 initialArray beforeCount 209 finalArray afterStart afterCount 210 211 -- TODO: Too many arguments, find a way to rejigger this! 212 ssetUnionGenericSwapped 213 :: Ord a 214 => Array a -> Int -> Int -> Int 215 -> Array a -> Int -> Int -> Int 216 -> Array a -> Int 217 -> Array a -> Int -> Int 218 -> Array a 219 ssetUnionGenericSwapped 220 xs xMin xMax xOverlapWidth 221 ys yMin yMax yOverlapWidth 222 initialArray beforeCount 223 finalArray finalStart finalCount = runST do 224 225 let maxOverlapWidth = xOverlapWidth + yOverlapWidth 226 buf <- newArray maxOverlapWidth (error "setUnion: uninitialized") 227 228 let go o iLow iEnd jLow jEnd = do 229 let iRemain = iEnd - iLow 230 let jRemain = jEnd - jLow 231 case (iRemain, jRemain) of 232 (0, 0) -> pure o 233 (0, _) -> copyArray buf o ys jLow jRemain $> (o + jRemain) 234 (_, 0) -> copyArray buf o xs iLow iRemain $> (o + iRemain) 235 (1, 1) -> do 236 let x = xs!iLow 237 let y = ys!jLow 238 case compare x y of 239 LT -> writeArray buf o x >> writeArray buf (o+1) y >> pure (o+2) 240 EQ -> writeArray buf o x >> pure (o+1) 241 GT -> writeArray buf o y >> writeArray buf (o+1) x >> pure (o+2) 242 (1, _) -> do 243 let x = xs!iLow 244 case bsearch_ x ys jLow jEnd of 245 (# _, 1# #) -> do 246 copyArray buf o ys jLow jRemain $> (o+jRemain) 247 (# jMid#, _ #) -> do 248 let jMid = I# jMid# 249 let nBelo = jMid - jLow 250 let nAbove = jEnd - jMid 251 copyArray buf o ys jLow nBelo 252 writeArray buf (o+nBelo) x 253 copyArray buf (o+nBelo+1) ys jMid nAbove 254 pure (o+nBelo+1+nAbove) 255 (_, 1) -> do 256 let y = ys!jLow 257 case bsearch_ y xs iLow iEnd of 258 (# _, 1# #) -> do 259 copyArray buf o xs iLow iRemain $> (o+iRemain) 260 (# iMid#, _ #) -> do 261 let iMid = I# iMid# 262 let nBelo = iMid - iLow 263 let nAbove = iEnd - iMid 264 copyArray buf o xs iLow nBelo 265 writeArray buf (o+nBelo) y 266 copyArray buf (o+nBelo+1) xs iMid nAbove 267 pure (o+nBelo+1+nAbove) 268 269 (_, _) -> do 270 -- Get the middle value for the left-set. 271 let iMid = (iLow + iEnd) `shiftR` 1 272 let iMidVal = xs ! iMid 273 let !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd 274 275 -- Recurse to the left of the split on both. 276 o2 <- go o iLow iMid jLow (I# jMid) 277 278 -- Always write out the pivot value. 279 writeArray buf o2 iMidVal 280 let o3 = o2+1 281 282 -- Skip over the pivot on recursion if it matched. 283 let iMid' = iMid + 1 284 let jMid' = I# (jMid +# found) 285 286 -- Recurse to the right of the split on both. 287 go o3 iMid' iEnd jMid' jEnd 288 289 overlapCount <- go 0 xMin xMax yMin yMax 290 overlap <- unsafeFreezeArray buf 291 292 if (overlapCount == maxOverlapWidth && finalCount+beforeCount == 0) 293 then do 294 pure overlap 295 else do 296 let totalCount = beforeCount + overlapCount + finalCount 297 res <- newArray totalCount (error "setUnion: uninitialized") 298 copyArray res 0 initialArray 0 beforeCount 299 copyArray res beforeCount overlap 0 overlapCount 300 copyArray res (beforeCount+overlapCount) finalArray finalStart finalCount 301 unsafeFreezeArray res 302 303 {-# INLINE ssetIsEmpty #-} 304 ssetIsEmpty :: Set k -> Bool 305 ssetIsEmpty (SET a) = null a 306 307 -- TODO: Should we check this invariant? 308 {-# INLINE ssetFromDistinctAscList #-} 309 ssetFromDistinctAscList :: [k] -> Set k 310 ssetFromDistinctAscList ksList = SET (arrayFromList ksList) 311 312 {-# INLINE ssetFindMax #-} 313 ssetFindMax :: Set k -> k 314 ssetFindMax (SET s) = 315 case sizeofArray s of 316 0 -> error "setFindMax: empty set" 317 n -> s ! (n-1) 318 319 {-# INLINE ssetFindMin #-} 320 ssetFindMin :: Set k -> k 321 ssetFindMin (SET s) = 322 if null s 323 then error "setFindMin: empty setE" 324 else s!0 325 326 -- Do a search, return True if if found something, otherwise False. 327 {-# INLINE ssetMember #-} 328 ssetMember :: Ord k => k -> Set k -> Bool 329 ssetMember k (SET (Array ks#)) = 330 case bsearch# compare k ks# 0# (sizeofArray# ks#) of 331 (# _, 0# #) -> False 332 (# _, _ #) -> True 333 334 -- This doesn't affect the order invariants, so we just run the operation 335 -- directly against the underlying array. 336 {-# INLINE ssetTake #-} 337 ssetTake :: Int -> Set k -> Set k 338 ssetTake i (SET ks) = 339 -- Prof.withSimpleTracingEventPure "setTake" "sorted" $ 340 SET (rowTake i ks) 341 342 -- This doesn't affect the order invariants, so we just run the operation 343 -- directly against the underlying array. 344 {-# INLINE ssetDrop #-} 345 ssetDrop :: Int -> Set k -> Set k 346 ssetDrop i (SET ks) = 347 -- Prof.withSimpleTracingEventPure "setDrop" "sorted" $ 348 SET (rowDrop i ks) 349 350 -- Just split the underlying array, set invariants are not at risk. 351 {-# INLINE ssetSplitAt #-} 352 ssetSplitAt :: Int -> Set k -> (Set k, Set k) 353 ssetSplitAt i (SET ks) = (SET (rowTake i ks), SET (rowDrop i ks)) 354 355 -- O(n) set intersection. Special cases for (size=0 and size=1) 356 ssetIntersection :: Ord a => Set a -> Set a -> Set a 357 ssetIntersection x@(SET xs) y@(SET ys) = 358 -- Prof.withSimpleTracingEventPure "setIntersection" "sorted" $ 359 case (sizeofArray xs, sizeofArray ys) of 360 ( 0, _ ) -> mempty 361 ( _, 0 ) -> mempty 362 ( 1, 1 ) -> if xs!0 == ys!0 then x else mempty 363 ( 1, _ ) -> if ssetMember (xs!0) y then x else mempty 364 ( _, 1 ) -> if ssetMember (ys!0) x then y else mempty 365 ( xWid, yWid ) -> 366 let 367 xSmallest = xs ! 0 368 yLargest = ys ! (yWid-1) 369 xLargest = xs ! (xWid-1) 370 ySmallest = ys ! 0 371 in 372 373 -- If no overlap is possible, return empty. 374 if xSmallest > yLargest then mempty else 375 if ySmallest > xLargest then mempty else 376 377 -- figure out which subregions overlap. 378 let 379 (xMin, yMin) = 380 case compare xSmallest ySmallest of 381 EQ -> (0, 0) 382 GT -> (0, bsearchIndex xSmallest ys) 383 LT -> (bsearchIndex ySmallest xs, 0) 384 385 (xMax, yMax) = 386 case compare xLargest yLargest of 387 EQ -> (xWid, yWid) 388 GT -> (bsearchPostIndex yLargest xs, yWid) 389 LT -> (xWid, bsearchPostIndex xLargest ys) 390 391 xSz = (xMax - xMin) 392 ySz = (yMax - yMin) 393 in 394 -- Run the intersection algorithm against the overlapping 395 -- region, with the smaller region as the left-hand-side. 396 coerce $ 397 if xSz > ySz 398 then ssetIntersectionGeneric ys yMin yMax ySz xs xMin xMax xSz 399 else ssetIntersectionGeneric xs xMin xMax xSz ys yMin yMax ySz 400 401 {- 402 Performs intersection by divide-and-conquor on two non-empty slices 403 for two sorted arrays. Returns a sorted, unique array containing 404 the intersection. 405 -} 406 ssetIntersectionGeneric 407 :: Ord a 408 => Array a -> Int -> Int -> Int 409 -> Array a -> Int -> Int -> Int 410 -> Array a 411 ssetIntersectionGeneric xs !xMin !xMax !xSz ys !yMin !yMax !ySz = runST do 412 let rWid = min ySz xSz 413 414 buf <- newArray rWid (error "setIntersection: uninitialized") 415 416 let go o iLow iEnd jLow jEnd = do 417 case (iEnd - iLow, jEnd - jLow) of 418 (0, _) -> pure o 419 (_, 0) -> pure o 420 (1, 1) -> 421 let xVal = xs!iLow; yVal = ys!jLow in 422 if xVal ==yVal 423 then writeArray buf o xVal >> pure (o+1) 424 else pure o 425 426 (1, _) -> 427 let xVal = xs!iLow in 428 case bsearch_ xVal ys jLow jEnd of 429 (# _, 0# #) -> pure o 430 (# _, _ #) -> writeArray buf o xVal >> pure (o+1) 431 432 (_, 1) -> 433 let yVal = ys!jLow in 434 case bsearch_ yVal xs iLow iEnd of 435 (# _, 0# #) -> pure o 436 (# _, _ #) -> writeArray buf o yVal >> pure (o+1) 437 438 (_, _) -> 439 440 -- Get the middle value for the left-set. 441 let iMid = (iLow + iEnd) `shiftR` 1 442 iMidVal = xs ! iMid 443 !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd 444 in do 445 -- Recurse to the left of the split on both. 446 o2 <- go o iLow iMid jLow (I# jMid) 447 448 -- Write out the pivot value, if it exists 449 -- in both arrays. 450 o3 <- case found of 451 0# -> pure o2 452 _ -> do writeArray buf o2 iMidVal 453 pure (o2+1) 454 455 -- Skip over the pivot if it matched. 456 let iMid' = iMid + 1 457 let jMid' = I# (jMid +# found) 458 459 -- Recurse to the right of the split on both. 460 go o3 iMid' iEnd jMid' jEnd 461 462 used <- go 0 xMin xMax yMin yMax 463 464 if used == rWid 465 then unsafeFreezeArray buf 466 else freezeArray buf 0 used 467 468 -- O(n) set difference. 469 ssetDifference :: Ord a => Set a -> Set a -> Set a 470 ssetDifference (SET xs) (SET ys) = runST do 471 let xWid = sizeofArray xs 472 let yWid = sizeofArray ys 473 buf <- newArray xWid (error "setDifference: uninitialized") 474 let go o i j = 475 if i >= xWid then pure o else 476 if j >= yWid then do 477 let extra = xWid - i 478 copyArray buf o xs i extra 479 pure (o + extra) 480 else do 481 let x = xs ! i 482 let y = ys ! j 483 case compare x y of 484 LT -> writeArray buf o x >> go (o+1) (i+1) j 485 EQ -> go o (i+1) (j+1) 486 GT -> go o i (j+1) 487 used <- go 0 0 0 488 if used == xWid 489 then SET <$> unsafeFreezeArray buf 490 else SET <$> freezeArray buf 0 used 491 492 -- Assuming that the predicate is monotone, find the point where the 493 -- predicate stops holding, and split the set there. 494 {-# INLINE ssetSpanAntitone #-} 495 ssetSpanAntitone :: (a -> Bool) -> Set a -> (Set a, Set a) 496 ssetSpanAntitone f (SET ks) = 497 let numTrue = bfind f ks in 498 ( SET $ rowTake numTrue ks 499 , SET $ rowDrop numTrue ks 500 ) 501 502 -------------------------------------------------------------------------------- 503 -- TODO: Optimize and verify these instances 504 505 instance MonoFoldable (Set a) where 506 507 type instance Element (Set a) = a 508 509 -- Assert that append (<>) never produces something smaller. 510 instance GrowingAppend (Set k) where 511 512 instance Ord k => SetContainer (Set k) where 513 type ContainerKey (Set k) = k 514 member = ssetMember 515 notMember k s = not (ssetMember k s) 516 union = ssetUnion 517 difference = ssetDifference 518 intersection = ssetIntersection 519 keys = ssetToAscList 520 {-# INLINE member #-} 521 {-# INLINE notMember #-} 522 {-# INLINE union #-} 523 {-# INLINE difference #-} 524 {-# INLINE intersection #-} 525 {-# INLINE keys #-} 526 527 instance Ord a => Semigroup (Set a) where 528 (<>) = ssetUnion 529 {-# INLINE (<>) #-} 530 531 instance Ord a => Monoid (Set a) where 532 mempty = emptySet 533 {-# INLINE mempty #-} 534 535 instance Ord k => IsSet (Set k) where 536 insertSet = ssetInsert 537 deleteSet = ssetDelete 538 singletonSet = ssetSingleton 539 setFromList = ssetFromList 540 setToList = ssetToAscList 541 {-# INLINE insertSet #-} 542 {-# INLINE deleteSet #-} 543 {-# INLINE singletonSet #-} 544 {-# INLINE setFromList #-} 545 {-# INLINE setToList #-}