{-# LANGUAGE RecordWildCards #-}
module Data.BalBST where

import           Data.Bifunctor
import           Data.Function (on)
import           Data.Functor.Contravariant
import qualified Data.List as L
import           Data.Maybe
import qualified Data.Tree as T
import           Prelude hiding (lookup,null)

--------------------------------------------------------------------------------

-- | Describes how to search in a tree
data TreeNavigator k a = Nav { goLeft     :: a -> k -> Bool
                             , extractKey :: a -> a -> k
                             }

instance Contravariant (TreeNavigator k) where
  contramap f (Nav gL eK) = Nav (\a k -> gL (f a) k) (\x y -> eK (f x) (f y))


ordNav :: Ord a => TreeNavigator a a
ordNav = Nav (<=) min


ordNavBy   :: Ord b => (a -> b) ->  TreeNavigator b a
ordNavBy f = Nav (\x k -> f x <= k) (min `on` f)


-- instance Functor (TreeNavigator k) where
--   fmap f Nav{..} = Nav (\b k -> )



-- | A balanced binary search tree
data BalBST k a = BalBST { nav    :: !(TreeNavigator k a)
                         , toTree :: !(Tree k a)
                         }

instance (Show k, Show a) => Show (BalBST k a) where
  show (BalBST _ t) = "BalBST (" ++ show t ++ ")"


data Color = Red | Black deriving (Show,Read,Eq,Ord)

type Height = Int

-- Red-Black tree with values in the leaves
data Tree k a = Empty
              | Leaf !a
              | Node !Color !Height (Tree k a) !k (Tree k a) deriving (Show,Eq)

--------------------------------------------------------------------------------

-- | Creates an empty BST
empty   :: TreeNavigator k a -> BalBST k a
empty n = BalBST n Empty


-- | \(O(n\log n)\)
fromList :: TreeNavigator k a -> [a] -> BalBST k a
fromList n = foldr insert (empty n)

fromList' :: Ord a => [a] -> BalBST a a
fromList' = fromList ordNav


-- -- | \(O(n)\)
-- fromAscList :: TreeNavigator k a -> [a] -> BalBST k a
-- fromAscList = undefined


--------------------------------------------------------------------------------

-- | Check if the tree is empty
null                  :: BalBST k a -> Bool
null (BalBST _ Empty) = True
null _                = False

-- | Test if an element occurs in the BST.
-- \(O(\log n)\)
lookup :: Eq a => a -> BalBST k a -> Maybe a
lookup x (BalBST Nav{..} t) = lookup' t
  where
    lookup' Empty            = Nothing
    lookup' (Leaf y)         = if x == y then Just y else Nothing
    lookup' (Node _ _ l k r)
      | goLeft x k           = lookup' l
      | otherwise            = lookup' r

-- | \(O(\log n)\)
member   :: Eq a => a -> BalBST k a -> Bool
member x = isJust . lookup x





-- | Insert an element in the BST.
--
-- \(O(\log n)\)
insert :: a -> BalBST k a -> BalBST k a
insert x (BalBST n@Nav{..} t) = BalBST n (blacken $ insert' t)
  where
    insert' Empty    = Leaf x
    insert' (Leaf y) = let k     = extractKey x y
                           (l,r) = if goLeft x k then (x,y) else (y,x)
                       in red 2 (Leaf l) k (Leaf r)
    insert' (Node c h l k r)
      | goLeft  x k  = balance c h (insert' l) k r
      | otherwise    = balance c h l           k (insert' r)



-- delete = undefined

-- delete                        :: Eq a => a -> BalBST k a -> BalBST k a
-- delete x (BalBST n@Nav{..} t) = delete' t
--   where
--     delete' Empty      = Empty
--     delete' l@(Leaf y) = if x == y then Empty else l
--     delete' (Node c h l k r)
--       | goLeft x k     =


--------------------------------------------------------------------------------


-- | Extract the minimum from the tree
-- \(O(\log n)\)
minView              :: BalBST k a -> Maybe (a, Tree k a)
minView (BalBST n t) = minView' t
  where
    minView' Empty            = Nothing
    minView' (Leaf x)         = Just (x,Empty)
    minView' (Node _ _ l _ r) = fmap (flip (joinWith n) r) <$> minView' l

-- | Extract the maximum from the tree
-- \(O(\log n)\)
maxView              :: BalBST k a -> Maybe (a, Tree k a)
maxView (BalBST n t) = maxView' t
  where
    maxView' Empty            = Nothing
    maxView' (Leaf x)         = Just (x,Empty)
    maxView' (Node _ _ l _ r) = fmap (joinWith n l) <$> maxView' r

-- | Joins two BSTs. Assumes that the ranges are disjoint. It takes the left Tree nav
--
-- \(O(\log n)\)
join                           :: BalBST k a -> BalBST k a -> BalBST k a
join (BalBST n l) (BalBST _ r) = BalBST n $ joinWith n l r

-- | Joins two BSTs' with a specific Tree Navigator
--
-- \(O(\log n)\)
joinWith               :: TreeNavigator k a -> Tree k a -> Tree k a -> Tree k a
joinWith Nav{..} tl tr
    | lh >= rh         = blacken $ joinL tl tr
    | otherwise        = blacken $ joinR tl tr
  where
    rh = height tr
    lh = height tl

    joinL Empty      _           = Empty
    joinL l          Empty       = l
    joinL l@(Leaf x) r@(Leaf y)  = red 2 l (extractKey x y) r
    joinL l@(Node c h ll k lr) r
      | h == rh                  = let lm = unsafeMax lr
                                       rm = unsafeMin r
                                   in balance Red (h+1) l (extractKey lm rm) r
      | otherwise                = balance c h ll k (joinL lr r)
        -- lh >= rh
    joinL _ _ = error "joinL. absurd"


    joinR _          Empty       = Empty
    joinR Empty      r           = r

    joinR l@(Leaf x) r@(Leaf y)  = red 2 l (extractKey x y) r
    joinR l r@(Node c h rl k rr)
      | h == lh                  = let lm = unsafeMax l
                                       rm = unsafeMin rl
                                   in balance Red (h+1) l (extractKey lm rm) r
      | otherwise                = balance c h (joinR l rl) k rr
        -- lh >= rh
    joinR _ _ = error "joinR absurd"


--------------------------------------------------------------------------------
-- | Splitting and extracting

-- | A pair that is strict in its first argument and lazy in the second.
data Pair a b = Pair { fst' :: !a
                     , snd' :: b
                     } deriving (Show,Eq,Functor,Foldable,Traversable)


collect        :: b -> [Pair a b] -> Pair [a] b
collect def [] = Pair [] def
collect _   xs = Pair (map fst' xs) (snd' $ last xs)


-- | Extract a prefix from the tree, i.e. a repeated 'minView'
--
-- \(O(\log n +k)\), where \(k\) is the size of the extracted part
extractPrefix                      :: BalBST k a -> [Pair a (Tree k a)]
extractPrefix (BalBST n@Nav{..} t) = extractPrefix' t
  where
    extractPrefix' Empty            = []
    extractPrefix' (Leaf x)         = [Pair x Empty]
    extractPrefix' (Node _ _ l _ r) = ls ++ extractPrefix' r
      where
        ls = map (fmap $ flip (joinWith n) r) $ extractPrefix' l

-- | Extract a suffix from the tree, i.e. a repeated 'minView'
--
-- \(O(\log n +k)\), where \(k\) is the size of the extracted part
extractSuffix                      :: BalBST k a -> [Pair a (Tree k a)]
extractSuffix (BalBST n@Nav{..} t) = extract t
  where
    extract Empty            = []
    extract (Leaf x)         = [Pair x Empty]
    extract (Node _ _ l _ r) = rs ++ extract l
      where
        rs = map (fmap $ joinWith n l) $ extract r

-- | Result of splititng a tree
data Split a b = Split a !b a deriving (Show,Eq)

-- | Splits the tree at x. Note that if x occurs more often, no guarantees are
-- given which one is found.
--
-- \(O(\log n)\)
split                        :: Eq a => a -> BalBST k a -> Split (Tree k a) (Maybe a)
split x (BalBST n@Nav{..} t) = split' t
  where
    split' Empty                  = Split Empty Nothing Empty
    split' l@(Leaf y)
      | x == y                    = Split Empty (Just y) Empty
      | goLeft x (extractKey x y) = Split l     Nothing  Empty
      | otherwise                 = Split Empty Nothing  l
    split' (Node _ _ l k r)
      | goLeft x k                = let Split l' mx r' = split' l
                                    in Split l' mx (joinWith n r' r)
      | otherwise                 = let Split l' mx r' = split' r
                                    in Split (joinWith n l l') mx r'

-- | split based on a monotonic predicate
--
-- \(O(\log n)\)
splitMonotone                        :: (a -> Bool) -> BalBST k a
                                     -> (BalBST k a, BalBST k a)
splitMonotone p (BalBST n@Nav{..} t) = bimap (BalBST n) (BalBST n) $ split' t
  where
    split' Empty        = (Empty,Empty)
    split' l@(Leaf y)
      | p y             = (Empty,l)
      | otherwise       = (l,Empty)
    split' (Node _ _ l _ r)
      | p (unsafeMin r) = let (l',m) = split' l in (l',joinWith n m r)
      | otherwise       = let (m,r') = split' r in (joinWith n l m, r')


-- | Splits at a given monotone predicate p, and then selects everything that
-- satisfies the predicate sel.
splitExtract           :: (a -> Bool) -> (a -> Bool) -> BalBST k a
                       -> Split (BalBST k a) ([a],[a])
splitExtract p sel bst = Split (BalBST n before) (reverse mid1,mid2) (BalBST n after)
  where
    n                = nav bst
    (before',after') = splitMonotone p bst

    extract def = collect def . L.takeWhile (sel . fst')

    Pair mid1 before = extract (toTree before') $ extractSuffix before'
    Pair mid2 after  = extract (toTree after')  $ extractPrefix after'


--------------------------------------------------------------------------------


data T k a = Internal !Color !Height !k | Val !a deriving (Show,Eq,Ord)

toRoseTree :: Tree k a -> Maybe (T.Tree (T k a))
toRoseTree Empty            = Nothing
toRoseTree (Leaf x)         = Just $ T.Node (Val x) []
toRoseTree (Node c h l k r) = Just $ T.Node (Internal c h k) (mapMaybe toRoseTree [l,r])


showTree :: (Show k, Show a) => BalBST k a -> String
showTree = maybe "Empty" T.drawTree . fmap (fmap show) . toRoseTree . toTree

-- | Get the minimum in the tree. Errors when the tree is empty
--
-- \(O(\log n)\)
unsafeMin                  :: Tree k a -> a
unsafeMin (Leaf x)         = x
unsafeMin (Node _ _ l _ _) = unsafeMin l
unsafeMin _                = error "unsafeMin: Empty"

-- | Get the maximum in the tree. Errors when the tree is empty
--
-- \(O(\log n)\)
unsafeMax                  :: Tree k a -> a
unsafeMax (Leaf x)         = x
unsafeMax (Node _ _ _ _ r) = unsafeMax r
unsafeMax _                = error "unsafeMax: Empty"

-- | Extract all elements in the tree
--
-- \(O(n)\)
toList :: BalBST k a -> [a]
toList = toList' . toTree

-- | Extract all elements in the tree
--
-- \(O(n)\)
toList'                  :: Tree k a -> [a]
toList' Empty            = []
toList' (Leaf x)         = [x]
toList' (Node _ _ l _ r) = toList' l ++ toList' r


--------------------------------------------------------------------------------
-- * Helper stuff

black :: Height -> Tree k a -> k -> Tree k a -> Tree k a
black = Node Black

red :: Height -> Tree k a -> k -> Tree k a -> Tree k a
red = Node Red


blacken                    :: Tree k a -> Tree k a
blacken (Node Red h l k r) = Node Black h l k r
blacken t                  = t

-- | rebalance the tree
balance  :: Color -> Height -> Tree k a -> k -> Tree k a -> Tree k a
balance Black h (Node Red _ (Node Red _ a x b) y c) z d = mkNode h a x b y c z d
balance Black h (Node Red _ a x (Node Red _ b y c)) z d = mkNode h a x b y c z d
balance Black h a x (Node Red _ (Node Red _ b y c) z d) = mkNode h a x b y c z d
balance Black h a x (Node Red _ b y (Node Red _ c z d)) = mkNode h a x b y c z d
balance co h a x b                                      = Node co h a x b

mkNode                 :: Height
                       -> Tree k a -> k -> Tree k a -> k -> Tree k a  -> k -> Tree k a
                       -> Tree k a
mkNode h a x b y c z d = red h (black h a x b) y (black h c z d)

height                  :: Tree k a -> Height
height Empty            = 0
height (Leaf _)         = 1
height (Node _ h _ _ _) = h