module Data.Set.BUSplay (
Splay(..)
, empty
, singleton
, insert
, fromList
, toList
, member
, delete
, deleteMin
, deleteMax
, null
, union
, intersection
, difference
, minimum
, maximum
, valid
, (===)
, showSet
, printSet
) where
import Data.List (foldl')
import Prelude hiding (minimum, maximum, null)
data Splay a = Leaf | Node (Splay a) a (Splay a) deriving Show
instance (Eq a) => Eq (Splay a) where
t1 == t2 = toList t1 == toList t2
(===) :: Eq a => Splay a -> Splay a -> Bool
Leaf === Leaf = True
(Node l1 x1 r1) === (Node l2 x2 r2) = x1 == x2 && l1 === l2 && r1 === r2
_ === _ = False
data Direction a = L a (Splay a) | R a (Splay a) deriving Show
type Path a = [Direction a]
search :: Ord a => a -> Splay a -> (Splay a, Path a)
search k s = go s []
where
go Leaf bs = (Leaf, bs)
go t@(Node l x r) bs = case compare k x of
LT -> go l (L x r : bs)
GT -> go r (R x l : bs)
EQ -> (t,bs)
searchMin :: Splay a -> (Splay a, Path a)
searchMin s = go s []
where
go Leaf bs = (Leaf, bs)
go (Node l x r) bs = go l (L x r : bs)
searchMax :: Splay a -> (Splay a, Path a)
searchMax s = go s []
where
go Leaf bs = (Leaf, bs)
go (Node l x r) bs = go r (R x l : bs)
splay :: Splay a -> Path a -> Splay a
splay t [] = t
splay Leaf (L x r : bs) = splay (Node Leaf x r) bs
splay Leaf (R x l : bs) = splay (Node l x Leaf) bs
splay (Node a x b) [L y c] = Node a x (Node b y c)
splay (Node b y c) [R x a] = Node (Node a x b) y c
splay (Node a x b) (L y c : L z d : bs)
= splay (Node a x (Node b y (Node c z d))) bs
splay (Node b x c) (R y a : L z d : bs)
= splay (Node (Node a y b) x (Node c z d)) bs
splay (Node c z d) (R y b : R x a : bs)
= splay (Node (Node (Node a x b) y c) z d) bs
splay (Node b x c) (L y d : R z a : bs)
= splay (Node (Node a z b) x (Node c y d)) bs
empty :: Splay a
empty = Leaf
null :: Splay a -> Bool
null Leaf = True
null _ = False
singleton :: a -> Splay a
singleton x = Node Leaf x Leaf
insert :: Ord a => a -> Splay a -> Splay a
insert x t = Node l x r
where
(l,_,r) = split x t
fromList :: Ord a => [a] -> Splay a
fromList = foldl' (flip insert) empty
toList :: Splay a -> [a]
toList t = inorder t []
where
inorder Leaf xs = xs
inorder (Node l x r) xs = inorder l (x : inorder r xs)
member :: Ord a => a -> Splay a -> (Bool, Splay a)
member x t = case search x t of
(Leaf, []) -> (False, empty)
(Leaf, ps) -> (False, splay Leaf ps)
(s, ps) -> (True, splay s ps)
minimum :: Splay a -> (a, Splay a)
minimum t = case uncurry splay $ searchMin t of
Leaf -> error "minimum"
s@(Node _ x _) -> (x, s)
maximum :: Splay a -> (a, Splay a)
maximum t = case uncurry splay $ searchMax t of
Leaf -> error "maximum"
s@(Node _ x _) -> (x, s)
deleteMin :: Splay a -> Splay a
deleteMin Leaf = error "deleteMin"
deleteMin t = case minimum t of
(_, Node Leaf _ r) -> r
_ -> error "deleteMin"
deleteMax :: Splay a -> Splay a
deleteMax Leaf = error "deleteMax"
deleteMax t = case maximum t of
(_, Node l _ Leaf) -> l
_ -> error "deleteMax"
delete :: Ord a => a -> Splay a -> Splay a
delete _ Leaf = Leaf
delete x t = case member x t of
(True, Node l _ r) -> merge l r
(False, s) -> s
_ -> error "delete"
union :: Ord a => Splay a -> Splay a -> Splay a
union t1 Leaf = t1
union Leaf t2 = t2
union t1 (Node l x r) = Node (union l' l) x (union r' r)
where
(l',_,r') = split x t1
intersection :: Ord a => Splay a -> Splay a -> Splay a
intersection Leaf _ = Leaf
intersection _ Leaf = Leaf
intersection t1 (Node l x r) = case split x t1 of
(l', True, r') -> Node (intersection l' l) x (intersection r' r)
(l', False, r') -> merge (intersection l' l) (intersection r' r)
difference :: Ord a => Splay a -> Splay a -> Splay a
difference Leaf _ = Leaf
difference t1 Leaf = t1
difference t1 (Node l x r) = union (difference l' l) (difference r' r)
where
(l',_,r') = split x t1
merge :: Splay a -> Splay a -> Splay a
merge Leaf t2 = t2
merge t1 Leaf = t1
merge t1 t2 = Node l x t2
where
(_, Node l x Leaf) = maximum t1
split :: Ord a => a -> Splay a -> (Splay a, Bool, Splay a)
split _ Leaf = (Leaf,False,Leaf)
split x t = case member x t of
(True, Node l _ r) -> (l,True,r)
(False, Node l y r) -> case compare x y of
LT -> (l, False, Node Leaf y r)
GT -> (Node l y Leaf, False, r)
EQ -> error "split"
_ -> error "split"
valid :: Ord a => Splay a -> Bool
valid t = isOrdered t
isOrdered :: Ord a => Splay a -> Bool
isOrdered t = ordered $ toList t
where
ordered [] = True
ordered [_] = True
ordered (x:y:xys) = x < y && ordered (y:xys)
showSet :: Show a => Splay a -> String
showSet = showSet' ""
showSet' :: Show a => String -> Splay a -> String
showSet' _ Leaf = "\n"
showSet' pref (Node l x r) = show x ++ "\n"
++ pref ++ "+ " ++ showSet' pref' l
++ pref ++ "+ " ++ showSet' pref' r
where
pref' = " " ++ pref
printSet :: Show a => Splay a -> IO ()
printSet = putStr . showSet