{-# LANGUAGE RecursiveDo #-}

-- |
-- Module     : Simulation.Aivika.Branch.Internal.BR
-- Copyright  : Copyright (c) 2016-2017, David Sorokin <david.sorokin@gmail.com>
-- License    : BSD3
-- Maintainer : David Sorokin <david.sorokin@gmail.com>
-- Stability  : experimental
-- Tested with: GHC 7.10.3
--
-- This module defines a branching computation.
--
module Simulation.Aivika.Branch.Internal.BR
       (BRParams(..),
        BR(..),
        invokeBR,
        runBR,
        newBRParams,
        newRootBRParams,
        branchLevel) where

import Data.IORef
import Data.Maybe

import Control.Applicative
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Fix
import Control.Exception (throw, catch, finally)

import Simulation.Aivika.Trans.Exception

-- | The branching computation.
newtype BR m a = BR { BR m a -> BRParams -> m a
unBR :: BRParams -> m a
                      -- ^ Unwrap the computation.
                    }

-- | The parameters of the computation.
data BRParams =
  BRParams { BRParams -> Int
brId :: !Int,
             -- ^ The branch identifier.
             BRParams -> IORef Int
brIdGenerator :: IORef Int,
             -- ^ The generator of identifiers.
             BRParams -> Int
brLevel :: !Int,
             -- ^ The branch level.
             BRParams -> Maybe BRParams
brParent :: Maybe BRParams,
             -- ^ The branch parent.
             BRParams -> IORef ()
brUniqueRef :: IORef ()
             -- ^ The unique reference to which
             -- the finalizers are attached to
             -- be garbage collected.
           }

instance Monad m => Monad (BR m) where

  {-# INLINE return #-}
  return :: a -> BR m a
return = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a)
-> (a -> BRParams -> m a) -> a -> BR m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> BRParams -> m a
forall a b. a -> b -> a
const (m a -> BRParams -> m a) -> (a -> m a) -> a -> BRParams -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return

  {-# INLINE (>>=) #-}
  (BR BRParams -> m a
m) >>= :: BR m a -> (a -> BR m b) -> BR m b
>>= a -> BR m b
k = (BRParams -> m b) -> BR m b
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m b) -> BR m b) -> (BRParams -> m b) -> BR m b
forall a b. (a -> b) -> a -> b
$ \BRParams
ps ->
    BRParams -> m a
m BRParams
ps m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a ->
    let m' :: BRParams -> m b
m' = BR m b -> BRParams -> m b
forall (m :: * -> *) a. BR m a -> BRParams -> m a
unBR (a -> BR m b
k a
a) in BRParams -> m b
m' BRParams
ps

instance Applicative m => Applicative (BR m) where

  {-# INLINE pure #-}
  pure :: a -> BR m a
pure = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a)
-> (a -> BRParams -> m a) -> a -> BR m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> BRParams -> m a
forall a b. a -> b -> a
const (m a -> BRParams -> m a) -> (a -> m a) -> a -> BRParams -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

  {-# INLINE (<*>) #-}
  (BR BRParams -> m (a -> b)
f) <*> :: BR m (a -> b) -> BR m a -> BR m b
<*> (BR BRParams -> m a
m) = (BRParams -> m b) -> BR m b
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m b) -> BR m b) -> (BRParams -> m b) -> BR m b
forall a b. (a -> b) -> a -> b
$ \BRParams
ps -> BRParams -> m (a -> b)
f BRParams
ps m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> BRParams -> m a
m BRParams
ps

instance Functor m => Functor (BR m) where

  {-# INLINE fmap #-}
  fmap :: (a -> b) -> BR m a -> BR m b
fmap a -> b
f (BR BRParams -> m a
m) = (BRParams -> m b) -> BR m b
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m b) -> BR m b) -> (BRParams -> m b) -> BR m b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f (m a -> m b) -> (BRParams -> m a) -> BRParams -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BRParams -> m a
m 

instance MonadIO m => MonadIO (BR m) where

  {-# INLINE liftIO #-}
  liftIO :: IO a -> BR m a
liftIO = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a)
-> (IO a -> BRParams -> m a) -> IO a -> BR m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> BRParams -> m a
forall a b. a -> b -> a
const (m a -> BRParams -> m a)
-> (IO a -> m a) -> IO a -> BRParams -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance MonadTrans BR where

  {-# INLINE lift #-}
  lift :: m a -> BR m a
lift = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a)
-> (m a -> BRParams -> m a) -> m a -> BR m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> BRParams -> m a
forall a b. a -> b -> a
const

instance MonadFix m => MonadFix (BR m) where

  mfix :: (a -> BR m a) -> BR m a
mfix a -> BR m a
f = 
    (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a) -> (BRParams -> m a) -> BR m a
forall a b. (a -> b) -> a -> b
$ \BRParams
ps ->
    do { rec { a
a <- BRParams -> BR m a -> m a
forall (m :: * -> *) a. BRParams -> BR m a -> m a
invokeBR BRParams
ps (a -> BR m a
f a
a) }; a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a }

instance MonadException m => MonadException (BR m) where

  catchComp :: BR m a -> (e -> BR m a) -> BR m a
catchComp (BR BRParams -> m a
m) e -> BR m a
h = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a) -> (BRParams -> m a) -> BR m a
forall a b. (a -> b) -> a -> b
$ \BRParams
ps ->
    m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadException m, Exception e) =>
m a -> (e -> m a) -> m a
catchComp (BRParams -> m a
m BRParams
ps) (\e
e -> BR m a -> BRParams -> m a
forall (m :: * -> *) a. BR m a -> BRParams -> m a
unBR (e -> BR m a
h e
e) BRParams
ps)

  finallyComp :: BR m a -> BR m b -> BR m a
finallyComp (BR BRParams -> m a
m1) (BR BRParams -> m b
m2) = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a) -> (BRParams -> m a) -> BR m a
forall a b. (a -> b) -> a -> b
$ \BRParams
ps ->
    m a -> m b -> m a
forall (m :: * -> *) a b. MonadException m => m a -> m b -> m a
finallyComp (BRParams -> m a
m1 BRParams
ps) (BRParams -> m b
m2 BRParams
ps)
  
  throwComp :: e -> BR m a
throwComp e
e = (BRParams -> m a) -> BR m a
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m a) -> BR m a) -> (BRParams -> m a) -> BR m a
forall a b. (a -> b) -> a -> b
$ \BRParams
ps ->
    e -> m a
forall (m :: * -> *) e a.
(MonadException m, Exception e) =>
e -> m a
throwComp e
e

-- | Invoke the computation.
invokeBR :: BRParams -> BR m a -> m a
{-# INLINE invokeBR #-}
invokeBR :: BRParams -> BR m a -> m a
invokeBR BRParams
ps (BR BRParams -> m a
m) = BRParams -> m a
m BRParams
ps

-- | Run the branching computation.
runBR :: MonadIO m => BR m a -> m a
{-# INLINABLE runBR #-}
runBR :: BR m a -> m a
runBR BR m a
m =
  do BRParams
ps <- IO BRParams -> m BRParams
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO BRParams
newRootBRParams
     BR m a -> BRParams -> m a
forall (m :: * -> *) a. BR m a -> BRParams -> m a
unBR BR m a
m BRParams
ps

-- | Create a new child branch.
newBRParams :: BRParams -> IO BRParams
newBRParams :: BRParams -> IO BRParams
newBRParams BRParams
ps =
  do Int
id <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (BRParams -> IORef Int
brIdGenerator BRParams
ps) ((Int -> (Int, Int)) -> IO Int) -> (Int -> (Int, Int)) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Int
a ->
       let b :: Int
b = Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
b Int -> (Int, Int) -> (Int, Int)
`seq` (Int
b, Int
b)
     let level :: Int
level = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ BRParams -> Int
brLevel BRParams
ps
     IORef ()
uniqueRef <- () -> IO (IORef ())
forall a. a -> IO (IORef a)
newIORef ()
     BRParams -> IO BRParams
forall (m :: * -> *) a. Monad m => a -> m a
return BRParams :: Int -> IORef Int -> Int -> Maybe BRParams -> IORef () -> BRParams
BRParams { brId :: Int
brId = Int
id,
                       brIdGenerator :: IORef Int
brIdGenerator = BRParams -> IORef Int
brIdGenerator BRParams
ps,
                       brLevel :: Int
brLevel = Int
level Int -> Int -> Int
`seq` Int
level,
                       brParent :: Maybe BRParams
brParent = BRParams -> Maybe BRParams
forall a. a -> Maybe a
Just BRParams
ps,
                       brUniqueRef :: IORef ()
brUniqueRef = IORef ()
uniqueRef }

-- | Create a root branch.
newRootBRParams :: IO BRParams
newRootBRParams :: IO BRParams
newRootBRParams =
  do IORef Int
genId <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
     IORef ()
uniqueRef <- () -> IO (IORef ())
forall a. a -> IO (IORef a)
newIORef ()
     BRParams -> IO BRParams
forall (m :: * -> *) a. Monad m => a -> m a
return BRParams :: Int -> IORef Int -> Int -> Maybe BRParams -> IORef () -> BRParams
BRParams { brId :: Int
brId = Int
0,
                       brIdGenerator :: IORef Int
brIdGenerator = IORef Int
genId,
                       brLevel :: Int
brLevel = Int
0,
                       brParent :: Maybe BRParams
brParent = Maybe BRParams
forall a. Maybe a
Nothing,
                       brUniqueRef :: IORef ()
brUniqueRef = IORef ()
uniqueRef
                     }

-- | Return the current branch level starting from 0.
branchLevel :: Monad m => BR m Int
{-# INLINABLE branchLevel #-}
branchLevel :: BR m Int
branchLevel = (BRParams -> m Int) -> BR m Int
forall (m :: * -> *) a. (BRParams -> m a) -> BR m a
BR ((BRParams -> m Int) -> BR m Int)
-> (BRParams -> m Int) -> BR m Int
forall a b. (a -> b) -> a -> b
$ \BRParams
ps -> Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return (BRParams -> Int
brLevel BRParams
ps)