{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Treap.Pure
(
Size (..)
, Priority (..)
, Treap (..)
, empty
, one
, size
, sizeInt
, monoid
, at
, query
, splitAt
, merge
, take
, drop
, rotate
, insert
, delete
, recalculate
) where
import Prelude hiding (drop, lookup, splitAt, take)
import Control.DeepSeq (NFData)
import Data.Foldable (foldl')
import Data.Word (Word64)
import GHC.Exts (IsList (..))
import GHC.Generics (Generic)
import Treap.Measured (Measured (..))
newtype Size = Size
{ unSize :: Int
} deriving stock (Show, Read, Generic)
deriving newtype (Eq, Ord, Num, NFData)
newtype Priority = Priority
{ unPriority :: Word64
} deriving stock (Show, Read, Generic)
deriving newtype (Eq, Ord, NFData)
data Treap m a
= Node !Size !Priority !m a !(Treap m a) !(Treap m a)
| Empty
deriving stock (Show, Read, Eq, Generic, Foldable)
deriving anyclass (NFData)
instance Monoid m => Measured m (Treap m a) where
measure = monoid
{-# INLINE measure #-}
instance Measured m a => IsList (Treap m a) where
type Item (Treap m a) = (Priority, a)
fromList :: [(Priority, a)] -> Treap m a
fromList =
foldl' (\t (i, p, a) -> insert i p a t) Empty
. zipWith (\i (p, a) -> (i, p, a)) [0..]
{-# INLINE fromList #-}
toList :: Treap m a -> [(Priority, a)]
toList Empty = []
toList (Node _ p _ a l r) = toList l ++ (p, a) : toList r
empty :: Treap m a
empty = Empty
{-# INLINE empty #-}
one :: Measured m a => Priority -> a -> Treap m a
one p a = Node (Size 1) p (measure a) a Empty Empty
{-# INLINE one #-}
size :: Treap m a -> Size
size = \case
Empty -> Size 0
Node s _ _ _ _ _ -> s
{-# INLINE size #-}
sizeInt :: Treap m a -> Int
sizeInt = unSize . size
{-# INLINE sizeInt #-}
monoid :: Monoid m => Treap m a -> m
monoid = \case
Empty -> mempty
Node _ _ m _ _ _ -> m
{-# INLINE monoid #-}
at :: forall m a . Int -> Treap m a -> Maybe a
at i t
| i < 0 = Nothing
| i >= sizeInt t = Nothing
| otherwise = go i t
where
go :: Int -> Treap m a -> Maybe a
go _ Empty = Nothing
go k (Node _ _ _ a l r) =
let lSize = sizeInt l
in case compare k lSize of
EQ -> Just a
LT -> go k l
GT -> go (k - lSize - 1) r
query :: forall m a . Measured m a => Int -> Int -> Treap m a -> m
query from to t
| to <= from = mempty
| otherwise =
let (l, _) = splitAt to t
(_, m) = splitAt from l
in monoid m
new :: Measured m a => Priority -> a -> Treap m a -> Treap m a -> Treap m a
new p a l r = recalculate $ Node 0 p mempty a l r
{-# INLINE new #-}
splitAt :: forall m a . Measured m a => Int -> Treap m a -> (Treap m a, Treap m a)
splitAt i t
| i <= 0 = (empty, t)
| i >= sizeInt t = (t, empty)
| otherwise = go i t
where
go :: Int -> Treap m a -> (Treap m a, Treap m a)
go k = \case
Empty -> (Empty, Empty)
Node _ p _ a left right ->
let lSize = sizeInt left
in case compare k lSize of
EQ -> (left, new p a Empty right)
LT ->
let (!t1, !newLeft) = go k left
in (t1, new p a newLeft right)
GT ->
let (!newRight, !t2) = go (k - lSize - 1) right
in (new p a left newRight, t2)
merge :: Measured m a => Treap m a -> Treap m a -> Treap m a
merge Empty r = r
merge l Empty = l
merge l@(Node _ p1 _ a1 l1 r1) r@(Node _ p2 _ a2 l2 r2)
| p1 > p2 = recalculate $ new p1 a1 l1 (merge r1 r)
| otherwise = recalculate $ new p2 a2 (merge l l2) r2
take :: forall m a . Measured m a => Int -> Treap m a -> Treap m a
take n t
| n <= 0 = Empty
| n >= sizeInt t = t
| otherwise = go n t
where
go :: Int -> Treap m a -> Treap m a
go _ Empty = Empty
go 0 _ = Empty
go i (Node _ p _ a l r) =
let lSize = sizeInt l
in case compare i lSize of
LT -> go i l
EQ -> l
GT -> new p a l $ go (i - lSize - 1) r
drop :: forall m a . Measured m a => Int -> Treap m a -> Treap m a
drop n t
| n <= 0 = t
| n >= sizeInt t = Empty
| otherwise = go n t
where
go :: Int -> Treap m a -> Treap m a
go _ Empty = Empty
go 0 tree = tree
go i (Node _ p _ a l r) =
let lSize = sizeInt l
in case compare i lSize of
LT -> new p a (go i l) r
EQ -> new p a Empty r
GT -> go (i - lSize - 1) r
rotate :: forall m a . Measured m a => Int -> Treap m a -> Treap m a
rotate n t = case t of
Empty -> Empty
_ | n == 0 -> t
| otherwise -> let (left, right) = splitAt shift t in merge right left
where
shift :: Int
shift = n `mod` sizeInt t
insert :: forall m a . Measured m a => Int -> Priority -> a -> Treap m a -> Treap m a
insert i p a t
| i < 0 = go 0 t
| i >= sizeInt t = go (sizeInt t) t
| otherwise = go i t
where
go :: Int -> Treap m a -> Treap m a
go _ Empty = one p a
go k node@(Node _ tp _ ta l r)
| p <= tp =
let lSize = sizeInt l
in if k <= lSize
then recalculate $ new tp ta (go k l) r
else recalculate $ new tp ta l (go (k - lSize - 1) r)
| otherwise =
let (!newL, !newR) = splitAt k node
in recalculate $ new p a newL newR
delete :: forall m a . Measured m a => Int -> Treap m a -> Treap m a
delete i t
| i < 0 = t
| i >= sizeInt t = t
| otherwise = go i t
where
go :: Int -> Treap m a -> Treap m a
go _ Empty = Empty
go k (Node _ p _ a l r) =
let lSize = sizeInt l
in case compare k lSize of
EQ -> merge l r
LT -> recalculate $ new p a (go k l) r
GT -> recalculate $ new p a l (go (k - lSize - 1) r)
recalculate :: Measured m a => Treap m a -> Treap m a
recalculate Empty = Empty
recalculate (Node _ p _ a l r) =
Node (1 + size l + size r) p (measure l <> measure a <> measure r) a l r