      -- Copyright 2023 The Plunder Authors
      -- Use of this source code is governed by a BSD-style license that can be
      -- found in the LICENSE file.
      -- | Maps and Sets as Sorted Vectors
      {-# OPTIONS_GHC -Wall        #-}
      {-# OPTIONS_GHC -Werror      #-}
      {-# OPTIONS_GHC -Wno-orphans #-}
     {-# LANGUAGE UnboxedTuples #-}
     module Data.Sorted.Tab
     ( mkTab
     , tabSingleton
     , tabInsert
     , tabLookup
     , tabSize
     , tabElemAt
     , tabSplit
     , tabSplitAt
     , tabSpanAntitone
     , tabMap
     , tabMapWithKey
     , tabUnion
     , tabUnionWith
     , tabIntersection
     , tabDifference
     , tabLookupMin
     , tabLookupMax
     , tabAlter
     , tabDelete
     , tabMember
     , tabElemsList
     , tabElemsArray
     , tabFoldlWithKey'
     , tabFilterWithKey
     , tabKeysSet
     , tabKeysList
     , tabKeysArray
     , tabToAscPairsList
     , tabToDescPairsList
     )
     where
     import Prelude
     import Control.Monad.ST
     import Data.Containers
     import Data.Foldable
     import Data.MonoTraversable
     import Data.Primitive.Array
     import Data.Sorted.Row
     import Data.Sorted.Search
     import Data.Sorted.Set
     import Data.Sorted.Types
     import PlunderPrelude (on)
     import GHC.Exts       (Int(..), indexArray#, sizeofArray#, (+#))
     -- import qualified Fan.Prof as Prof
     -- Searching -------------------------------------------------------------------
     {-# INLINE emptyTab #-}
     emptyTab :: Tab k v
     emptyTab = TAB mempty mempty
     {-# INLINE tabSingleton #-}
     tabSingleton :: k -> v -> Tab k v
     tabSingleton k v = TAB (rowSingleton k) (rowSingleton v)
     -- The first key MUST be strictly smaller than the second.
     {-# INLINE tabUnsafeDuo #-}
     tabUnsafeDuo :: k -> v -> k -> v -> Tab k v
     tabUnsafeDuo xk xv yk yv =
     TAB (rowDuo xk yk) (rowDuo xv yv)
     -- If found, update the values array at the found index.  Otherwise insert
     -- the key and value at the found-index of the relevent arrays.
     tabInsert :: Ord k => k -> v -> Tab k v -> Tab k v
     tabInsert k v (TAB ks@(Array ks#) vs) =
     -- Prof.withSimpleTracingEventPure "tabInsert" "sorted" $
     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
     let i = I# i# in
     case found of
     1# -> TAB ks (rowUnsafePut i v vs)
     _  -> TAB (rowInsert i k ks) (rowInsert i v vs)
     -- If found, merge the two values with (merge newVal oldVal).  Otherwise
     -- insert the key and value at the found-index of the relevant arrays.
     tabInsertWith :: Ord k => (v -> v -> v) -> k -> v -> Tab k v -> Tab k v
     tabInsertWith merge k v (TAB ks@(Array ks#) vs) =
     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
     let i = I# i# in
     case found of
     1# -> TAB ks $ rowUnsafePut i (merge v (vs!i)) vs
     _  -> TAB (rowInsert i k ks) (rowInsert i v vs)
    -- Do a search on the keys set, if we found a match, return the matching
    -- value in the values array.
    tabLookup :: Ord k => k -> Tab k v -> Maybe v
    tabLookup k (TAB (Array ks#) (Array vs#)) =
    -- Prof.withSimpleTracingEventPure "tabLookup" "sorted" $
    let !(# i, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
    case found of
    0# -> Nothing
    _  -> case indexArray# vs# i of
    (# res #) -> Just res
    {-# INLINE tabSize #-}
    tabSize :: Tab k v -> Int
    tabSize (TAB ks _) = sizeofArray ks
    {-# INLINE tabElemAt #-}
    tabElemAt :: Int -> Tab k v -> (k, v)
    tabElemAt i (TAB ks vs) =
    if i > length ks
    then error "tabElemAt: out-of-bounds"
    else (ks!i, vs!i)
    {-# INLINE tabSplitAt #-}
    tabSplitAt :: Int -> Tab k v -> (Tab k v, Tab k v)
    tabSplitAt i (TAB ks vs) =
    ( TAB (rowTake i ks) (rowTake i vs)
    , TAB (rowDrop i ks) (rowDrop i vs)
    )
    -- Find index, call split (TODO: What behavior on found vs not-found?
    -- Avoid off-by-one-errors)
    {-# INLINE tabSplit #-}
    tabSplit :: Ord k => k -> Tab k v -> (Tab k v, Tab k v)
    tabSplit k (TAB ks@(Array ks#) vs) =
    let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#)
    i = I# i#
    j = I# (i# +# found)
    in 
    ( TAB (rowTake i ks) (rowTake i vs)
    , TAB (rowDrop j ks) (rowDrop j vs)
    )
    {-# INLINE tabSpanAntitone #-}
    tabSpanAntitone :: (k -> Bool) -> Tab k v -> (Tab k v, Tab k v)
    tabSpanAntitone f (TAB ks vs) =
    ( TAB (rowTake numTrue ks) (rowTake numTrue vs)
    , TAB (rowDrop numTrue ks) (rowDrop numTrue vs)
    )
    where
    numTrue = bfind f ks
    {-# INLINE tabMapWithKey #-}
    tabMapWithKey :: (k -> v -> a) -> Tab k v -> Tab k a
    tabMapWithKey f (TAB ks vs) = TAB ks (rowZipWith f ks vs)
    {-# INLINE tabMap #-}
    tabMap :: (a -> b) -> Tab k a -> Tab k b
    tabMap f (TAB k v) = TAB k (f <$> v)
    {-# INLINE tabUnion #-}
    tabUnion :: Ord k => Tab k v -> Tab k v -> Tab k v
    tabUnion = tabUnionWith const
    -- O(n) union
    tabUnionWith :: Ord k => (v -> v -> v) -> Tab k v -> Tab k v -> Tab k v
    tabUnionWith merge x@(TAB xKeys xVals) y@(TAB yKeys yVals) =
    -- Prof.withSimpleTracingEventPure "tabUnion" "sorted" $
    case (sizeofArray xKeys, sizeofArray yKeys) of
    ( 0,  _  ) -> y
    ( _,  0  ) -> x
    ( 1,  1  ) -> let xk = xKeys!0
    yk = yKeys!0
    xv = xVals!0
    yv = yVals!0
    in case compare xk yk of
    LT -> tabUnsafeDuo xk xv yk yv
    GT -> tabUnsafeDuo yk yv xk xv
    EQ -> tabSingleton xk (merge xv yv)
    ( 1,  _  ) -> tabInsertWith merge        (xKeys!0) (xVals!0) y
    ( _,  1  ) -> tabInsertWith (flip merge) (yKeys!0) (yVals!0) x
    ( xw, yw ) -> tabUnionWithGeneric merge x xw y yw
    tabUnionWithGeneric
    :: Ord k => (v -> v -> v) -> Tab k v -> Int -> Tab k v -> Int -> Tab k v
    tabUnionWithGeneric merge (TAB xKeys xVals) !xWid (TAB yKeys yVals) !yWid =
    runST do
    let rWid = xWid + yWid
    valsBuf <- newArray rWid (error "ssetUnion: uninitialized")
    keysBuf <- newArray rWid (error "ssetUnion: uninitialized")
    let go o i j = do
    let xRemain = xWid - i
    let yRemain = yWid - j
    case (xRemain, yRemain) of
    (0, 0) -> pure o
    (0, _) -> do
    copyArray keysBuf o yKeys j yRemain
    copyArray valsBuf o yVals j yRemain
    pure (o + yRemain)
    200                 (_, 0) -> do
    201                     copyArray keysBuf o xKeys i xRemain
    202                     copyArray valsBuf o xVals i xRemain
    203                     pure (o + xRemain)
    204                 (_, _) -> do
    205                     let x = xKeys ! i
    206                     let y = yKeys ! j
    207                     case compare x y of
    208                         EQ -> do writeArray keysBuf o x
    209                                  writeArray valsBuf o (merge (xVals!i) (yVals!j))
    210                                  go (o+1) (i+1) (j+1)
    211                         LT -> do writeArray keysBuf o x
    212                                  writeArray valsBuf o (xVals!i)
    213                                  go (o+1) (i+1) j
    214                         GT -> do writeArray keysBuf o y
    215                                  writeArray valsBuf o (yVals!j)
    216                                  go (o+1) i     (j+1)
    218     written <- go 0 0 0
    220     if written == rWid
    221     then TAB <$> unsafeFreezeArray keysBuf
    222              <*> unsafeFreezeArray valsBuf
    223     else TAB <$> freezeArray keysBuf 0 written
    224              <*> freezeArray valsBuf 0 written
    226 -- O(n) tab difference
    227 tabDifference :: Ord k => Tab k v -> Tab k v -> Tab k v
    228 tabDifference x (TAB yKeys _) | null yKeys = x
    229 tabDifference (TAB xKeys _) _ | null xKeys = mempty
    230 tabDifference (TAB xKeys xVals) (TAB yKeys yVals) =
    231     -- Prof.withSimpleTracingEventPure "tabDifference" "sorted" $
    232     runST do
    233     let xWid = sizeofArray xKeys
    234     let yWid = sizeofArray yVals
    235     rKeys <- newArray xWid (error "tabDifference: uninitialized")
    236     rVals <- newArray xWid (error "tabDifference: uninitialized")
    237     let go o i j =
    238             if i >= xWid then pure o else
    239             if j >= yWid then do
    240                 let extra = xWid - i
    241                 copyArray rKeys o xKeys i extra
    242                 copyArray rVals o xVals i extra
    243                 pure (o + extra)
    244             else do
    245                 let x = xKeys ! i
    246                 let y = yKeys ! j
    247                 case compare x y of
    248                     LT -> do
    249                         writeArray rKeys o x
    250                         writeArray rVals o (xVals!i)
    251                         go (o+1) (i+1) j
    252                     EQ -> go o (i+1) (j+1)
    253                     GT -> go o i (j+1)
    254     used <- go 0 0 0
    255     if used == 0 then
    256         pure mempty
    257     else if used == xWid then
    258         TAB <$> unsafeFreezeArray rKeys <*> unsafeFreezeArray rVals
    259     else
    260         TAB <$> freezeArray rKeys 0 used <*> freezeArray rVals 0 used
    262 {-# INLINE tabIntersection #-}
    263 tabIntersection :: Ord k => Tab k v -> Tab k v -> Tab k v
    264 tabIntersection x y = tabIntersectionWith const x y
    266 tabIntersectionWith :: Ord k => (v -> v -> v) -> Tab k v -> Tab k v -> Tab k v
    267 tabIntersectionWith f x@(TAB xKeys xVals) y@(TAB yKeys yVals) =
    268     -- Prof.withSimpleTracingEventPure "tabIntersectionWith" "sorted" $
    269     case (sizeofArray xKeys, sizeofArray yKeys) of
    270         ( 0,  _  ) -> mempty
    271         ( _,  0  ) -> mempty
    272         ( 1,  1  ) -> let xk = xKeys!0
    273                       in if xk == (yKeys!0)
    274                          then tabSingleton xk (f (xVals!0) (yVals!0))
    275                          else mempty
    276         ( 1,  _  ) -> let xk = xKeys!0 in
    277                       case tabLookup xk y of
    278                           Nothing -> mempty
    279                           Just yv -> tabSingleton xk (f (xVals!0) yv)
    280         ( _,  1  ) -> let yk = yKeys!0 in
    281                       case tabLookup yk x of
    282                           Nothing -> mempty
    283                           Just xv -> tabSingleton yk (f xv (yVals!0))
    284         ( xw, yw ) -> tabIntersectionWithGeneric f x xw y yw
    286 tabIntersectionWithGeneric
    287     :: Ord k => (v -> v -> v) -> Tab k v -> Int -> Tab k v -> Int -> Tab k v
    288 tabIntersectionWithGeneric merge (TAB xKeys xVals) xWid (TAB yKeys yVals) yWid =
    289   runST do
    290     let rWid = min xWid yWid
    291     rKeys <- newArray rWid (error "setIntersection: uninitialized")
    292     rVals <- newArray rWid (error "setIntersection: uninitialized")
    293     let go o i j =
    294             if i >= xWid || j >= yWid then pure o else do
    295                 let x = xKeys ! i
    296                 let y = yKeys ! j
    297                 case compare x y of
    298                     EQ -> do
    299                         writeArray rKeys o x
    300                         writeArray rVals o $ merge (xVals!i) (yVals!j)
    301                         go (o+1) (i+1) (j+1)
    302                     LT -> go o (i+1) j
    303                     GT -> go o i (j+1)
    304     used <- go 0 0 0
    305     if used == rWid
    306     then TAB <$> unsafeFreezeArray rKeys <*> unsafeFreezeArray rVals
    307     else TAB <$> freezeArray rKeys 0 used <*> freezeArray rVals 0 used
    309 {-# INLINE tabLookupMin #-}
    310 tabLookupMin :: Tab k v -> Maybe (k, v)
    311 tabLookupMin (TAB k v) =
    312     if null k then Nothing else Just (k!0, v!0)
    314 {-# INLINE tabLookupMax #-}
    315 tabLookupMax :: Tab k v -> Maybe (k, v)
    316 tabLookupMax (TAB k v) =
    317     case sizeofArray k of
    318         0 -> Nothing
    319         n -> let !i = n-1 in Just (k!i, v!i)
    321 tabAlter :: Ord k => (Maybe v -> Maybe v) -> k -> Tab k v -> Tab k v
    322 tabAlter f k tab@(TAB ks@(Array ks#) vs) =
    323     -- Prof.withSimpleTracingEventPure "tabAlter" "sorted" $
    324     let !(# i#, found #) = bsearch# compare k ks# 0# (sizeofArray# ks#) in
    325     let i = I# i# in
    326     case found of
    327         1# ->
    328             case f (Just (vs!i)) of
    329                 Nothing -> TAB (rowUnsafeDelete i ks) (rowUnsafeDelete i vs)
    330                 Just v  -> TAB ks                     (rowUnsafePut i v vs)
    331         _ ->
    332             case f Nothing of
    333                 Nothing -> tab -- no change
    334                 Just v  -> TAB (rowInsert i k ks) (rowInsert i v vs)
    336 tabDelete :: Ord k => k -> Tab k v -> Tab k v
    337 tabDelete k tab@(TAB ks@(Array ks#) vs) =
    338     -- Prof.withSimpleTracingEventPure "tabDelete" "sorted" $
    339     case bsearch# compare k ks# 0# (sizeofArray# ks#) of
    340         (# _,  0# #) -> tab
    341         (# i#, _  #) -> TAB ks' vs'
    342             where i = I# i#
    343                   !ks' = rowUnsafeDelete i ks
    344                   !vs' = rowUnsafeDelete i vs
    346 {-# INLINE tabMember #-}
    347 tabMember :: Ord k => k -> Tab k v -> Bool
    348 tabMember k (TAB ks _) = ssetMember k (SET ks)
    350 {-# INLINE tabKeysSet #-}
    351 tabKeysSet :: Tab k v -> Set k
    352 tabKeysSet (TAB k _) = SET k
    354 {-# INLINE tabKeysArray #-}
    355 tabKeysArray :: Tab k v -> Array k
    356 tabKeysArray (TAB k _) = k
    358 {-# INLINE tabKeysList #-}
    359 tabKeysList :: Tab k v -> [k]
    360 tabKeysList (TAB k _) = toList k
    362 {-# INLINE tabFoldlWithKey' #-}
    363 tabFoldlWithKey' :: (a -> k -> v -> a) -> a -> Tab k v -> a
    364 tabFoldlWithKey' f !x (TAB ks vs) =
    365     -- Prof.withSimpleTracingEventPure "tabFoldWithKey" "sorted" do
    366     go 0 x
    367   where
    368     !wid = sizeofArray ks
    370     go i !acc | i >= wid  = acc
    371     go i !acc | otherwise = go (i+1) $ f acc (ks!i) (vs!i)
    373 tabFilterWithKey :: (k -> v -> Bool) -> Tab k v -> Tab k v
    374 tabFilterWithKey f tab@(TAB ks vs) =
    375   -- Prof.withSimpleTracingEventPure "tabFilterWithKey" "sorted" $
    376   runST do
    377     let !wid = sizeofArray ks
    378     keysBuf <- newArray wid (error "tabFilterWithKey: uninitialized")
    379     valsBuf <- newArray wid (error "tabFilterWithKey: uninitialized")
    380     let go o i | i >= wid  = pure o
    381         go o i | otherwise = do
    382             let key = ks!i
    383             let val = vs!i
    384             if f key val then do
    385                 writeArray keysBuf o key
    386                 writeArray valsBuf o val
    387                 go (o+1) (i+1)
    388             else do
    389                 go o (i+1)
    390     written <- go 0 0
    391     if written == wid then pure tab else
    392         TAB <$> freezeArray keysBuf 0 written
    393             <*> freezeArray valsBuf 0 written
    395 {-# INLINE tabElemsArray #-}
    396 tabElemsArray :: Tab k v -> Array v
    397 tabElemsArray (TAB _ v) = v
    399 {-# INLINE tabElemsList #-}
    400 tabElemsList :: Tab k v -> [v]
    401 tabElemsList (TAB _ v) = toList v
    403 {-
    404     Creates a table from a list of key-value pairs.  If keys appear
    405     multiple times, values later in the list are used.
    407     Implementation:
    409     - Collect the list of pairs into an array.
    411     - Do a stable sort on the array (comparing only the keys), removing
    412       duplicates.
    414     - Split the resulting array out into a keys array and a values array.
    416     Note that rowSortUniqBy chooses earlier values, not later values.
    417     We resolve this by filling the array in reverse.
    419     TODO: We can skip a copy by directly creating a mutable array,
    420     and doing the sort on that, in place.
    421 -}
    422 tabFromPairsList :: Ord k => [(k,v)] -> Tab k v
    423 tabFromPairsList pairs =
    424     -- Prof.withSimpleTracingEventPure "tabFromPairsList" "sorted" $
    425     let buf = rowSortUniqBy (on compare fst) $ arrayFromListRev pairs in
    426     TAB (fst <$> buf) (snd <$> buf)
    428 {-# INLINE tabToAscPairsList #-}
    429 tabToAscPairsList :: Tab k v -> [(k,v)]
    430 tabToAscPairsList (TAB k v) = go 0
    431   where
    432     !len = sizeofArray k
    433     go i | i >= len = []
    434     go i            = (k!i, v!i) : go (i+1)
    436 {-# INLINE tabToDescPairsList #-}
    437 tabToDescPairsList :: Tab k v -> [(k,v)]
    438 tabToDescPairsList (TAB k v) = go (length k - 1)
    439   where
    440     go i | i < 0 = []
    441     go i         = (k!i, v!i) : go (i-1)
    443 {-# INLINE mkTab #-}
    444 mkTab :: Set k -> Array v -> Tab k v
    445 mkTab (SET k) v =
    446     if sizeofArray k /= sizeofArray v
    447     then error "mkTab: keys must have same size as values"
    448     else TAB k v
    450 --------------------------------------------------------------------------------
    451 -- TODO: Optimize and verify these instances
    453 instance Foldable (Tab k) where
    454     foldMap f (TAB _ v) = foldMap f v
    455     {-# INLINE foldMap #-}
    457 instance MonoFunctor (Tab k v) where
    458     omap = fmap
    459     {-# INLINE omap #-}
    461 -- TODO: Explicit instances
    462 instance MonoFoldable (Tab k v) where
    464 instance MonoTraversable (Tab k v) where
    465     otraverse = traverse
    466     omapM = traverse
    467     {-# INLINE otraverse #-}
    468     {-# INLINE omapM #-}
    470 type instance Element (Tab k v) = v
    472 instance Ord k => Semigroup (Tab k v) where
    473     (<>) = tabUnion
    474     {-# INLINE (<>) #-}
    476 instance Ord k => Monoid (Tab k v) where
    477     mempty = emptyTab
    478     {-# INLINE mempty #-}
    480 -- Assert that append (<>) never produces something smaller (this class
    481 -- how no methods).
    482 instance GrowingAppend (Tab k v) where
    484 {-
    485     This explicitly implements all methods
    487     TODO: Better implementation of `unions`!
    488 -}
    489 instance Ord k => SetContainer (Tab k v) where
    490     type ContainerKey (Tab k v) = k
    491     member = tabMember
    492     notMember k t = not (tabMember k t)
    493     union = tabUnion
    494     difference = tabDifference
    495     intersection = tabIntersection
    496     keys = tabKeysList
    497     unions = ofoldl' tabUnion mempty
    498     {-# INLINE member #-}
    499     {-# INLINE notMember #-}
    500     {-# INLINE union #-}
    501     {-# INLINE unions #-}
    502     {-# INLINE difference #-}
    503     {-# INLINE intersection #-}
    504     {-# INLINE keys #-}
    506 instance Functor (Tab k) where
    507     fmap = tabMap
    508     {-# INLINE fmap #-}
    510 -- TODO: implement way more methods.
    511 instance Traversable (Tab k) where
    512     traverse f (TAB ks vs) = TAB ks <$> traverse f vs
    513     sequenceA (TAB ks vs)  = TAB ks <$> sequenceA vs
    515 -- TODO: Implement many more methods.
    516 instance Ord k => IsMap (Tab k v) where
    517     type MapValue (Tab k v) = v
    518     lookup = tabLookup
    519     insertMap = tabInsert
    520     deleteMap = tabDelete
    521     singletonMap = tabSingleton
    522     mapFromList = tabFromPairsList
    523     mapToList = tabToAscPairsList