      -- Copyright 2023 The Plunder Authors
      -- Use of this source code is governed by a BSD-style license that can be
      -- found in the LICENSE file.
      -- | Maps and Sets as Sorted Vectors
      {-# OPTIONS_GHC -Wall        #-}
      {-# OPTIONS_GHC -Werror      #-}
      {-# OPTIONS_GHC -Wno-orphans #-}
     {-# LANGUAGE UnboxedTuples #-}
     module Data.Sorted.Set
     ( ssetFromList
     , ssetToAscArray
     , ssetToAscList
     , ssetToDescList
     , ssetToArray
     , ssetSplitAt
     , ssetSpanAntitone
     , ssetIntersection
     , ssetDifference
     , ssetSingleton
     , ssetInsert
     , ssetLookupMin
     , ssetLookupMax
     , ssetDelete
     , ssetSize
     , ssetUnion
     , ssetIsEmpty
     , ssetFromDistinctAscList
     , ssetFindMax
     , ssetFindMin
     , ssetMember
     , ssetTake
     , ssetDrop
     )
     where
     import Control.Monad.ST
     import Data.Bits
     import Data.Containers
     import Data.Foldable
     import Data.MonoTraversable
     import Data.Primitive.Array
     import Data.Sorted.Row
     import Data.Sorted.Search
     import Data.Sorted.Types
     import Prelude
     import Data.Coerce  (coerce)
     import Data.Functor (($>))
     import GHC.Exts     (Array#, Int(..), Int#, sizeofArray#, (+#))
     -- import qualified Fan.Prof as Prof
     --------------------------------------------------------------------------------
     {-# INLINE emptySet #-}
     emptySet :: Set k
     emptySet = SET mempty
     {-# INLINE ssetSingleton #-}
     ssetSingleton :: k -> Set k
     ssetSingleton k = SET (rowSingleton k)
     -- x must be less than y, otherwise the resulting set will be invalid.
     {-# INLINE ssetUnsafeDuo #-}
     ssetUnsafeDuo :: k -> k -> Set k
     ssetUnsafeDuo x y = coerce (rowDuo x y)
     -- coerce here avoids the fmap.  SET is a newtype for row, so there
     -- is a type-safe cast between the two.
     -- Collect list into an mutable array.  Sort it, remove duplicates.
     --
     -- TODO: Avoid the copies by loading into a mutable array, sorting the
     -- mutable array, and then doing an unsafe freeze of that.
     ssetFromList :: Ord k => [k] -> Set k
     ssetFromList kList =
     -- Prof.withSimpleTracingEventPure "ssetFromList" "sorted" $
     SET $ rowSortUniqBy compare $ arrayFromList kList
     {-# INLINE ssetInsert# #-}
     ssetInsert# :: Ord a => a -> Array# a -> Int# -> Array# a
     ssetInsert# k ks# wid# =
     let !(# i, found #) = bsearch# compare k ks# 0# wid# in
     case found of
     0# -> rowUnsafeInsert# i k ks# wid#
     _  -> ks#
     {-# INLINE ssetInsert #-}
     ssetInsert :: Ord k => k -> Set k -> Set k
     ssetInsert k (SET (Array ks#)) =
     -- Prof.withSimpleTracingEventPure "setInsert" "sorted" $
     SET (Array (ssetInsert# k ks# (sizeofArray# ks#)))
     {-# INLINE ssetDelete #-}
     ssetDelete :: Ord k => k -> Set k -> Set k
     ssetDelete k set@(SET ks@(Array ks#)) =
    -- Prof.withSimpleTracingEventPure "setDelete" "sorted" $
    let !(# i, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
    case found of
    0# -> set
    _  -> SET (rowDelete (I# i) ks)
    {-# INLINE ssetLookupMin #-}
    ssetLookupMin :: Set a -> Maybe a
    ssetLookupMin (SET ks) | null ks = Nothing
    ssetLookupMin (SET ks)           = Just (ks!0)
    {-# INLINE ssetLookupMax #-}
    ssetLookupMax :: Set a -> Maybe a
    ssetLookupMax (SET ks) =
    if wid==0 then Nothing else Just (ks ! (wid-1))
    where
    !wid = sizeofArray ks
    {-# INLINE ssetToAscArray #-}
    ssetToAscArray :: Set k -> Array k
    ssetToAscArray (SET a) = a
    {-# INLINE ssetToAscList #-}
    ssetToAscList :: Set k -> [k]
    ssetToAscList (SET a) = toList a
    {-# INLINE ssetToArray #-}
    ssetToArray :: Set k -> Array k
    ssetToArray (SET a) = a
    {-# INLINE ssetToDescList #-}
    ssetToDescList :: Set k -> [k]
    ssetToDescList (SET a) = go (length a - 1)
    where
    go i | i<0 = []
    go i       = (a!i) : go (i-1)
    {-# INLINE ssetSize #-}
    ssetSize :: Set k -> Int
    ssetSize (SET a) = sizeofArray a
    -- O(n+m) given input sets of size n and m.
    --
    -- We special-case sets of size zero and one.
    ssetUnion :: Ord k => Set k -> Set k -> Set k
    ssetUnion x@(SET xs) y@(SET ys) =
    case (sizeofArray xs, sizeofArray ys) of
    ( 0,  _  ) -> y
    ( _,  0  ) -> x
    ( 1,  1  ) -> let xv = (xs!0)
    yv = (ys!0)
    in case compare xv yv of
    EQ -> x
    LT -> ssetUnsafeDuo xv yv
    GT -> ssetUnsafeDuo yv xv
    ( 1,  _  ) -> ssetInsert (xs!0) y
    ( _,  1  ) -> ssetInsert (ys!0) x
    ( xw, yw ) -> ssetUnionGeneric x xw y yw
    -- This assumes that neither of the inputs are empty.
    --
    -- TODO: Skip the shrinking optimization if the sizes are small.
    ssetUnionGeneric :: Ord a => Set a -> Int -> Set a -> Int -> Set a
    ssetUnionGeneric (SET xs) !xWid (SET ys) !yWid =
    -- Prof.withSimpleTracingEventPure "setUnionGeneric" "sorted" $
    let
    xSmallest = xs ! 0
    xLargest  = xs ! (xWid-1)
    ySmallest = ys ! 0
    yLargest  = ys ! (yWid-1)
    in
    -- If there is no overlap, then the union is just array concatenation.
    if xSmallest > yLargest then SET (ys <> xs) else
    if ySmallest > xLargest then SET (xs <> ys) else
    -- Find the overlapping range of the sets so we can walk merely the
    -- parts we know overlap
    let
    (xMin, yMin, beforeCount, initialArray) =
    case compare xSmallest ySmallest of
    EQ -> (0, 0, 0, xs)
    GT -> let yMn = bsearchIndex xSmallest ys
    in (0, yMn, yMn, ys)
    LT -> let xMn = bsearchIndex ySmallest xs
    in (xMn, 0, xMn, xs)
    (xMax, yMax, afterStart, afterCount, finalArray) =
    case compare xLargest yLargest of
    EQ -> (xWid, yWid, 0, 0, xs)
    GT -> let xMx = bsearchPostIndex yLargest xs
    in (xMx, yWid, xMx, xWid - xMx, xs)
    LT -> let yMx = bsearchPostIndex xLargest ys
    in (xWid, yMx, yMx, yWid - yMx, ys)
    xOverlapWidth = xMax - xMin
    yOverlapWidth = yMax - yMin
    in
    coerce $
    if (yOverlapWidth > xOverlapWidth) then
    ssetUnionGenericSwapped
    ys yMin yMax yOverlapWidth
    xs xMin xMax xOverlapWidth
    initialArray beforeCount
    finalArray afterStart afterCount
    else
    ssetUnionGenericSwapped
    xs xMin xMax xOverlapWidth
    ys yMin yMax yOverlapWidth
    initialArray beforeCount
    finalArray afterStart
    211 -- TODO: Too many arguments, find a way to rejigger this!
    212 ssetUnionGenericSwapped
    213     :: Ord a
    214     => Array a -> Int -> Int -> Int
    215     -> Array a -> Int -> Int -> Int
    216     -> Array a -> Int
    217     -> Array a -> Int -> Int
    218     -> Array a
    219 ssetUnionGenericSwapped
    220         xs xMin xMax xOverlapWidth
    221         ys yMin yMax yOverlapWidth
    222         initialArray beforeCount
    223         finalArray finalStart finalCount = runST do
    225     let maxOverlapWidth = xOverlapWidth + yOverlapWidth
    226     buf <- newArray maxOverlapWidth (error "setUnion: uninitialized")
    228     let go o iLow iEnd jLow jEnd = do
    229             let iRemain = iEnd - iLow
    230             let jRemain = jEnd - jLow
    231             case (iRemain, jRemain) of
    232                 (0, 0) -> pure o
    233                 (0, _) -> copyArray buf o ys jLow jRemain $> (o + jRemain)
    234                 (_, 0) -> copyArray buf o xs iLow iRemain $> (o + iRemain)
    235                 (1, 1) -> do
    236                     let x = xs!iLow
    237                     let y = ys!jLow
    238                     case compare x y of
    239                        LT -> writeArray buf o x >> writeArray buf (o+1) y >> pure (o+2)
    240                        EQ -> writeArray buf o x >> pure (o+1)
    241                        GT -> writeArray buf o y >> writeArray buf (o+1) x >> pure (o+2)
    242                 (1, _) -> do
    243                     let x = xs!iLow
    244                     case bsearch_ x ys jLow jEnd of
    245                         (# _, 1# #) -> do
    246                             copyArray buf o ys jLow jRemain $> (o+jRemain)
    247                         (# jMid#, _ #) -> do
    248                             let jMid   = I# jMid#
    249                             let nBelo  = jMid - jLow
    250                             let nAbove = jEnd - jMid
    251                             copyArray buf o ys jLow nBelo
    252                             writeArray buf (o+nBelo) x
    253                             copyArray buf (o+nBelo+1) ys jMid nAbove
    254                             pure (o+nBelo+1+nAbove)
    255                 (_, 1) -> do
    256                     let y = ys!jLow
    257                     case bsearch_ y xs iLow iEnd of
    258                         (# _, 1# #) -> do
    259                             copyArray buf o xs iLow iRemain $> (o+iRemain)
    260                         (# iMid#, _ #) -> do
    261                             let iMid   = I# iMid#
    262                             let nBelo  = iMid - iLow
    263                             let nAbove = iEnd - iMid
    264                             copyArray buf o xs iLow nBelo
    265                             writeArray buf (o+nBelo) y
    266                             copyArray buf (o+nBelo+1) xs iMid nAbove
    267                             pure (o+nBelo+1+nAbove)
    269                 (_, _) -> do
    270                     -- Get the middle value for the left-set.
    271                     let iMid               = (iLow + iEnd) `shiftR` 1
    272                     let iMidVal            = xs ! iMid
    273                     let !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd
    275                     -- Recurse to the left of the split on both.
    276                     o2 <- go o iLow iMid jLow (I# jMid)
    278                     -- Always write out the pivot value.
    279                     writeArray buf o2 iMidVal
    280                     let o3 = o2+1
    282                     -- Skip over the pivot on recursion if it matched.
    283                     let iMid' = iMid + 1
    284                     let jMid' = I# (jMid +# found)
    286                     -- Recurse to the right of the split on both.
    287                     go o3 iMid' iEnd jMid' jEnd
    289     overlapCount <- go 0 xMin xMax yMin yMax
    290     overlap      <- unsafeFreezeArray buf
    292     if (overlapCount == maxOverlapWidth && finalCount+beforeCount == 0)
    293     then do
    294         pure overlap
    295     else do
    296         let totalCount = beforeCount + overlapCount + finalCount
    297         res <- newArray totalCount (error "setUnion: uninitialized")
    298         copyArray res 0                          initialArray 0          beforeCount
    299         copyArray res beforeCount                overlap      0          overlapCount
    300         copyArray res (beforeCount+overlapCount) finalArray   finalStart finalCount
    301         unsafeFreezeArray res
    303 {-# INLINE ssetIsEmpty #-}
    304 ssetIsEmpty :: Set k -> Bool
    305 ssetIsEmpty (SET a) = null a
    307 -- TODO: Should we check this invariant?
    308 {-# INLINE ssetFromDistinctAscList #-}
    309 ssetFromDistinctAscList :: [k] -> Set k
    310 ssetFromDistinctAscList ksList = SET (arrayFromList ksList)
    312 {-# INLINE ssetFindMax #-}
    313 ssetFindMax :: Set k -> k
    314 ssetFindMax (SET s) =
    315     case sizeofArray s of
    316         0 -> error "setFindMax: empty set"
    317         n -> s ! (n-1)
    319 {-# INLINE ssetFindMin #-}
    320 ssetFindMin :: Set k -> k
    321 ssetFindMin (SET s) =
    322     if null s
    323     then error "setFindMin: empty setE"
    324     else s!0
    326 -- Do a search, return True if if found something, otherwise False.
    327 {-# INLINE ssetMember #-}
    328 ssetMember :: Ord k => k -> Set k -> Bool
    329 ssetMember k (SET (Array ks#)) =
    330     case bsearch# compare k ks# 0# (sizeofArray# ks#) of
    331         (# _, 0# #) -> False
    332         (# _, _  #) -> True
    334 -- This doesn't affect the order invariants, so we just run the operation
    335 -- directly against the underlying array.
    336 {-# INLINE ssetTake #-}
    337 ssetTake :: Int -> Set k -> Set k
    338 ssetTake i (SET ks) =
    339     -- Prof.withSimpleTracingEventPure "setTake" "sorted" $
    340     SET (rowTake i ks)
    342 -- This doesn't affect the order invariants, so we just run the operation
    343 -- directly against the underlying array.
    344 {-# INLINE ssetDrop #-}
    345 ssetDrop :: Int -> Set k -> Set k
    346 ssetDrop i (SET ks) =
    347     -- Prof.withSimpleTracingEventPure "setDrop" "sorted" $
    348     SET (rowDrop i ks)
    350 -- Just split the underlying array, set invariants are not at risk.
    351 {-# INLINE ssetSplitAt #-}
    352 ssetSplitAt :: Int -> Set k -> (Set k, Set k)
    353 ssetSplitAt i (SET ks) = (SET (rowTake i ks), SET (rowDrop i ks))
    355 -- O(n) set intersection.  Special cases for (size=0 and size=1)
    356 ssetIntersection :: Ord a => Set a -> Set a -> Set a
    357 ssetIntersection x@(SET xs) y@(SET ys) =
    358     -- Prof.withSimpleTracingEventPure "setIntersection" "sorted" $
    359     case (sizeofArray xs, sizeofArray ys) of
    360         ( 0,    _    ) -> mempty
    361         ( _,    0    ) -> mempty
    362         ( 1,    1    ) -> if xs!0 == ys!0 then x else mempty
    363         ( 1,    _    ) -> if ssetMember (xs!0) y then x else mempty
    364         ( _,    1    ) -> if ssetMember (ys!0) x then y else mempty
    365         ( xWid, yWid ) ->
    366             let
    367                 xSmallest = xs ! 0
    368                 yLargest  = ys ! (yWid-1)
    369                 xLargest  = xs ! (xWid-1)
    370                 ySmallest = ys ! 0
    371             in
    373             -- If no overlap is possible, return empty.
    374             if xSmallest > yLargest then mempty else
    375             if ySmallest > xLargest then mempty else
    377             -- figure out which subregions overlap.
    378             let
    379                 (xMin, yMin) =
    380                     case compare xSmallest ySmallest of
    381                         EQ -> (0, 0)
    382                         GT -> (0, bsearchIndex xSmallest ys)
    383                         LT -> (bsearchIndex ySmallest xs, 0)
    385                 (xMax, yMax) =
    386                     case compare xLargest yLargest of
    387                         EQ -> (xWid, yWid)
    388                         GT -> (bsearchPostIndex yLargest xs, yWid)
    389                         LT -> (xWid, bsearchPostIndex xLargest ys)
    391                 xSz = (xMax - xMin)
    392                 ySz = (yMax - yMin)
    393             in
    394                 -- Run the intersection algorithm against the overlapping
    395                 -- region, with the smaller region as the left-hand-side.
    396                 coerce $
    397                 if xSz > ySz
    398                 then ssetIntersectionGeneric ys yMin yMax ySz xs xMin xMax xSz
    399                 else ssetIntersectionGeneric xs xMin xMax xSz ys yMin yMax ySz
    401 {-
    402     Performs intersection by divide-and-conquor on two non-empty slices
    403     for two sorted arrays.  Returns a sorted, unique array containing
    404     the intersection.
    405 -}
    406 ssetIntersectionGeneric
    407     :: Ord a
    408     => Array a -> Int -> Int -> Int
    409     -> Array a -> Int -> Int -> Int
    410     -> Array a
    411 ssetIntersectionGeneric xs !xMin !xMax !xSz ys !yMin !yMax !ySz = runST do
    412     let rWid = min ySz xSz
    414     buf <- newArray rWid (error "setIntersection: uninitialized")
    416     let go o iLow iEnd jLow jEnd = do
    417               case (iEnd - iLow, jEnd - jLow) of
    418                   (0, _) -> pure o
    419                   (_, 0) -> pure o
    420                   (1, 1) ->
    421                       let xVal = xs!iLow; yVal = ys!jLow in
    422                       if xVal ==yVal
    423                       then writeArray buf o xVal >> pure (o+1)
    424                       else pure o
    426                   (1, _) ->
    427                       let xVal = xs!iLow in
    428                       case bsearch_ xVal ys jLow jEnd of
    429                           (# _, 0# #) -> pure o
    430                           (# _, _  #) -> writeArray buf o xVal >> pure (o+1)
    432                   (_, 1) ->
    433                       let yVal = ys!jLow in
    434                       case bsearch_ yVal xs iLow iEnd of
    435                           (# _, 0# #) -> pure o
    436                           (# _, _  #) -> writeArray buf o yVal >> pure (o+1)
    438                   (_, _) ->
    440                       -- Get the middle value for the left-set.
    441                       let iMid               = (iLow + iEnd) `shiftR` 1
    442                           iMidVal            = xs ! iMid
    443                           !(# jMid, found #) = bsearch_ iMidVal ys jLow jEnd
    444                       in do
    445                               -- Recurse to the left of the split on both.
    446                               o2 <- go o iLow iMid jLow (I# jMid)
    448                               -- Write out the pivot value, if it exists
    449                               -- in both arrays.
    450                               o3 <- case found of
    451                                         0# -> pure o2
    452                                         _  -> do writeArray buf o2 iMidVal
    453                                                  pure (o2+1)
    455                               -- Skip over the pivot if it matched.
    456                               let iMid' = iMid + 1
    457                               let jMid' = I# (jMid +# found)
    459                               -- Recurse to the right of the split on both.
    460                               go o3  iMid' iEnd jMid' jEnd
    462     used <- go 0 xMin xMax yMin yMax
    464     if used == rWid
    465     then unsafeFreezeArray buf
    466     else freezeArray buf 0 used
    468 -- O(n) set difference.
    469 ssetDifference :: Ord a => Set a -> Set a -> Set a
    470 ssetDifference (SET xs) (SET ys) = runST do
    471     let xWid = sizeofArray xs
    472     let yWid = sizeofArray ys
    473     buf <- newArray xWid (error "setDifference: uninitialized")
    474     let go o i j =
    475             if i >= xWid then pure o else
    476             if j >= yWid then do
    477                 let extra = xWid - i
    478                 copyArray buf o xs i extra
    479                 pure (o + extra)
    480             else do
    481                 let x = xs ! i
    482                 let y = ys ! j
    483                 case compare x y of
    484                     LT -> writeArray buf o x >> go (o+1) (i+1) j
    485                     EQ -> go o (i+1) (j+1)
    486                     GT -> go o i (j+1)
    487     used <- go 0 0 0
    488     if used == xWid
    489     then SET <$> unsafeFreezeArray buf
    490     else SET <$> freezeArray buf 0 used
    492 -- Assuming that the predicate is monotone, find the point where the
    493 -- predicate stops holding, and split the set there.
    494 {-# INLINE ssetSpanAntitone #-}
    495 ssetSpanAntitone :: (a -> Bool) -> Set a -> (Set a, Set a)
    496 ssetSpanAntitone f (SET ks) =
    497     let numTrue = bfind f ks in
    498     ( SET $ rowTake numTrue ks
    499     , SET $ rowDrop numTrue ks
    500     )
    502 --------------------------------------------------------------------------------
    503 -- TODO: Optimize and verify these instances
    505 instance MonoFoldable (Set a) where
    507 type instance Element (Set a) = a
    509 -- Assert that append (<>) never produces something smaller.
    510 instance GrowingAppend (Set k) where
    512 instance Ord k => SetContainer (Set k) where
    513     type ContainerKey (Set k) = k
    514     member = ssetMember
    515     notMember k s = not (ssetMember k s)
    516     union = ssetUnion
    517     difference = ssetDifference
    518     intersection = ssetIntersection
    519     keys = ssetToAscList
    520     {-# INLINE member #-}
    521     {-# INLINE notMember #-}
    522     {-# INLINE union #-}
    523     {-# INLINE difference #-}
    524     {-# INLINE intersection #-}
    525     {-# INLINE keys #-}
    527 instance Ord a => Semigroup (Set a) where
    528     (<>) = ssetUnion
    529     {-# INLINE (<>) #-}
    531 instance Ord a => Monoid (Set a) where
    532     mempty = emptySet
    533     {-# INLINE mempty #-}
    535 instance Ord k => IsSet (Set k) where
    536     insertSet = ssetInsert
    537     deleteSet = ssetDelete
    538     singletonSet = ssetSingleton
    539     setFromList = ssetFromList
    540     setToList = ssetToAscList
    541     {-# INLINE insertSet #-}
    542     {-# INLINE deleteSet #-}
    543     {-# INLINE singletonSet #-}
    544     {-# INLINE setFromList #-}
    545     {-# INLINE setToList #-}