commitaed79e93a32adc084b17853c1ca9cda190fc6be2parenta747684bc2379a8eb2f36ac04128de536905d41dAuthor:Sol <sol@plunder.tech>Date:Thu, 7 Sep 2023 21:18:21 -0400 rts: use divide-and-conquer for set-union This path take the "divide and conquer" algorithm that we used for set-intersection, and implements it on set-union. This is faster on all benchmarks, and *much* faster on some benchmarks. The code is a bit gnarly but all the tests pass.Diffstat:

M | lib/Data/Sorted/Set.hs | | | 152 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------- |

1 file changed, 109 insertions(+), 43 deletions(-)diff --git a/lib/Data/Sorted/Set.hs b/lib/Data/Sorted/Set.hs@@ -47,8 +47,9 @@ import Data.Sorted.Search import Data.Sorted.Types import Prelude -import Data.Coerce (coerce) -import GHC.Exts (Array#, Int(..), Int#, sizeofArray#, (+#)) +import Data.Coerce (coerce) +import Data.Functor (($>)) +import GHC.Exts (Array#, Int(..), Int#, sizeofArray#, (+#)) -- import qualified Fan.Prof as Prof @@ -140,8 +141,6 @@ ssetSize (SET a) = sizeofArray a -- O(n+m) given input sets of size n and m. -- -- We special-case sets of size zero and one. --- --- TODO: Make sure that GHC optimizes away this pattern match. ssetUnion :: Ord k => Set k -> Set k -> Set k ssetUnion x@(SET xs) y@(SET ys) = case (sizeofArray xs, sizeofArray ys) of @@ -192,47 +191,114 @@ ssetUnionGeneric (SET xs) !xWid (SET ys) !yWid = in (xMx, yWid, xMx, xWid - xMx, xs) LT -> let yMx = bsearchPostIndex xLargest ys in (xWid, yMx, yMx, yWid - yMx, ys) + xOverlapWidth = xMax - xMin + yOverlapWidth = yMax - yMin in + coerce $ + if (yOverlapWidth > xOverlapWidth) then + ssetUnionGenericSwapped + ys yMin yMax yOverlapWidth + xs xMin xMax xOverlapWidth + initialArray beforeCount + finalArray afterStart afterCount + else + ssetUnionGenericSwapped + xs xMin xMax xOverlapWidth + ys yMin yMax yOverlapWidth + initialArray beforeCount + finalArray afterStart afterCount + +-- TODO: Too many arguments, find a way to rejigger this! +ssetUnionGenericSwapped + :: Ord a + => Array a -> Int -> Int -> Int + -> Array a -> Int -> Int -> Int + -> Array a -> Int + -> Array a -> Int -> Int + -> Array a +ssetUnionGenericSwapped + xs xMin xMax xOverlapWidth + ys yMin yMax yOverlapWidth + initialArray beforeCount + finalArray finalStart finalCount = runST do + + let maxOverlapWidth = xOverlapWidth + yOverlapWidth + buf <- newArray maxOverlapWidth (error "setUnion: uninitialized") - coerce $ runST do - let xOverlapWidth = xMax - xMin - let yOverlapWidth = yMax - yMin - - let maxOverlapWidth = xOverlapWidth + yOverlapWidth - buf <- newArray maxOverlapWidth (error "setUnion: uninitialized") - - let go o i j = do - let xRemain = xMax - i - let yRemain = yMax - j - case (xRemain, yRemain) of - (0, 0) -> pure o - (0, _) -> do - copyArray buf o ys j yRemain - pure (o + yRemain) - (_, 0) -> do - copyArray buf o xs i xRemain - pure (o + xRemain) - (_, _) -> do - let x = xs ! i - let y = ys ! j - case compare x y of - EQ -> writeArray buf o x >> go (o+1) (i+1) (j+1) - LT -> writeArray buf o x >> go (o+1) (i+1) j - GT -> writeArray buf o y >> go (o+1) i (j+1) - - overlapCount <- go 0 xMin yMin - overlap <- unsafeFreezeArray buf - - if (overlapCount == maxOverlapWidth && afterCount+beforeCount == 0) - then do - pure overlap - else do - let totalCount = beforeCount + overlapCount + afterCount - res <- newArray totalCount (error "setUnion: uninitialized") - copyArray res 0 initialArray 0 beforeCount - copyArray res beforeCount overlap 0 overlapCount - copyArray res (beforeCount+overlapCount) finalArray afterStart afterCount - unsafeFreezeArray res + let go o iLow iEnd jLow jEnd = do + let iRemain = iEnd - iLow + let jRemain = jEnd - jLow + case (iRemain, jRemain) of + (0, 0) -> pure o + (0, _) -> copyArray buf o ys jLow jRemain $> (o + jRemain) + (_, 0) -> copyArray buf o xs iLow iRemain $> (o + iRemain) + (1, 1) -> do + let x = xs!iLow + let y = ys!jLow + case compare x y of + LT -> writeArray buf o x >> writeArray buf (o+1) y >> pure (o+2) + EQ -> writeArray buf o x >> pure (o+1) + GT -> writeArray buf o y >> writeArray buf (o+1) x >> pure (o+2) + (1, _) -> do + let x = xs!iLow + case bsearch_ x ys jLow jEnd of + (# _, 1# #) -> do + copyArray buf o ys jLow jRemain $> (o+jRemain) + (# jMid#, _ #) -> do + let jMid = I# jMid# + let nBelo = jMid - jLow + let nAbove = jEnd - jMid + copyArray buf o ys jLow nBelo + writeArray buf (o+nBelo) x + copyArray buf (o+nBelo+1) ys jMid nAbove + pure (o+nBelo+1+nAbove) + (_, 1) -> do + let y = ys!jLow + case bsearch_ y xs iLow iEnd of + (# _, 1# #) -> do + copyArray buf o xs iLow iRemain $> (o+iRemain) + (# iMid#, _ #) -> do + let iMid = I# iMid# + let nBelo = iMid - iLow + let nAbove = iEnd - iMid + copyArray buf o xs iLow nBelo + writeArray buf (o+nBelo) y + copyArray buf (o+nBelo+1) xs iMid nAbove + pure (o+nBelo+1+nAbove) + + (_, _) -> do + -- Get the middle value for the left-set. + let iMid = (iLow + iEnd) `shiftR` 1 + let iMidVal = xs ! iMid + let !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd + + -- Recurse to the left of the split on both. + o2 <- go o iLow iMid jLow (I# jMid) + + -- Always write out the pivot value. + writeArray buf o2 iMidVal + let o3 = o2+1 + + -- Skip over the pivot on recursion if it matched. + let iMid' = iMid + 1 + let jMid' = I# (jMid +# found) + + -- Recurse to the right of the split on both. + go o3 iMid' iEnd jMid' jEnd + + overlapCount <- go 0 xMin xMax yMin yMax + overlap <- unsafeFreezeArray buf + + if (overlapCount == maxOverlapWidth && finalCount+beforeCount == 0) + then do + pure overlap + else do + let totalCount = beforeCount + overlapCount + finalCount + res <- newArray totalCount (error "setUnion: uninitialized") + copyArray res 0 initialArray 0 beforeCount + copyArray res beforeCount overlap 0 overlapCount + copyArray res (beforeCount+overlapCount) finalArray finalStart finalCount + unsafeFreezeArray res {-# INLINE ssetIsEmpty #-} ssetIsEmpty :: Set k -> Bool