module Data.BalBST where
import Data.Bifunctor
import Data.Function (on)
import Data.Functor.Contravariant
import qualified Data.List as L
import Data.Maybe
import qualified Data.Tree as T
import Prelude hiding (lookup,null)
data TreeNavigator k a = Nav { goLeft :: a -> k -> Bool
, extractKey :: a -> a -> k
}
instance Contravariant (TreeNavigator k) where
contramap f (Nav gL eK) = Nav (\a k -> gL (f a) k) (\x y -> eK (f x) (f y))
ordNav :: Ord a => TreeNavigator a a
ordNav = Nav (<=) min
ordNavBy :: Ord b => (a -> b) -> TreeNavigator b a
ordNavBy f = Nav (\x k -> f x <= k) (min `on` f)
data BalBST k a = BalBST { nav :: !(TreeNavigator k a)
, toTree :: !(Tree k a)
}
instance (Show k, Show a) => Show (BalBST k a) where
show (BalBST _ t) = "BalBST (" ++ show t ++ ")"
data Color = Red | Black deriving (Show,Read,Eq,Ord)
type Height = Int
data Tree k a = Empty
| Leaf !a
| Node !Color !Height (Tree k a) !k (Tree k a) deriving (Show,Eq)
empty :: TreeNavigator k a -> BalBST k a
empty n = BalBST n Empty
fromList :: TreeNavigator k a -> [a] -> BalBST k a
fromList n = foldr insert (empty n)
fromList' :: Ord a => [a] -> BalBST a a
fromList' = fromList ordNav
null :: BalBST k a -> Bool
null (BalBST _ Empty) = True
null _ = False
lookup :: Eq a => a -> BalBST k a -> Maybe a
lookup x (BalBST Nav{..} t) = lookup' t
where
lookup' Empty = Nothing
lookup' (Leaf y) = if x == y then Just y else Nothing
lookup' (Node _ _ l k r)
| goLeft x k = lookup' l
| otherwise = lookup' r
member :: Eq a => a -> BalBST k a -> Bool
member x = isJust . lookup x
insert :: a -> BalBST k a -> BalBST k a
insert x (BalBST n@Nav{..} t) = BalBST n (blacken $ insert' t)
where
insert' Empty = Leaf x
insert' (Leaf y) = let k = extractKey x y
(l,r) = if goLeft x k then (x,y) else (y,x)
in red 2 (Leaf l) k (Leaf r)
insert' (Node c h l k r)
| goLeft x k = balance c h (insert' l) k r
| otherwise = balance c h l k (insert' r)
minView :: BalBST k a -> Maybe (a, Tree k a)
minView (BalBST n t) = minView' t
where
minView' Empty = Nothing
minView' (Leaf x) = Just (x,Empty)
minView' (Node _ _ l _ r) = fmap (flip (joinWith n) r) <$> minView' l
maxView :: BalBST k a -> Maybe (a, Tree k a)
maxView (BalBST n t) = maxView' t
where
maxView' Empty = Nothing
maxView' (Leaf x) = Just (x,Empty)
maxView' (Node _ _ l _ r) = fmap (joinWith n l) <$> maxView' r
join :: BalBST k a -> BalBST k a -> BalBST k a
join (BalBST n l) (BalBST _ r) = BalBST n $ joinWith n l r
joinWith :: TreeNavigator k a -> Tree k a -> Tree k a -> Tree k a
joinWith Nav{..} tl tr
| lh >= rh = blacken $ joinL tl tr
| otherwise = blacken $ joinR tl tr
where
rh = height tr
lh = height tl
joinL Empty _ = Empty
joinL l Empty = l
joinL l@(Leaf x) r@(Leaf y) = red 2 l (extractKey x y) r
joinL l@(Node c h ll k lr) r
| h == rh = let lm = unsafeMax lr
rm = unsafeMin r
in balance Red (h+1) l (extractKey lm rm) r
| otherwise = balance c h ll k (joinL lr r)
joinL _ _ = error "joinL. absurd"
joinR _ Empty = Empty
joinR Empty r = r
joinR l@(Leaf x) r@(Leaf y) = red 2 l (extractKey x y) r
joinR l r@(Node c h rl k rr)
| h == lh = let lm = unsafeMax l
rm = unsafeMin rl
in balance Red (h+1) l (extractKey lm rm) r
| otherwise = balance c h (joinR l rl) k rr
joinR _ _ = error "joinR absurd"
data Pair a b = Pair { fst' :: !a
, snd' :: b
} deriving (Show,Eq,Functor,Foldable,Traversable)
collect :: b -> [Pair a b] -> Pair [a] b
collect def [] = Pair [] def
collect _ xs = Pair (map fst' xs) (snd' $ last xs)
extractPrefix :: BalBST k a -> [Pair a (Tree k a)]
extractPrefix (BalBST n@Nav{..} t) = extractPrefix' t
where
extractPrefix' Empty = []
extractPrefix' (Leaf x) = [Pair x Empty]
extractPrefix' (Node _ _ l _ r) = ls ++ extractPrefix' r
where
ls = map (fmap $ flip (joinWith n) r) $ extractPrefix' l
extractSuffix :: BalBST k a -> [Pair a (Tree k a)]
extractSuffix (BalBST n@Nav{..} t) = extract t
where
extract Empty = []
extract (Leaf x) = [Pair x Empty]
extract (Node _ _ l _ r) = rs ++ extract l
where
rs = map (fmap $ joinWith n l) $ extract r
data Split a b = Split a !b a deriving (Show,Eq)
split :: Eq a => a -> BalBST k a -> Split (Tree k a) (Maybe a)
split x (BalBST n@Nav{..} t) = split' t
where
split' Empty = Split Empty Nothing Empty
split' l@(Leaf y)
| x == y = Split Empty (Just y) Empty
| goLeft x (extractKey x y) = Split l Nothing Empty
| otherwise = Split Empty Nothing l
split' (Node _ _ l k r)
| goLeft x k = let Split l' mx r' = split' l
in Split l' mx (joinWith n r' r)
| otherwise = let Split l' mx r' = split' r
in Split (joinWith n l l') mx r'
splitMonotone :: (a -> Bool) -> BalBST k a
-> (BalBST k a, BalBST k a)
splitMonotone p (BalBST n@Nav{..} t) = bimap (BalBST n) (BalBST n) $ split' t
where
split' Empty = (Empty,Empty)
split' l@(Leaf y)
| p y = (Empty,l)
| otherwise = (l,Empty)
split' (Node _ _ l _ r)
| p (unsafeMin r) = let (l',m) = split' l in (l',joinWith n m r)
| otherwise = let (m,r') = split' r in (joinWith n l m, r')
splitExtract :: (a -> Bool) -> (a -> Bool) -> BalBST k a
-> Split (BalBST k a) ([a],[a])
splitExtract p sel bst = Split (BalBST n before) (reverse mid1,mid2) (BalBST n after)
where
n = nav bst
(before',after') = splitMonotone p bst
extract def = collect def . L.takeWhile (sel . fst')
Pair mid1 before = extract (toTree before') $ extractSuffix before'
Pair mid2 after = extract (toTree after') $ extractPrefix after'
data T k a = Internal !Color !Height !k | Val !a deriving (Show,Eq,Ord)
toRoseTree :: Tree k a -> Maybe (T.Tree (T k a))
toRoseTree Empty = Nothing
toRoseTree (Leaf x) = Just $ T.Node (Val x) []
toRoseTree (Node c h l k r) = Just $ T.Node (Internal c h k) (mapMaybe toRoseTree [l,r])
showTree :: (Show k, Show a) => BalBST k a -> String
showTree = maybe "Empty" T.drawTree . fmap (fmap show) . toRoseTree . toTree
unsafeMin :: Tree k a -> a
unsafeMin (Leaf x) = x
unsafeMin (Node _ _ l _ _) = unsafeMin l
unsafeMin _ = error "unsafeMin: Empty"
unsafeMax :: Tree k a -> a
unsafeMax (Leaf x) = x
unsafeMax (Node _ _ _ _ r) = unsafeMax r
unsafeMax _ = error "unsafeMax: Empty"
toList :: BalBST k a -> [a]
toList = toList' . toTree
toList' :: Tree k a -> [a]
toList' Empty = []
toList' (Leaf x) = [x]
toList' (Node _ _ l _ r) = toList' l ++ toList' r
black :: Height -> Tree k a -> k -> Tree k a -> Tree k a
black = Node Black
red :: Height -> Tree k a -> k -> Tree k a -> Tree k a
red = Node Red
blacken :: Tree k a -> Tree k a
blacken (Node Red h l k r) = Node Black h l k r
blacken t = t
balance :: Color -> Height -> Tree k a -> k -> Tree k a -> Tree k a
balance Black h (Node Red _ (Node Red _ a x b) y c) z d = mkNode h a x b y c z d
balance Black h (Node Red _ a x (Node Red _ b y c)) z d = mkNode h a x b y c z d
balance Black h a x (Node Red _ (Node Red _ b y c) z d) = mkNode h a x b y c z d
balance Black h a x (Node Red _ b y (Node Red _ c z d)) = mkNode h a x b y c z d
balance co h a x b = Node co h a x b
mkNode :: Height
-> Tree k a -> k -> Tree k a -> k -> Tree k a -> k -> Tree k a
-> Tree k a
mkNode h a x b y c z d = red h (black h a x b) y (black h c z d)
height :: Tree k a -> Height
height Empty = 0
height (Leaf _) = 1
height (Node _ h _ _ _) = h