{-# LANGUAGE BangPatterns #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeSynonymInstances #-} module Data.Graph.Dynamic.Internal.Splay ( Tree , singleton , cons , snoc , append , split , connected , root , aggregate , toList -- * Debugging only , readRoot , freeze , print , assertInvariants ) where import Control.Monad (when) import Control.Monad.Primitive (PrimMonad (..)) import qualified Data.Graph.Dynamic.Internal.Tree as Class import Data.Monoid ((<>)) import Data.Primitive.MutVar (MutVar) import qualified Data.Primitive.MutVar as MutVar import qualified Data.Tree as Tree import Prelude hiding (concat, print) import System.IO.Unsafe (unsafePerformIO) import Unsafe.Coerce (unsafeCoerce) data T s a v = T { tParent :: {-# UNPACK #-} !(Tree s a v) , tLeft :: {-# UNPACK #-} !(Tree s a v) , tRight :: {-# UNPACK #-} !(Tree s a v) , tLabel :: !a , tValue :: !v , tAgg :: !v } -- | NOTE (jaspervdj): There are two ways of indicating the parent / left / -- right is not set (we want to avoid Maybe's since they cause a lot of -- indirections). -- -- Imagine that we are considering tLeft. -- -- 1. We can set tLeft of x to the MutVar that holds the tree itself (i.e. a -- self-loop). -- 2. We can set tLeft to some nil value. -- -- They seem to offer similar performance. We choose to use the latter since it -- is less likely to end up in infinite loops that way, and additionally, we can -- move easily move e.g. x's left child to y's right child, even it is an empty -- child. nil :: Tree s a v nil = unsafeCoerce $ unsafePerformIO $ fmap Tree $ MutVar.newMutVar undefined {-# NOINLINE nil #-} newtype Tree s a v = Tree {unTree :: MutVar s (T s a v)} deriving (Eq) singleton :: PrimMonad m => a -> v -> m (Tree (PrimState m) a v) singleton tLabel tValue = fmap Tree $ MutVar.newMutVar $! T nil nil nil tLabel tValue tValue readRoot :: PrimMonad m => Tree (PrimState m) a v -> m (Tree (PrimState m) a v) readRoot tree = do T {..} <- MutVar.readMutVar (unTree tree) if tParent == nil then return tree else readRoot tParent -- | `lv` must be a singleton tree cons :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> Tree (PrimState m) a v -> m (Tree (PrimState m) a v) cons lt@(Tree lv) rt@(Tree rv) = do r <- MutVar.readMutVar rv MutVar.modifyMutVar' lv $ \l -> l {tRight = rt, tAgg = tAgg l <> tAgg r} MutVar.writeMutVar rv $! r {tParent = lt} return lt -- | `rv` must be a singleton tree snoc :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> Tree (PrimState m) a v -> m (Tree (PrimState m) a v) snoc lt@(Tree lv) rt@(Tree rv) = do l <- MutVar.readMutVar lv MutVar.modifyMutVar' rv $ \r -> r {tLeft = lt, tAgg = tAgg l <> tAgg r} MutVar.writeMutVar lv $! l {tParent = rt} return rt -- | Appends two trees. Returns the root of the tree. append :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> Tree (PrimState m) a v -> m (Tree (PrimState m) a v) append xt@(Tree _xv) yt@(Tree yv) = do rmt@(Tree rmv) <- getRightMost xt _ <- splay rmt y <- MutVar.readMutVar yv MutVar.modifyMutVar rmv $ \r -> r {tRight = yt, tAgg = tAgg r <> tAgg y} MutVar.writeMutVar yv $! y {tParent = rmt} return rmt where getRightMost tt@(Tree tv) = do t <- MutVar.readMutVar tv if tRight t == nil then return tt else getRightMost (tRight t) split :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m (Maybe (Tree (PrimState m) a v), Maybe (Tree (PrimState m) a v)) split xt@(Tree xv) = do _ <- splay xt T {..} <- MutVar.readMutVar xv when (tLeft /= nil) (removeParent tLeft) -- Works even if l is x when (tRight /= nil) (removeParent tRight) MutVar.writeMutVar xv $ T {tAgg = tValue, ..} removeLeft xt removeRight xt return ( if tLeft == nil then Nothing else Just tLeft , if tRight == nil then Nothing else Just tRight ) connected :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> Tree (PrimState m) a v -> m Bool connected x y = do _ <- splay x x' <- splay y return $ x == x' root :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m (Tree (PrimState m) a v) root x = do _ <- splay x return x label :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m a label (Tree xv) = tLabel <$> MutVar.readMutVar xv aggregate :: (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m v aggregate (Tree xv) = tAgg <$> MutVar.readMutVar xv -- | For debugging/testing. toList :: PrimMonad m => Tree (PrimState m) a v -> m [a] toList = go [] where go acc0 (Tree tv) = do T {..} <- MutVar.readMutVar tv acc1 <- if tRight == nil then return acc0 else go acc0 tRight let acc2 = tLabel : acc1 if tLeft == nil then return acc2 else go acc2 tLeft splay :: forall m a v. (PrimMonad m, Monoid v) => Tree (PrimState m) a v -> m (Tree (PrimState m) a v) -- Returns the old root. splay xt@(Tree xv) = do -- Note (jaspervdj): Rather than repeatedly reading from/writing to xv we -- read x once and thread its (continuously updated) value through the -- entire stack of `go` calls. -- -- The same is true for the left and right aggregates of x: they can be -- passed upwards rather than recomputed. x0 <- MutVar.readMutVar xv xla <- if tLeft x0 == nil then return mempty else tAgg <$> MutVar.readMutVar (unTree $ tLeft x0) xra <- if tRight x0 == nil then return mempty else tAgg <$> MutVar.readMutVar (unTree $ tRight x0) go xt xla xra x0 where go :: Tree (PrimState m) a v -> v -> v -> T (PrimState m) a v -> m (Tree (PrimState m) a v) go closestToRootFound xla xra !x = do let !(pt@(Tree pv)) = tParent x if pt == nil then do MutVar.writeMutVar xv x return closestToRootFound else do p <- MutVar.readMutVar pv let gt@(Tree gv) = tParent p let plt@(Tree plv) = tLeft p let prt@(Tree prv) = tRight p let xlt@(Tree xlv) = tLeft x let xrt@(Tree xrv) = tRight x if | gt == nil, plt == xt -> do -- ZIG (Right) -- -- p => x -- / \ -- x p -- \ / -- xr xr -- when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr -> xr {tParent = pt} pra <- if prt == nil then return mempty else tAgg <$> MutVar.readMutVar prv MutVar.writeMutVar pv $! p { tLeft = xrt , tParent = xt , tAgg = xra <> tValue p <> pra } MutVar.writeMutVar xv $! x { tAgg = tAgg p , tRight = pt , tParent = nil } return pt | gt == nil -> do -- ZIG (Left) -- -- p => x -- \ / -- x p -- / \ -- xl xl -- when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl -> xl {tParent = pt} pla <- if plt == nil then return mempty else tAgg <$> MutVar.readMutVar plv MutVar.writeMutVar pv $! p { tRight = xlt , tParent = xt , tAgg = pla <> tValue p <> xla } MutVar.writeMutVar xv $! x { tAgg = tAgg p , tLeft = pt , tParent = nil } return pt | otherwise -> do g <- MutVar.readMutVar gv let ggt@(Tree ggv) = tParent g let glt@(Tree glv) = tLeft g let grt@(Tree grv) = tRight g when (ggt /= nil) $ MutVar.modifyMutVar' ggv $ \gg -> if tLeft gg == gt then gg {tLeft = xt} else gg {tRight = xt} if | plt == xt && glt == pt -> do -- ZIGZIG (Right): -- -- gg gg -- | | -- g x -- / \ / \ -- p => p -- / \ / \ -- x pr xr g -- / \ / -- xr pr -- pra <- if prt == nil then return mempty else tAgg <$> MutVar.readMutVar prv gra <- if grt == nil then return mempty else tAgg <$> MutVar.readMutVar grv let !ga' = pra <> tValue g <> gra when (prt /= nil) $ MutVar.modifyMutVar' prv $ \pr -> pr {tParent = gt} MutVar.writeMutVar gv $! g { tParent = pt , tLeft = prt , tAgg = ga' } when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr -> xr {tParent = pt} let !pa' = xra <> tValue p <> ga' MutVar.writeMutVar pv $! p { tParent = xt , tLeft = xrt , tRight = gt , tAgg = pa' } go gt xla pa' $! x { tRight = pt , tAgg = tAgg g , tParent = ggt } | plv /= xv && glv /= pv -> do -- ZIGZIG (Left): -- -- gg gg -- | | -- g x -- / \ / \ -- p => p -- / \ / \ -- pl x g xl -- / \ / \ -- xl pl -- pla <- if plt == nil then return mempty else tAgg <$> MutVar.readMutVar plv gla <- if glt == nil then return mempty else tAgg <$> MutVar.readMutVar glv let !ga' = gla <> tValue g <> pla when (plt /= nil) $ MutVar.modifyMutVar' plv $ \pl -> pl {tParent = gt} MutVar.writeMutVar gv $! g { tParent = pt , tRight = plt , tAgg = ga' } when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl -> xl {tParent = pt} let !pa' = ga' <> tValue p <> xla MutVar.writeMutVar pv $! p { tParent = xt , tLeft = gt , tRight = xlt , tAgg = pa' } go gt pa' xra $! x { tLeft = pt , tAgg = tAgg g , tParent = ggt } | plv == xv -> do -- ZIGZIG (Left): -- -- gg gg -- | | -- g x -- \ / \ -- p => g p -- / \ / -- x xl xr -- / \ -- xl xr -- when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl -> xl {tParent = gt} gla <- if glt == nil then return mempty else tAgg <$> MutVar.readMutVar glv let !ga' = gla <> tValue g <> xla MutVar.writeMutVar gv $! g { tParent = xt , tRight = xlt , tAgg = ga' } when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr -> xr {tParent = pt} pra <- if prt == nil then return mempty else tAgg <$> MutVar.readMutVar prv let pa' = xra <> tValue p <> pra MutVar.writeMutVar pv $! p { tParent = xt , tLeft = xrt , tAgg = pa' } go gt ga' pa' $! x { tParent = ggt , tLeft = gt , tRight = pt , tAgg = tAgg g } | otherwise -> do -- ZIGZIG (Right): -- -- gg gg -- | | -- g x -- / / \ -- p => p g -- \ \ / -- x xl xr -- / \ -- xl xr -- when (xrt /= nil) $ MutVar.modifyMutVar' xrv $ \xr -> xr {tParent = gt} gra <- if grt == nil then return mempty else tAgg <$> MutVar.readMutVar grv let !ga' = xra <> tValue g <> gra MutVar.writeMutVar gv $! g { tParent = xt , tLeft = xrt , tAgg = ga' } when (xlt /= nil) $ MutVar.modifyMutVar' xlv $ \xl -> xl {tParent = pt} pla <- if plt == nil then return mempty else tAgg <$> MutVar.readMutVar plv let !pa' = pla <> tValue p <> xla MutVar.writeMutVar pv $! p { tParent = xt , tRight = xlt , tAgg = pa' } go gt pa' ga' $! x { tParent = ggt , tLeft = pt , tRight = gt , tAgg = tAgg g } removeParent, removeLeft, removeRight :: PrimMonad m => Tree (PrimState m) a v -- Parent -> m () removeParent (Tree x) = MutVar.modifyMutVar' x $ \x' -> x' {tParent = nil} removeLeft (Tree x) = MutVar.modifyMutVar' x $ \x' -> x' {tLeft = nil} removeRight (Tree x) = MutVar.modifyMutVar' x $ \x' -> x' {tRight = nil} -- | For debugging/testing. freeze :: PrimMonad m => Tree (PrimState m) a v -> m (Tree.Tree a) freeze (Tree tv) = do T {..} <- MutVar.readMutVar tv children <- sequence $ [freeze tLeft | tLeft /= nil] ++ [freeze tRight | tRight /= nil] return $ Tree.Node tLabel children print :: Show a => Tree (PrimState IO) a v -> IO () print = go 0 where go d (Tree tv) = do T {..} <- MutVar.readMutVar tv when (tLeft /= nil) $ go (d + 1) tLeft putStrLn $ replicate d ' ' ++ show tLabel when (tRight /= nil) $ go (d + 1) tRight assertInvariants :: (PrimMonad m, Monoid v, Eq v, Show v) => Tree (PrimState m) a v -> m () assertInvariants t = do _ <- computeAgg nil t return () where -- TODO: Check average computeAgg pt xt@(Tree xv) = do x' <- MutVar.readMutVar xv let p' = tParent x' when (pt /= p') $ fail "broken parent pointer" let l = tLeft x' let r = tRight x' la <- if l == nil then return mempty else computeAgg xt l ra <- if r == nil then return mempty else computeAgg xt r let actualAgg = la <> (tValue x') <> ra let storedAgg = tAgg x' when (actualAgg /= storedAgg) $ fail $ "error in stored aggregates: " ++ show storedAgg ++ ", actual: " ++ show actualAgg return actualAgg data TreeGen s = TreeGen instance Class.Tree Tree where type TreeGen Tree = TreeGen newTreeGen _ = return TreeGen singleton _ = singleton append = append split = split connected = connected root = root readRoot = readRoot label = label aggregate = aggregate toList = toList instance Class.TestTree Tree where print = print assertInvariants = assertInvariants assertSingleton = \_ -> return () assertRoot = \_ -> return ()