module Data.IntTrie
( IntTrie, identity, apply, modify, modify', overwrite,
mirror, modifyAscList, modifyDescList )
where
import Control.Applicative
import Control.Arrow (first, second)
import Data.Bits
import Data.Function (fix)
import Data.Monoid (Monoid(..))
import Data.Semigroup (Semigroup(..))
data IntTrie a = IntTrie (BitTrie a) a (BitTrie a)
data BitTrie a = BitTrie a (BitTrie a) (BitTrie a)
instance Functor BitTrie where
fmap f ~(BitTrie x l r) = BitTrie (f x) (fmap f l) (fmap f r)
instance Applicative BitTrie where
pure x = fix (\g -> BitTrie x g g)
~(BitTrie f fl fr) <*> ~(BitTrie x xl xr) = BitTrie (f x) (fl <*> xl) (fr <*> xr)
instance Semigroup a => Semigroup (BitTrie a) where
(<>) = liftA2 (<>)
instance Monoid a => Monoid (BitTrie a) where
mempty = pure mempty
mappend = liftA2 mappend
instance Functor IntTrie where
fmap f ~(IntTrie neg z pos) = IntTrie (fmap f neg) (f z) (fmap f pos)
instance Applicative IntTrie where
pure x = IntTrie (pure x) x (pure x)
IntTrie fneg fz fpos <*> IntTrie xneg xz xpos =
IntTrie (fneg <*> xneg) (fz xz) (fpos <*> xpos)
instance Semigroup a => Semigroup (IntTrie a) where
(<>) = liftA2 (<>)
instance Monoid a => Monoid (IntTrie a) where
mempty = pure mempty
mappend = liftA2 mappend
apply :: (Ord b, Num b, Bits b) => IntTrie a -> b -> a
apply (IntTrie neg z pos) x =
case compare x 0 of
LT -> applyPositive neg (x)
EQ -> z
GT -> applyPositive pos x
applyPositive :: (Num b, Bits b) => BitTrie a -> b -> a
applyPositive (BitTrie one even odd) x
| x == 1 = one
| testBit x 0 = applyPositive odd (x `shiftR` 1)
| otherwise = applyPositive even (x `shiftR` 1)
identity :: (Num a, Bits a) => IntTrie a
identity = IntTrie (fmap negate identityPositive) 0 identityPositive
identityPositive :: (Num a, Bits a) => BitTrie a
identityPositive = go
where
go = BitTrie 1 (fmap (`shiftL` 1) go) (fmap (\n -> (n `shiftL` 1) .|. 1) go)
modify :: (Ord b, Num b, Bits b) => b -> (a -> a) -> IntTrie a -> IntTrie a
modify x f ~(IntTrie neg z pos) =
case compare x 0 of
LT -> IntTrie (modifyPositive (x) f neg) z pos
EQ -> IntTrie neg (f z) pos
GT -> IntTrie neg z (modifyPositive x f pos)
modifyPositive :: (Num b, Bits b) => b -> (a -> a) -> BitTrie a -> BitTrie a
modifyPositive x f ~(BitTrie one even odd)
| x == 1 = BitTrie (f one) even odd
| testBit x 0 = BitTrie one even (modifyPositive (x `shiftR` 1) f odd)
| otherwise = BitTrie one (modifyPositive (x `shiftR` 1) f even) odd
modify' :: (Ord b, Num b, Bits b) => b -> (a -> a) -> IntTrie a -> IntTrie a
modify' x f (IntTrie neg z pos) =
case compare x 0 of
LT -> (IntTrie $! modifyPositive' (x) f neg) z pos
EQ -> (IntTrie neg $! f z) pos
GT -> IntTrie neg z $! modifyPositive' x f pos
modifyPositive' :: (Num b, Bits b) => b -> (a -> a) -> BitTrie a -> BitTrie a
modifyPositive' x f (BitTrie one even odd)
| x == 1 = (BitTrie $! f one) even odd
| testBit x 0 = BitTrie one even $! modifyPositive' (x `shiftR` 1) f odd
| otherwise = (BitTrie one $! modifyPositive' (x `shiftR` 1) f even) odd
overwrite :: (Ord b, Num b, Bits b) => b -> a -> IntTrie a -> IntTrie a
overwrite i x = modify i (const x)
mirror :: IntTrie a -> IntTrie a
mirror ~(IntTrie neg z pos) = IntTrie pos z neg
modifyAscList :: (Ord b, Num b, Bits b) => [(b, a -> a)] -> IntTrie a -> IntTrie a
modifyAscList ifs ~t@(IntTrie neg z pos) =
case break ((>= 0) . fst) ifs of
([], []) -> t
(nifs, (0, f):pifs) -> IntTrie (modifyAscListNegative nifs neg) (f z)
(modifyAscListPositive pifs pos)
(nifs, pifs) -> IntTrie (modifyAscListNegative nifs neg) z
(modifyAscListPositive pifs pos)
where modifyAscListNegative = modifyAscListPositive . map (first negate) . reverse
modifyDescList :: (Ord b, Num b, Bits b) => [(b, a -> a)] -> IntTrie a -> IntTrie a
modifyDescList ifs = mirror . modifyAscList (map (first negate) ifs) . mirror
modifyAscListPositive :: (Ord b, Num b, Bits b) => [(b, a -> a)] -> BitTrie a -> BitTrie a
modifyAscListPositive [] t = t
modifyAscListPositive ((0, _):_) _ =
error "modifyAscList: expected strictly monotonic indices"
modifyAscListPositive ifs@((i, f):_) ~(BitTrie one even odd) = BitTrie one' even' odd' where
(one', ifs') = if i == 1 then (f one, tail ifs) else (one, ifs)
even' = modifyAscListPositive ifsEven even
odd' = modifyAscListPositive ifsOdd odd
(ifsOdd, ifsEven) = both (map $ first (`shiftR` 1)) $ partitionIndices ifs'
both f (x, y) = (f x, f y)
partitionIndices :: (Num b, Bits b) => [(b, a -> a)] -> ([(b, a -> a)], [(b, a -> a)])
partitionIndices [] = ([], [])
partitionIndices [x] = if testBit (fst x) 0 then ([x], []) else ([], [x])
partitionIndices (x:xs@(y:_)) = case testBit (fst x) 0 of
False -> (if testBit (fst y) 0 then odd else pad:odd, x:even)
True -> (x:odd, if testBit (fst y) 0 then pad:even else even)
where ~(odd, even) = partitionIndices xs
pad = (fst y 1, id)