-- This file is part of the 'union-find-array' library. It is licensed
-- under an MIT license. See the accompanying 'LICENSE' file for details.
--
-- Authors: Bertram Felgenhauer

{-# LANGUAGE GeneralizedNewtypeDeriving, RankNTypes, FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-- |
-- Monadic interface for creating a disjoint set data structure.
--
module Control.Monad.Union (
  UnionM,
  Union (..),
  MonadUnion (..),
  Node,
  run,
  run',
) where

import Control.Monad.Union.Class
import qualified Data.Union.ST as US
import Data.Union.Type (Node (..), Union (..))

import Prelude hiding (lookup)
import Control.Monad
import Control.Monad.State
import Control.Monad.ST
import Control.Monad.Fix
import Control.Applicative
import Control.Arrow (first)

data UState s l = UState {
    forall s l. UState s l -> Int
next   :: !Int,
    forall s l. UState s l -> UnionST s l
forest :: US.UnionST s l
}

-- | Union find monad.
newtype UnionM l a = U {
    forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU :: forall s . StateT (UState s l) (ST s) a
}

instance Monad (UnionM l) where
    return :: forall a. a -> UnionM l a
return a
x =  (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U (a -> StateT (UState s l) (ST s) a
forall a. a -> StateT (UState s l) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x)
    UnionM l a
f >>= :: forall a b. UnionM l a -> (a -> UnionM l b) -> UnionM l b
>>= a -> UnionM l b
b = (forall s. StateT (UState s l) (ST s) b) -> UnionM l b
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U (UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU UnionM l a
f StateT (UState s l) (ST s) a
-> (a -> StateT (UState s l) (ST s) b)
-> StateT (UState s l) (ST s) b
forall a b.
StateT (UState s l) (ST s) a
-> (a -> StateT (UState s l) (ST s) b)
-> StateT (UState s l) (ST s) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
v -> UnionM l b -> forall s. StateT (UState s l) (ST s) b
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU (a -> UnionM l b
b a
v))

instance Functor (UnionM l) where
    fmap :: forall a b. (a -> b) -> UnionM l a -> UnionM l b
fmap = (a -> b) -> UnionM l a -> UnionM l b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM

instance Applicative (UnionM l) where
    pure :: forall a. a -> UnionM l a
pure = a -> UnionM l a
forall a. a -> UnionM l a
forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: forall a b. UnionM l (a -> b) -> UnionM l a -> UnionM l b
(<*>) = UnionM l (a -> b) -> UnionM l a -> UnionM l b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance MonadFix (UnionM l) where
    mfix :: forall a. (a -> UnionM l a) -> UnionM l a
mfix a -> UnionM l a
a = (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((a -> StateT (UState s l) (ST s) a) -> StateT (UState s l) (ST s) a
forall a.
(a -> StateT (UState s l) (ST s) a) -> StateT (UState s l) (ST s) a
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (\a
v -> UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU (a -> UnionM l a
a a
v)))

-- | Run a union find computation.
run :: UnionM l a -> a
run :: forall l a. UnionM l a -> a
run UnionM l a
a = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ do
    UnionST s l
u <- Int -> l -> ST s (UnionST s l)
forall l s. Int -> l -> ST s (UnionST s l)
US.new Int
1 l
forall a. HasCallStack => a
undefined
    StateT (UState s l) (ST s) a -> UState s l -> ST s a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU UnionM l a
a) UState{ next :: Int
next = Int
0, forest :: UnionST s l
forest = UnionST s l
u }

-- | Run a union find computation; also return the final disjoint set forest
-- for querying.
run' :: UnionM l a -> (Union l, a)
run' :: forall l a. UnionM l a -> (Union l, a)
run' UnionM l a
a = (forall s. ST s (Union l, a)) -> (Union l, a)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Union l, a)) -> (Union l, a))
-> (forall s. ST s (Union l, a)) -> (Union l, a)
forall a b. (a -> b) -> a -> b
$ do
    UnionST s l
u <- Int -> l -> ST s (UnionST s l)
forall l s. Int -> l -> ST s (UnionST s l)
US.new Int
1 l
forall a. HasCallStack => a
undefined
    (a
a, UState s l
s) <- StateT (UState s l) (ST s) a -> UState s l -> ST s (a, UState s l)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (UnionM l a -> forall s. StateT (UState s l) (ST s) a
forall l a. UnionM l a -> forall s. StateT (UState s l) (ST s) a
runU UnionM l a
a) UState{ next :: Int
next = Int
0, forest :: UnionST s l
forest = UnionST s l
u }
    Union l
f <- UnionST s l -> ST s (Union l)
forall s l. UnionST s l -> ST s (Union l)
US.unsafeFreeze (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
s)
    (Union l, a) -> ST s (Union l, a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Union l
f, a
a)

instance MonadUnion l (UnionM l) where
    -- Add a new node, with a given label.
    new :: l -> UnionM l Node
new l
l = (forall s. StateT (UState s l) (ST s) Node) -> UnionM l Node
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) Node) -> UnionM l Node)
-> (forall s. StateT (UState s l) (ST s) Node) -> UnionM l Node
forall a b. (a -> b) -> a -> b
$ do
        UState s l
u <- StateT (UState s l) (ST s) (UState s l)
forall s (m :: * -> *). MonadState s m => m s
get
        let size :: Int
size = UnionST s l -> Int
forall s l. UnionST s l -> Int
US.size (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
u)
            n :: Int
n    = UState s l -> Int
forall s l. UState s l -> Int
next UState s l
u
        if (Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= UState s l -> Int
forall s l. UState s l -> Int
next UState s l
u) then do
            UnionST s l
forest' <- ST s (UnionST s l) -> StateT (UState s l) (ST s) (UnionST s l)
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (UnionST s l) -> StateT (UState s l) (ST s) (UnionST s l))
-> ST s (UnionST s l) -> StateT (UState s l) (ST s) (UnionST s l)
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> ST s (UnionST s l)
forall s l. UnionST s l -> Int -> ST s (UnionST s l)
US.grow (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
u) (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
size)
            ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> l -> ST s ()
forall s l. UnionST s l -> Int -> l -> ST s ()
US.annotate UnionST s l
forest' Int
n l
l
            UState s l -> StateT (UState s l) (ST s) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put UState s l
u{ forest = forest', next = n + 1 }
         else do
            ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> l -> ST s ()
forall s l. UnionST s l -> Int -> l -> ST s ()
US.annotate (UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest UState s l
u) Int
n l
l
            UState s l -> StateT (UState s l) (ST s) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put UState s l
u{ next = n + 1 }
        Node -> StateT (UState s l) (ST s) Node
forall a. a -> StateT (UState s l) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Node
Node Int
n)

    -- Find the node representing a given node, and its label.
    lookup :: Node -> UnionM l (Node, l)
lookup (Node Int
n) = (forall s. StateT (UState s l) (ST s) (Node, l))
-> UnionM l (Node, l)
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) (Node, l))
 -> UnionM l (Node, l))
-> (forall s. StateT (UState s l) (ST s) (Node, l))
-> UnionM l (Node, l)
forall a b. (a -> b) -> a -> b
$ do
        UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
        (Int -> Node) -> (Int, l) -> (Node, l)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Int -> Node
Node ((Int, l) -> (Node, l))
-> StateT (UState s l) (ST s) (Int, l)
-> StateT (UState s l) (ST s) (Node, l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ST s (Int, l) -> StateT (UState s l) (ST s) (Int, l)
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (UnionST s l -> Int -> ST s (Int, l)
forall s l. UnionST s l -> Int -> ST s (Int, l)
US.lookup UnionST s l
dsf Int
n)

    -- Merge two sets. The first argument is a function that takes the labels
    -- of the corresponding sets' representatives and computes a new label for
    -- the joined set. Returns Nothing if the given nodes are in the same set
    -- already.
    merge :: forall a. (l -> l -> (l, a)) -> Node -> Node -> UnionM l (Maybe a)
merge l -> l -> (l, a)
f (Node Int
n) (Node Int
m) = (forall s. StateT (UState s l) (ST s) (Maybe a))
-> UnionM l (Maybe a)
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) (Maybe a))
 -> UnionM l (Maybe a))
-> (forall s. StateT (UState s l) (ST s) (Maybe a))
-> UnionM l (Maybe a)
forall a b. (a -> b) -> a -> b
$ do
        UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
        ST s (Maybe a) -> StateT (UState s l) (ST s) (Maybe a)
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (Maybe a) -> StateT (UState s l) (ST s) (Maybe a))
-> ST s (Maybe a) -> StateT (UState s l) (ST s) (Maybe a)
forall a b. (a -> b) -> a -> b
$ UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
forall s l a.
UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
US.merge UnionST s l
dsf l -> l -> (l, a)
f Int
n Int
m

    -- Re-label a node.
    annotate :: Node -> l -> UnionM l ()
annotate (Node Int
n) l
l = (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) ()) -> UnionM l ())
-> (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall a b. (a -> b) -> a -> b
$ do
        UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
        ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> Int -> l -> ST s ()
forall s l. UnionST s l -> Int -> l -> ST s ()
US.annotate UnionST s l
dsf Int
n l
l

    -- Flatten the disjoint set forest for faster lookups.
    flatten :: UnionM l ()
flatten = (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall l a. (forall s. StateT (UState s l) (ST s) a) -> UnionM l a
U ((forall s. StateT (UState s l) (ST s) ()) -> UnionM l ())
-> (forall s. StateT (UState s l) (ST s) ()) -> UnionM l ()
forall a b. (a -> b) -> a -> b
$ do
        UnionST s l
dsf <- (UState s l -> UnionST s l)
-> StateT (UState s l) (ST s) (UnionST s l)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UState s l -> UnionST s l
forall s l. UState s l -> UnionST s l
forest
        ST s () -> StateT (UState s l) (ST s) ()
forall (m :: * -> *) a. Monad m => m a -> StateT (UState s l) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> StateT (UState s l) (ST s) ())
-> ST s () -> StateT (UState s l) (ST s) ()
forall a b. (a -> b) -> a -> b
$ UnionST s l -> ST s ()
forall s l. UnionST s l -> ST s ()
US.flatten UnionST s l
dsf