commit5597a5c67ff91d6463ca8d487e8ff81766bb174dparent8199d0aa11068a7ca6fbffea94e781bbaebf8734Author:Sol <sol@plunder.tech>Date:Sat, 2 Sep 2023 12:45:46 -0400 rts: optimize setUnion in the same way. This changes setUnion so that it also only walks the overlapping region. This is a little bit trickier than the intersection case, because the initial non-overlapping sections also need to be copied into the result. Also, this has a fast-path for union of non-overlapping sets.Diffstat:

M | lib/Data/Sorted/Set.hs | | | 119 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------- |

1 file changed, 85 insertions(+), 34 deletions(-)diff --git a/lib/Data/Sorted/Set.hs b/lib/Data/Sorted/Set.hs@@ -46,8 +46,9 @@ import Data.Sorted.Search import Data.Sorted.Types import Prelude -import Data.Coerce (coerce) -import GHC.Exts (Int(..)) +import ClassyPrelude (when) +import Data.Coerce (coerce) +import GHC.Exts (Int(..)) -------------------------------------------------------------------------------- @@ -146,36 +147,84 @@ ssetUnion x@(SET xs) y@(SET ys) = ( _, 1 ) -> ssetInsert (ys!0) x ( xw, yw ) -> ssetUnionGeneric x xw y yw -ssetUnionGeneric :: Ord k => Set k -> Int -> Set k -> Int -> Set k -ssetUnionGeneric (SET xs) !xWid (SET ys) !yWid = runST do - let !rWid = xWid + yWid - - buf <- newArray rWid (error "ssetUnion: uninitialized") - - let go o i j = do - let xRemain = xWid - i - let yRemain = yWid - j - case (xRemain, yRemain) of - (0, 0) -> pure o - (0, _) -> do - copyArray buf o ys j yRemain - pure (o + yRemain) - (_, 0) -> do - copyArray buf o xs i xRemain - pure (o + xRemain) - (_, _) -> do - let x = xs ! i - let y = ys ! j - case compare x y of - EQ -> writeArray buf o x >> go (o+1) (i+1) (j+1) - LT -> writeArray buf o x >> go (o+1) (i+1) j - GT -> writeArray buf o y >> go (o+1) i (j+1) - - written <- go 0 0 0 - - SET <$> if written == rWid - then unsafeFreezeArray buf - else freezeArray buf 0 written +-- This assumes that neither of the inputs are empty. +-- +-- TODO: Skip the shrinking optimization if the sizes are small. +ssetUnionGeneric :: Ord a => Set a -> Int -> Set a -> Int -> Set a +ssetUnionGeneric (SET xs) !xWid (SET ys) !yWid = + let + xSmallest = xs ! 0 + xLargest = xs ! (xWid-1) + ySmallest = ys ! 0 + yLargest = ys ! (yWid-1) + in + + -- If there is no overlap, then the union is just array concatenation. + if xSmallest > yLargest then SET (ys <> xs) else + if ySmallest > xLargest then SET (xs <> ys) else + + -- Find the overlapping range of the sets so we can walk merely the + -- parts we know overlap + let + (xMin, yMin, beforeCount, initialArray) = + case compare xSmallest ySmallest of + EQ -> (0, 0, 0, xs) + GT -> let yMn = bsearchIndex xSmallest ys + in (0, yMn, yMn, ys) + LT -> let xMn = bsearchIndex ySmallest xs + in (xMn, 0, xMn, xs) + + (xMax, yMax, afterStart, afterCount, finalArray) = + case compare xLargest yLargest of + EQ -> (xWid, yWid, 0, 0, xs) + GT -> let xMx = bsearchPostIndex yLargest xs + in (xMx, yWid, xMx, xWid - xMx, xs) + LT -> let yMx = bsearchPostIndex xLargest ys + in (xWid, yMx, yMx, yWid - yMx, ys) + in + + coerce $ runST do + let xOverlapWidth = xMax - xMin + let yOverlapWidth = yMax - yMin + + let maxOverlapWidth = (xOverlapWidth + yOverlapWidth) + + let bufferWidth = beforeCount + afterCount + maxOverlapWidth + + buf <- newArray bufferWidth (error "setUnion: uninitialized") + + let go o i j = do + let xRemain = xMax - i + let yRemain = yMax - j + case (xRemain, yRemain) of + (0, 0) -> pure o + (0, _) -> do + copyArray buf o ys j yRemain + pure (o + yRemain) + (_, 0) -> do + copyArray buf o xs i xRemain + pure (o + xRemain) + (_, _) -> do + let x = xs ! i + let y = ys ! j + case compare x y of + EQ -> writeArray buf o x >> go (o+1) (i+1) (j+1) + LT -> writeArray buf o x >> go (o+1) (i+1) j + GT -> writeArray buf o y >> go (o+1) i (j+1) + + when (beforeCount > 0) do + copyArray buf 0 initialArray 0 beforeCount + + written <- go beforeCount xMin yMin + + when (afterCount > 0) do + copyArray buf written finalArray afterStart afterCount + + let totalWritten = written + afterCount + + if totalWritten == bufferWidth + then unsafeFreezeArray buf + else freezeArray buf 0 totalWritten {-# INLINE ssetIsEmpty #-} ssetIsEmpty :: Set k -> Bool @@ -254,7 +303,7 @@ ssetIntersectionGeneric (SET xs) !xWid (SET ys) !yWid = if xSmallest > yLargest then mempty else if ySmallest > xLargest then mempty else - -- Find the overlapping range of the the sets so we can walk merely the + -- Find the overlapping range of the sets so we can walk merely the -- parts we know overlap let (xMin, yMin) = @@ -271,7 +320,9 @@ ssetIntersectionGeneric (SET xs) !xWid (SET ys) !yWid = in coerce $ runST do - let rWid = min (xMax - xMin) (yMax - yMin) + let xOverlapWidth = xMax - xMin + let yOverlapWidth = yMax - yMin + let rWid = min xOverlapWidth yOverlapWidth buf <- newArray rWid (error "setIntersection: uninitialized") let go o i j = do if i >= xMax || j >= yMax then pure o else do