{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Continuous.Set.Internal
( Set(..)
, Inclusivity(..)
, empty
, universe
, null
, universal
, singleton
, append
, member
, equals
, showsPrec
) where
import Prelude hiding (lookup,showsPrec,concat,map,foldr,negate,null)
import Control.Monad.ST (ST,runST)
import Data.Bool (bool)
import Data.Word (Word8)
import Data.Primitive.Contiguous (Contiguous,Element,Mutable)
import Data.Primitive (PrimArray,MutablePrimArray)
import Data.Bits (unsafeShiftL,unsafeShiftR,(.|.),(.&.))
import qualified Data.Foldable as F
import qualified Prelude as P
import qualified Data.Primitive.Contiguous as I
import qualified Data.Concatenation as C
data Set arr a = Set
!(arr a)
!(PrimArray Word8)
data Inclusivity = Exclusive | Inclusive
deriving (Eq,Ord,Show,Read)
data Edge
= EdgeInclusive
| EdgeExclusive
| EdgeAbsent
| EdgeUniversal
equals :: (Contiguous arr, Element arr a, Eq a) => Set arr a -> Set arr a -> Bool
equals (Set keys1 incs1) (Set keys2 incs2) =
I.equals keys1 keys2 && incs1 == incs2
empty :: Contiguous arr => Set arr a
empty = Set I.empty $ runST $ do
marr <- I.new 1
I.write marr 0 (edgePairToWord8 EdgeAbsent EdgeAbsent)
I.unsafeFreeze marr
universe :: Contiguous arr => Set arr a
universe = Set I.empty $ runST $ do
marr <- I.new 1
I.write marr 0 (edgePairToWord8 EdgeUniversal EdgeUniversal)
I.unsafeFreeze marr
null :: Contiguous arr => Set arr a -> Bool
null (Set keys incs) = I.null keys
&& I.index incs 0 == edgePairToWord8 EdgeAbsent EdgeAbsent
universal :: Contiguous arr => Set arr a -> Bool
universal (Set keys incs) = I.null keys
&& I.index incs 0 == edgePairToWord8 EdgeUniversal EdgeUniversal
singleton :: (Contiguous arr, Element arr a, Ord a)
=> Maybe (Inclusivity,a)
-> Maybe (Inclusivity,a)
-> Set arr a
singleton Nothing Nothing = universe
singleton Nothing (Just (incHi,hi)) = runST $ do
keys <- I.replicateM 1 hi >>= I.unsafeFreeze
incs <- I.replicateM 1 (edgePairToWord8 (inclusivityToEdge incHi) EdgeAbsent) >>= I.unsafeFreeze
return (Set keys incs)
singleton (Just (incLo,lo)) Nothing = runST $ do
keys <- I.replicateM 1 lo >>= I.unsafeFreeze
incs <- I.replicateM 1 (edgePairToWord8 EdgeAbsent (inclusivityToEdge incLo)) >>= I.unsafeFreeze
return (Set keys incs)
singleton (Just (incLo,lo)) (Just (incHi,hi)) = case compare lo hi of
GT -> empty
EQ -> if incLo == Inclusive && incHi == Inclusive
then runST $ do
keys <- I.replicateM 2 lo >>= I.unsafeFreeze
incsMut <- I.new 2
I.write incsMut 0 (inclusivityPairToWord8 Inclusive Inclusive)
I.write incsMut 1 (edgePairToWord8 EdgeAbsent EdgeAbsent)
incs <- I.unsafeFreeze incsMut
return (Set keys incs)
else empty
LT -> unsafeSingleton incLo lo incHi hi
unsafeSingleton :: (Contiguous arr, Element arr a) => Inclusivity -> a -> Inclusivity -> a -> Set arr a
unsafeSingleton incLo lo incHi hi = runST $ do
keysMut <- I.replicateM 2 lo
I.write keysMut 1 hi
keys <- I.unsafeFreeze keysMut
incsMut <- I.new 2
I.write incsMut 0 (inclusivityPairToWord8 incLo incHi)
I.write incsMut 1 (edgePairToWord8 EdgeAbsent EdgeAbsent)
incs <- I.unsafeFreeze incsMut
return (Set keys incs)
except :: (Contiguous arr, Element arr a) => a -> Set arr a
except x = Set keys incs where
keys = runST $ I.replicateM 2 x >>= I.unsafeFreeze
incs = runST $ do
m <- I.new 1
I.write m 0 (edgePairToWord8 EdgeExclusive EdgeExclusive)
I.unsafeFreeze m
infinities :: (Contiguous arr, Element arr a, Ord a)
=> Inclusivity
-> a
-> Inclusivity
-> a
-> Set arr a
infinities negInfHiInc negInfHi posInfLoInc posInfLo =
case compare negInfHi posInfLo of
GT -> universe
EQ -> if negInfHiInc == Exclusive && posInfLoInc == Exclusive
then except negInfHi
else universe
LT -> unsafeInfinities negInfHiInc negInfHi posInfLoInc posInfLo
unsafeInfinities :: (Contiguous arr, Element arr a) => Inclusivity -> a -> Inclusivity -> a -> Set arr a
unsafeInfinities negInfHiInc negInfHi posInfLoInc posInfLo = runST $ do
keysMut <- I.replicateM 2 negInfHi
I.write keysMut 1 posInfLo
keys <- I.unsafeFreeze keysMut
incsMut <- I.new 1
I.write incsMut 0 (edgePairToWord8 (inclusivityToEdge negInfHiInc) (inclusivityToEdge posInfLoInc))
incs <- I.unsafeFreeze incsMut
return (Set keys incs)
append :: forall arr a. (Ord a, Contiguous arr, Element arr a) => Set arr a -> Set arr a -> Set arr a
append s1@(Set keys1 incs1) s2@(Set keys2 incs2)
| null s1 = s2
| null s2 = s1
| universal s1 = s1
| universal s2 = s2
| pairsCount1 == 0 && pairsCount2 == 0 = case lowerPair1 of
Nothing -> case upperPair1 of
Nothing -> error (errMsg 9)
Just (posInfLoInc1,posInfLo1) -> case lowerPair2 of
Nothing -> case upperPair2 of
Just (posInfLoInc2,posInfLo2) -> case compare posInfLo1 posInfLo2 of
EQ -> if posInfLoInc1 > posInfLoInc2 then s1 else s2
GT -> s2
LT -> s1
Just (negInfHiInc1,negInfHi1) -> case upperPair1 of
Nothing -> case lowerPair2 of
Nothing -> case upperPair2 of
Nothing -> error (errMsg 1)
Just (posInfLoInc2,posInfLo2) ->
case compare negInfHi1 posInfLo2 of
GT -> universe
EQ -> if negInfHiInc1 == Inclusive || posInfLoInc2 == Inclusive
then universe
else except negInfHi1
LT -> unsafeInfinities negInfHiInc1 negInfHi1 posInfLoInc2 posInfLo2
Just (negInfHiInc2,negInfHi2) -> case upperPair2 of
Nothing -> case compare negInfHi1 negInfHi2 of
EQ -> if negInfHiInc1 > negInfHiInc2 then s1 else s2
GT -> s1
LT -> s2
Just (posInfLoInc2,posInfLo2) -> case compare negInfHi1 negInfHi2 of
LT -> s2
EQ -> if negInfHiInc1 > negInfHiInc2
then infinities negInfHiInc1 negInfHi1 posInfLoInc2 posInfLo2
else s2
GT -> infinities negInfHiInc1 negInfHi1 posInfLoInc2 posInfLo2
| otherwise = runST $ do
let maxSz = pairsCount1 + pairsCount2 + 1
keysMut <- I.new (maxSz * 2)
incsMut <- I.new maxSz
case lowerPairRes of
Just (negInfHiIncOriginal,negInfHiOriginal) -> do
let (negInfHiIncFinal,negInfHiFinal,ixInit1,ixInit2) = eatFromNegativeInfinity negInfHiIncOriginal negInfHiOriginal keys1 incs1 pairsCount1 keys2 incs2 pairsCount2
case upperPairRes of
Just (posInfLoIncOriginal,posInfLoOriginal) -> do
let (posInfLoIncFinal,posInfLoFinal,ixLast1,ixLast2) = eatFromPositiveInfinity posInfLoIncOriginal posInfLoOriginal keys1 incs1 pairsCount1 keys2 incs2 pairsCount2
finalIx <- go keysMut incsMut ixInit1 ixLast1 ixInit2 ixLast2 0 negInfHiIncFinal negInfHiFinal
I.write incsMut finalIx (edgePairToWord8 (inclusivityToEdge negInfHiIncFinal) (inclusivityToEdge posInfLoIncFinal))
I.write keysMut (finalIx * 2) negInfHiFinal
I.write keysMut (finalIx * 2 + 1) posInfLoFinal
keysFrozen <- I.resize keysMut (finalIx * 2 + 2) >>= I.unsafeFreeze
incsFrozen <- I.resize incsMut (finalIx * 1) >>= I.unsafeFreeze
return (Set keysFrozen incsFrozen)
Nothing -> error (errMsg 102)
Nothing -> error (errMsg 101)
where
(lowerPair1,upperPair1,pairsCount1) = edges keys1 incs1
(lowerPair2,upperPair2,pairsCount2) = edges keys2 incs2
lowerPairRes = combineNegativeInfinities lowerPair1 lowerPair2
upperPairRes = combineNegativeInfinities upperPair1 upperPair2
go :: forall s.
Mutable arr s a
-> MutablePrimArray s Word8
-> Int
-> Int
-> Int
-> Int
-> Int
-> Inclusivity
-> a
-> ST s Int
go !keysMut !incsMut !ix1 !ixLast1 !ix2 !ixLast2 !ixDst inc a = if ix1 <= ixLast1
then error (errMsg 103)
else if ix2 <= ixLast2
then error (errMsg 104)
else case upperPair1 of
Just _ -> error (errMsg 105)
Nothing -> case upperPair2 of
Nothing -> return ixDst
combineNegativeInfinities :: Ord a => Maybe (Inclusivity,a) -> Maybe (Inclusivity,a) -> Maybe (Inclusivity,a)
combineNegativeInfinities Nothing Nothing = Nothing
combineNegativeInfinities Nothing x@(Just _) = x
combineNegativeInfinities x@(Just _) Nothing = x
combineNegativeInfinities (Just (xinc,x)) (Just (yinc,y)) = case compare x y of
GT -> Just (xinc,x)
LT -> Just (yinc,y)
EQ -> Just (max xinc yinc,y)
eatFromPositiveInfinity ::
Inclusivity
-> a
-> arr a
-> PrimArray Word8
-> Int
-> arr a
-> PrimArray Word8
-> Int
-> (Inclusivity,a,Int,Int)
eatFromPositiveInfinity = error (errMsg 110)
eatFromNegativeInfinity :: (Contiguous arr, Element arr a, Ord a)
=> Inclusivity
-> a
-> arr a
-> PrimArray Word8
-> Int
-> arr a
-> PrimArray Word8
-> Int
-> (Inclusivity,a,Int,Int)
eatFromNegativeInfinity negInfInc0 negInfHi0 keys1 incs1 sz1 keys2 incs2 sz2 = go negInfInc0 negInfHi0 0 0
where
go negInfHiInc negInfHi !ix1 !ix2 = if ix1 < sz1
then error (errMsg 111)
else if ix2 < sz2
then let (# lo #) = I.index# keys2 (ix2 * 2)
(# hi #) = I.index# keys2 (ix2 * 2 + 1)
(loInc,hiInc) = indexInclusivityPair incs2 ix2
in case compare negInfHi lo of
LT -> (negInfHiInc,negInfHi,ix1,ix2)
GT -> case compare negInfHi hi of
LT -> go hiInc hi ix1 (ix2 + 1)
GT -> go negInfHiInc negInfHi ix1 (ix2 + 1)
EQ -> go (max hiInc negInfHiInc) hi ix1 (ix2 + 1)
EQ -> if negInfHiInc == Exclusive && loInc == Exclusive
then (Exclusive,negInfHi,ix1,ix2)
else case compare negInfHi hi of
LT -> go hiInc hi ix1 (ix2 + 1)
GT -> go negInfHiInc negInfHi ix1 (ix2 + 1)
EQ -> go (max hiInc negInfHiInc) hi ix1 (ix2 + 1)
else (negInfHiInc,negInfHi,ix1,ix2)
inclusivityToEdge :: Inclusivity -> Edge
inclusivityToEdge Inclusive = EdgeInclusive
inclusivityToEdge Exclusive = EdgeExclusive
inclusivityToWord8 :: Inclusivity -> Word8
inclusivityToWord8 Inclusive = 0
inclusivityToWord8 Exclusive = 1
inclusivityPairToWord8 :: Inclusivity -> Inclusivity -> Word8
inclusivityPairToWord8 a b =
unsafeShiftL (inclusivityToWord8 a) 1
.|. inclusivityToWord8 b
word8ToInclusivity :: Word8 -> Inclusivity
word8ToInclusivity 0 = Inclusive
word8ToInclusivity _ = Exclusive
indexInclusivityPair :: PrimArray Word8 -> Int -> (Inclusivity,Inclusivity)
indexInclusivityPair xs ix = case I.index xs ix of
0 -> (Inclusive,Inclusive)
1 -> (Inclusive,Exclusive)
2 -> (Exclusive,Inclusive)
_ -> (Exclusive,Exclusive)
edgeToWord8 :: Edge -> Word8
edgeToWord8 EdgeInclusive = 0
edgeToWord8 EdgeExclusive = 1
edgeToWord8 EdgeAbsent = 2
edgeToWord8 EdgeUniversal = 3
word8ToEdge :: Word8 -> Edge
word8ToEdge x = case x of
0 -> EdgeInclusive
1 -> EdgeExclusive
2 -> EdgeAbsent
_ -> EdgeUniversal
edgePairToWord8 :: Edge -> Edge -> Word8
edgePairToWord8 a b = unsafeShiftL (edgeToWord8 a) 2 .|. edgeToWord8 b
edgeMetadata :: PrimArray Word8 -> (Edge,Edge)
edgeMetadata xs = (word8ToEdge (unsafeShiftR w 2), word8ToEdge (0b00000011 .&. w))
where
w = I.index xs (I.size xs - 1)
edges :: (Contiguous arr, Element arr a)
=> arr a
-> PrimArray Word8
-> (Maybe (Inclusivity,a), Maybe (Inclusivity,a),Int)
edges keys incs = case edgeMetadata incs of
(lower,upper) -> case lower of
EdgeUniversal -> error (errMsg 2)
EdgeAbsent -> case upper of
EdgeInclusive -> (Nothing,Just (Inclusive,I.index keys (sz - 1)),div (sz - 1) 2)
EdgeExclusive -> (Nothing,Just (Exclusive,I.index keys (sz - 1)),div (sz - 1) 2)
EdgeAbsent -> (Nothing,Nothing,div sz 2)
_ -> error (errMsg 3)
EdgeInclusive -> case upper of
EdgeInclusive -> (Just (Inclusive,I.index keys (sz - 2)),Just (Inclusive,I.index keys (sz - 1)),div (sz - 2) 2)
EdgeExclusive -> (Just (Inclusive,I.index keys (sz - 2)),Just (Exclusive,I.index keys (sz - 1)),div (sz - 2) 2)
EdgeAbsent -> (Just (Inclusive,I.index keys (sz - 1)),Nothing,div (sz - 1) 2)
EdgeUniversal -> error (errMsg 4)
EdgeExclusive -> case upper of
EdgeInclusive -> (Just (Exclusive,I.index keys (sz - 2)),Just (Inclusive,I.index keys (sz - 1)),div (sz - 2) 2)
EdgeExclusive -> (Just (Exclusive,I.index keys (sz - 2)),Just (Exclusive,I.index keys (sz - 1)),div (sz - 2) 2)
EdgeAbsent -> (Just (Exclusive,I.index keys (sz - 1)),Nothing,div (sz - 1) 2)
EdgeUniversal -> error (errMsg 5)
where
sz = I.size keys
member :: forall arr a. (Contiguous arr, Element arr a, Ord a)
=> a
-> Set arr a
-> Bool
member val (Set keys incs) = case edges keys incs of
(!mnegInfHi,!mposInfLo,!n) ->
case mnegInfHi of
Nothing -> case mposInfLo of
Nothing -> go 0 (n - 1)
Just (!posInfLoInc,!posInfLo) -> case compare val posInfLo of
GT -> True
LT -> go 0 (n - 1)
EQ -> posInfLoInc == Inclusive
Just (!negInfHiInc,!negInfHi) -> case mposInfLo of
Nothing -> case compare val negInfHi of
LT -> True
GT -> go 0 (n - 1)
EQ -> negInfHiInc == Inclusive
Just (!posInfLoInc,!posInfLo) -> case compare val posInfLo of
GT -> True
LT -> case compare val negInfHi of
GT -> go 0 (n - 1)
LT -> True
EQ -> negInfHiInc == Inclusive
EQ -> posInfLoInc == Inclusive
where
go !start !end = if end <= start
then if end == start
then
let !(# valLo #) = I.index# keys (2 * start)
!(# valHi #) = I.index# keys (2 * start + 1)
in case indexInclusivityPair incs start of
(Exclusive,Exclusive) -> val > valLo && val < valHi
(Exclusive,Inclusive) -> val > valLo && val <= valHi
(Inclusive,Exclusive) -> val >= valLo && val < valHi
(Inclusive,Inclusive) -> val >= valLo && val <= valHi
else False
else
let !mid = div (end + start + 1) 2
!valLo = I.index keys (2 * mid)
in case P.compare val valLo of
LT -> go start (mid - 1)
EQ -> True
GT -> go mid end
{-# INLINEABLE member #-}
errMsg :: Int -> String
errMsg n = "Data.Continuous.Set.Internal: invariant " ++ show n ++ " violated"
toPairs :: (Contiguous arr, Element arr a) => Int -> Set arr a -> [(Inclusivity,a,Inclusivity,a)]
toPairs n (Set keys incs) = go 0 where
go !ix = if ix < n
then
let (incLo,incHi) = indexInclusivityPair incs ix
lo = I.index keys (2 * ix)
hi = I.index keys (2 * ix + 1)
in (incLo,lo,incHi,hi) : go (ix + 1)
else []
showsPrec :: (Contiguous arr, Element arr a, Show a)
=> Int
-> Set arr a
-> ShowS
showsPrec _ s@(Set keys incs)
| null s = showString "{}"
| universal s = showString "{(-∞,+∞)}"
| otherwise = showChar '{' . showListInf shows lowerPair (toPairs pairsCount s) upperPair . showChar '}'
where
(lowerPair,upperPair,pairsCount) = edges keys incs
showListInf :: (a -> ShowS) -> Maybe (Inclusivity,a) -> [(Inclusivity,a,Inclusivity,a)] -> Maybe (Inclusivity,a) -> ShowS
showListInf showx mnegInfHi [] mposInfLo s = case mnegInfHi of
Nothing -> case mposInfLo of
Nothing -> s
Just (posInfLoInc,posInfLo) -> showPosInfLo showx posInfLoInc posInfLo s
Just (negInfHiInc,negInfHi) -> case mposInfLo of
Nothing -> showNegInfHi showx negInfHiInc negInfHi s
Just (posInfLoInc,posInfLo) -> showChar '{'
$ showNegInfHi showx negInfHiInc negInfHi
$ showChar ','
$ showPosInfLo showx posInfLoInc posInfLo
$ showChar '}'
$ s
showListInf showx mnegInfHi ((ainc0,a0,binc0,b0):xs) mposInfLo s =
maybe id (\(negInfHiInc,negInfHi) s' -> showNegInfHi showx negInfHiInc negInfHi (',' : s')) mnegInfHi (case ainc0 of {Inclusive -> '[';Exclusive -> '('} : showx a0 (',' : showx b0 (case binc0 of {Inclusive -> ']'; Exclusive -> ')'} : showl xs)))
where
showl [] = maybe id (\(posInfLoInc,posInfLo) -> showChar ',' . showPosInfLo showx posInfLoInc posInfLo) mposInfLo (']' : s)
showl ((ainc,a,binc,b):ys) = ',' : case ainc of {Inclusive -> '[';Exclusive -> '('} : showx a (',' : showx b (case binc of {Inclusive -> ']'; Exclusive -> ')'} : showl ys))
showNegInfHi :: (a -> ShowS) -> Inclusivity -> a -> ShowS
showNegInfHi showx inc x s = "(-∞," ++ showx x ((case inc of { Inclusive -> ']'; Exclusive -> ')'} : s))
showPosInfLo :: (a -> ShowS) -> Inclusivity -> a -> ShowS
showPosInfLo showx inc x s = case inc of { Inclusive -> '['; Exclusive -> '('} : (showx x (",+∞)" ++ s))