{-# LANGUAGE GeneralizedNewtypeDeriving, RankNTypes, FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Control.Monad.Union (
UnionM,
Union (..),
MonadUnion (..),
Node,
run,
run',
) where
import Control.Monad.Union.Class
import qualified Data.Union.ST as US
import Data.Union.Type (Node (..), Union (..))
import Prelude hiding (lookup)
import Control.Monad
import Control.Monad.State
import Control.Monad.ST
import Control.Monad.Fix
import Control.Applicative
import Control.Arrow (first)
data UState s l = UState {
forall s l. UState s l -> Int
next :: !Int,
forall s l. UState s l -> UnionST s l
forest :: US.UnionST s l
}
newtype UnionM l a = U {
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU :: forall s . StateT (UState s l) (ST s) a
}
instance Monad (UnionM l) where
return :: forall a. a -> UnionM l a
return a
x = (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U (a -> StateT (UState s l) (ST s) a
forall a. a -> StateT (UState s l) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x)
UnionM l a
f >>= :: forall a b. UnionM l a -> (a -> UnionM l b) -> UnionM l b
>>= a -> UnionM l b
b = (forall s. StateT (UState s l) (ST s) b) -> UnionM l b
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U (UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU UnionM l a
f StateT (UState s l) (ST s) a
-> (a -> StateT (UState s l) (ST s) b)
-> StateT (UState s l) (ST s) b
forall a b.
StateT (UState s l) (ST s) a
-> (a -> StateT (UState s l) (ST s) b)
-> StateT (UState s l) (ST s) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
v -> UnionM l b -> forall s. StateT (UState s l) (ST s) b
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU (a -> UnionM l b
b a
v))
instance Functor (UnionM l) where
fmap :: forall a b. (a -> b) -> UnionM l a -> UnionM l b
fmap = (a -> b) -> UnionM l a -> UnionM l b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM
instance Applicative (UnionM l) where
pure :: forall a. a -> UnionM l a
pure = a -> UnionM l a
forall a. a -> UnionM l a
forall (m :: * -> *) a. Monad m => a -> m a
return
<*> :: forall a b. UnionM l (a -> b) -> UnionM l a -> UnionM l b
(<*>) = UnionM l (a -> b) -> UnionM l a -> UnionM l b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance MonadFix (UnionM l) where
mfix :: forall a. (a -> UnionM l a) -> UnionM l a
mfix a -> UnionM l a
a = (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((a -> StateT (UState s l) (ST s) a) -> StateT (UState s l) (ST s) a
forall a.
(a -> StateT (UState s l) (ST s) a) -> StateT (UState s l) (ST s) a
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (\a
v -> UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU (a -> UnionM l a
a a
v)))
run :: UnionM l a -> a
run :: forall l a. UnionM l a -> a
run UnionM l a
a = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ do
UnionST s l
u <- Int -> l -> ST s (UnionST s l)
forall l s. Int -> l -> ST s (UnionST s l)
US.new Int
1 l
forall a. HasCallStack => a
undefined
StateT (UState s l) (ST s) a -> UState s l -> ST s a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU UnionM l a
a) UState{ next :: Int
next = Int
0, forest :: UnionST s l
forest = UnionST s l
u }
run' :: UnionM l a -> (Union l, a)
run' :: forall l a. UnionM l a -> (Union l, a)
run' UnionM l a
a = (forall s. ST s (Union l, a)) -> (Union l, a)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Union l, a)) -> (Union l, a))
-> (forall s. ST s (Union l, a)) -> (Union l, a)
forall a b. (a -> b) -> a -> b
$ do
UnionST s l
u <- Int -> l -> ST s (UnionST s l)
forall l s. Int -> l -> ST s (UnionST s l)
US.new Int
1 l
forall a. HasCallStack => a
undefined
(a
a, UState s l
s) <- StateT (UState s l) (ST s) a -> UState s l -> ST s (a, UState s l)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU UnionM l a
a) UState{ next :: Int
next = Int
0, forest :: UnionST s l
forest = UnionST s l
u }
Union l
f <- UnionST s l -> ST s (Union l)
forall s l. UnionST s l -> ST s (Union l)
US.unsafeFreeze (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
s)
(Union l, a) -> ST s (Union l, a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Union l
f, a
a)
instance MonadUnion l (UnionM l) where
new :: l -> UnionM l Node
new l
l = (forall s. StateT (UState s l) (ST s) Node) -> UnionM l Node
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) Node) -> UnionM l Node)
-> (forall s. StateT (UState s l) (ST s) Node) -> UnionM l Node
forall a b. (a -> b) -> a -> b
$ do
UState s l
u <- StateT (UState s l) (ST s) (UState s l)
forall s (m :: * -> *). MonadState s m => m s
get
let size :: Int
size = UnionST s l -> Int
forall s l. UnionST s l -> Int
US.size (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
u)
n :: Int
n = UState s l -> Int
forall s l. UState s l -> Int
next UState s l
u
if (Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= UState s l -> Int
forall s l. UState s l -> Int
next UState s l
u) then do
UnionST s l
forest' <- ST s (UnionST s l) -> StateT (UState s l) (ST s) (UnionST s l)
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (UnionST s l) -> StateT (UState s l) (ST s) (UnionST s l))
-> ST s (UnionST s l) -> StateT (UState s l) (ST s) (UnionST s l)
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> ST s (UnionST s l)
forall s l. UnionST s l -> Int -> ST s (UnionST s l)
US.grow (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
u) (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
size)
ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> l -> ST s ()
forall s l. UnionST s l -> Int -> l -> ST s ()
US.annotate UnionST s l
forest' Int
n l
l
UState s l -> StateT (UState s l) (ST s) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put UState s l
u{ forest = forest', next = n + 1 }
else do
ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> l -> ST s ()
forall s l. UnionST s l -> Int -> l -> ST s ()
US.annotate (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
u) Int
n l
l
UState s l -> StateT (UState s l) (ST s) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put UState s l
u{ next = n + 1 }
Node -> StateT (UState s l) (ST s) Node
forall a. a -> StateT (UState s l) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Node
Node Int
n)
lookup :: Node -> UnionM l (Node, l)
lookup (Node Int
n) = (forall s. StateT (UState s l) (ST s) (Node, l))
-> UnionM l (Node, l)
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) (Node, l))
-> UnionM l (Node, l))
-> (forall s. StateT (UState s l) (ST s) (Node, l))
-> UnionM l (Node, l)
forall a b. (a -> b) -> a -> b
$ do
UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
(Int -> Node) -> (Int, l) -> (Node, l)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Int -> Node
Node ((Int, l) -> (Node, l))
-> StateT (UState s l) (ST s) (Int, l)
-> StateT (UState s l) (ST s) (Node, l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ST s (Int, l) -> StateT (UState s l) (ST s) (Int, l)
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (UnionST s l -> Int -> ST s (Int, l)
forall s l. UnionST s l -> Int -> ST s (Int, l)
US.lookup UnionST s l
dsf Int
n)
merge :: forall a. (l -> l -> (l, a)) -> Node -> Node -> UnionM l (Maybe a)
merge l -> l -> (l, a)
f (Node Int
n) (Node Int
m) = (forall s. StateT (UState s l) (ST s) (Maybe a))
-> UnionM l (Maybe a)
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) (Maybe a))
-> UnionM l (Maybe a))
-> (forall s. StateT (UState s l) (ST s) (Maybe a))
-> UnionM l (Maybe a)
forall a b. (a -> b) -> a -> b
$ do
UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
ST s (Maybe a) -> StateT (UState s l) (ST s) (Maybe a)
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (Maybe a) -> StateT (UState s l) (ST s) (Maybe a))
-> ST s (Maybe a) -> StateT (UState s l) (ST s) (Maybe a)
forall a b. (a -> b) -> a -> b
$ UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
forall s l a.
UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
US.merge UnionST s l
dsf l -> l -> (l, a)
f Int
n Int
m
annotate :: Node -> l -> UnionM l ()
annotate (Node Int
n) l
l = (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) ()) -> UnionM l ())
-> (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall a b. (a -> b) -> a -> b
$ do
UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> l -> ST s ()
forall s l. UnionST s l -> Int -> l -> ST s ()
US.annotate UnionST s l
dsf Int
n l
l
flatten :: UnionM l ()
flatten = (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) ()) -> UnionM l ())
-> (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall a b. (a -> b) -> a -> b
$ do
UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> ST s ()
forall s l. UnionST s l -> ST s ()
US.flatten UnionST s l
dsf