{-# LANGUAGE ConstraintKinds      #-}
{-# LANGUAGE CPP                  #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE MagicHash            #-}
{-# LANGUAGE NoImplicitPrelude    #-}
{-# LANGUAGE Rank2Types           #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UnboxedTuples        #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wall #-}

-- | This library implements the ST2 monad, a type using GDP (ghosts of departed proofs)
--   to define shared regions of memory between local mutable state threads. This allows
--   us to define a region of the heap in more minute contours, with each state thread
--   having explicit access to regions in memory. This is achieved using the function `runST2`,
--   which in effects lets the user run a computation that makes use of two partially-overlapping
--   memory regions. Within that computation, the user can run sub-computations bound to one or
--   the other memory region. Furthermore, a sub-computation can move any variable that it owns
--   into the common overlap via `share`.
--
--   An example is shown in below, where one sub-computation creates two cells: one
--   private, and the other shared. A second sub-computation has unconstrained access to the
--   shared cell. Yet even though the private reference is also in scope during the second
--   sub-computation, any attempts to access it there will fail to compile.
--
-- >>> :{
-- stSharingDemo :: Bool
-- stSharingDemo = runST2 $ do
--   -- In the "left" memory region, create and return
--   -- two references; one shared, and one not shared.
--   (secret, ref) <- liftL $ do
--     unshared <- newSTRef 42
--     shared   <- share =<< newSTRef 17
--     return (unshared, shared)
--   -- In the "right" memory region, mutate the shared
--   -- reference. If we attempt to access the non-shared
--   -- reference here, the program will not compile.
--   liftR $ do
--     let mine = use (symm ref)
--     x <- readSTRef mine
--     writeSTRef mine (x + 1)
--   -- Back in the "left" memory region, verify that the
--   -- unshared reference still holds its original value.
--   liftL $ do
--     check <- readSTRef secret
--     return (check == 42) 
-- :}
module Control.Monad.ST2
  ( -- * ST2 API
    ST(..), STRep
  , fixST
  , liftST

  , runST

  , STRef(..)
  , newSTRef
  , readSTRef
  , writeSTRef

  , type ()
  , Common
  , share
  , liftL
  , liftR
  , use
  , symm
  , runST2

    -- * Unsafe API
  , toBaseST
  , fromBaseST

  , STret
  , unsafeInterleaveST, unsafeDupableInterleaveST
  , stToPrim
  , unsafePrimToST
  , unsafeSTToPrim
  , unsafeInlineST
  , stToIO
  , ioToST
  , RealWorld
  ) where

import           Control.Applicative (Applicative(pure, (*>), (<*>), liftA2))
import           Control.Exception.Base (catch, throwIO, NonTermination(..), BlockedIndefinitelyOnMVar(..))
import           Control.Monad (Monad(return, (>>=), (>>)), ap, liftM2)
import           Control.Monad.Primitive (PrimMonad(primitive, PrimState), PrimBase(internal), primToPrim, unsafePrimToPrim, unsafeInlinePrim)
import qualified Control.Monad.ST as BaseST
import           Data.Eq (Eq((==)))
import           Data.Function (($), (.))
import           Data.Functor (Functor(fmap))
#if !(MIN_VERSION_base(4,11,0))
import           Data.Monoid (Monoid(mempty, mappend))
#else
import           Data.Monoid (Monoid(mempty))
#endif
import           Data.Semigroup (Semigroup((<>)))
import           GHC.IO (IO(IO),unsafeDupableInterleaveIO)
import           GHC.MVar (readMVar, putMVar, newEmptyMVar)
import           GHC.Prim (State#, realWorld#, unsafeCoerce#, MutVar#, newMutVar#, readMutVar#, writeMutVar#, sameMutVar#, RealWorld, noDuplicate#)
import           GHC.Show (Show(showsPrec, showList), showString, showList__)
import           GHC.Types (RuntimeRep, TYPE, Any, isTrue#)
import           Theory.Named (type (~~))
import           Unsafe.Coerce (unsafeCoerce)

-- | Convert an ST2 to an ST
toBaseST :: ST s a -> BaseST.ST s a
{-# INLINE toBaseST #-}
toBaseST = unsafeCoerce

-- | Convert an ST to an ST2
fromBaseST :: BaseST.ST s a -> ST s a
{-# INLINE fromBaseST #-}
fromBaseST = unsafeCoerce

-- The state-transformer monad proper.  By default the monad is strict;
-- too many people got bitten by space leaks when it was lazy.

-- | The strict state-transformer monad.
-- A computation of type @'ST' s a@ transforms an internal state indexed
-- by @s@, and returns a value of type @a@.
-- The @s@ parameter is either
--
-- * an uninstantiated type variable (inside invocations of 'runST'), or
--
-- * 'RealWorld' (inside invocations of 'Control.Monad.ST.stToIO').
--
-- It serves to keep the internal states of different invocations
-- of 'runST' separate from each other and from invocations of
-- 'Control.Monad.ST.stToIO'.
--
-- The '>>=' and '>>' operations are strict in the state (though not in
-- values stored in the state).  For example,
--
-- @'runST' (writeSTRef _|_ v >>= f) = _|_@
newtype ST s a = ST (STRep (Any ~~ s) a)

-- | Convenience type alias for expressing ST computations
--   more succintly.
type STRep s a = State# s -> (# State# s, a #)

instance PrimMonad (ST s) where
  type PrimState (ST s) = s
  primitive = ST . repToAny#
  {-# INLINE primitive #-}

instance PrimBase (ST s) where
  internal (ST p) = repFromAny# p
  {-# INLINE internal #-}

-- | A simple product type containing both the state thread and value inside of the 'ST'
data STret s a = STret (State# s) a

-- | 'liftST' is useful when we want a lifted result from an 'ST' computation. See
--   'fixST' below.
liftST :: ST s a -> State# s -> STret s a
liftST (ST m) = \s -> case m (unsafeCoerce# s) of (# s', r #) -> STret (unsafeCoerce# s') r

noDuplicateST :: ST s ()
{-# INLINE noDuplicateST #-}
noDuplicateST = ST $ \s -> (# noDuplicate# s, () #)

-- | 'unsafeInterleaveST' allows an 'ST' computation to be deferred
-- lazily.  When passed a value of type @ST a@, the 'ST' computation will
-- only be performed when the value of the @a@ is demanded.
{-# INLINE unsafeInterleaveST #-}
unsafeInterleaveST :: ST s a -> ST s a
unsafeInterleaveST m = unsafeDupableInterleaveST (noDuplicateST >> m)

-- | 'unsafeDupableInterleaveST' allows an 'ST' computation to be deferred
-- lazily.  When passed a value of type @ST a@, the 'ST' computation will
-- only be performed when the value of the @a@ is demanded.
--
-- The computation may be performed multiple times by different threads,
-- possibly at the same time. To prevent this, use 'unsafeInterleaveST' instead.
{-# NOINLINE unsafeDupableInterleaveST #-}
-- See Note [unsafeDupableInterleaveIO should not be inlined]
-- in GHC.IO.Unsafe
unsafeDupableInterleaveST :: ST s a -> ST s a
unsafeDupableInterleaveST (ST m) = ST ( \ s ->
    let
        r = case m s of (# _, res #) -> res
    in
    (# s, r #)
  )

-- | Embed a strict state transformer in an 'IO'
-- action.  The 'RealWorld' parameter indicates that the internal state
-- used by the 'ST' computation is a special one supplied by the 'IO'
-- monad, and thus distinct from those used by invocations of 'runST'.
stToIO        :: ST RealWorld a -> IO a
stToIO (ST m) = IO (unsafeCoerce m)

-- | Convert an 'IO' action into an 'ST' action. The type of the result
-- is constrained to use a 'RealWorld' state, and therefore the result cannot
-- be passed to 'runST'.
ioToST        :: IO a -> ST RealWorld a
ioToST (IO m) = ST (unsafeCoerce m)

-- | Convert an 'IO' action to an 'ST' action.
-- This relies on 'IO' and 'ST' having the same representation modulo the
-- constraint on the type of the state.
unsafeIOToST        :: IO a -> ST s a
unsafeIOToST (IO io) = ST $ \ s -> (unsafeCoerce# io) s

-- | Convert an 'ST' action to an 'IO' action.
-- This relies on 'IO' and 'ST' having the same representation modulo the
-- constraint on the type of the state.
--
-- For an example demonstrating why this is unsafe, see
-- https://mail.haskell.org/pipermail/haskell-cafe/2009-April/060719.html
unsafeSTToIO :: ST s a -> IO a
unsafeSTToIO (ST m) = IO (unsafeCoerce# m)

-- | Convert an 'ST' action to a 'PrimMonad'.
stToPrim :: PrimMonad m => ST (PrimState m) a -> m a
{-# INLINE stToPrim #-}
stToPrim = primToPrim

-- | Convert any 'PrimBase' to 'ST' with an arbitrary state token. This operation is highly unsafe!
unsafePrimToST :: PrimBase m => m a -> ST s a
{-# INLINE unsafePrimToST #-}
unsafePrimToST = unsafePrimToPrim

-- | Convert an 'ST' action with an arbitrary state token to any 'PrimMonad'. This operation is highly unsafe!
unsafeSTToPrim :: PrimBase m => ST s a -> m a
{-# INLINE unsafeSTToPrim #-}
unsafeSTToPrim = unsafePrimToPrim

-- | Extract the value out of the 'ST' computation. Compare this to 'runST'; in this case, the rest of the program
--   is permitted to reference the state thread 's'. This operation is highly unsafe!
unsafeInlineST :: ST s a -> a
{-# INLINE unsafeInlineST #-}
unsafeInlineST = unsafeInlinePrim

-- | A value of type @STRef s a@ is a mutable variable in state thread @s@,
--   containing a value of type @a@
data STRef s a = STRef (MutVar# s a)

-- |Build a new 'STRef' in the current state thread
newSTRef :: a -> ST s (STRef s a)
newSTRef init = ST $ \s1# ->
    case newMutVar# init (rwFromAny# s1#) of { (# s2#, var# #) ->
      (# (unsafeCoerce# s2#), STRef var# #) }

-- |Read the value of an 'STRef'
readSTRef :: STRef s a -> ST s a
readSTRef (STRef var#) = ST $ \s1# -> rwTupleToAny# (readMutVar# var# (rwFromAny# s1#))

-- | Write a new value into an 'STRef'
writeSTRef :: STRef s a -> a -> ST s ()
writeSTRef (STRef var#) val = ST $ \s1# ->
  case writeMutVar# var# val (rwFromAny# s1#) of
    s2# -> (# (rwToAny# s2#), () #)

-- Just pointer equality on mutable references:
instance Eq (STRef s a) where
  STRef v1# == STRef v2# = isTrue# (sameMutVar# v1# v2#)

instance Functor (ST s) where
  fmap f (ST m) = ST $ \ s ->
    case (m s) of { (# new_s, r #) ->
      (# new_s, f r #) }

instance Applicative (ST s) where
  {-# INLINE pure #-}
  {-# INLINE (*>) #-}
  pure x = ST (\ s -> (# s, x #))
  m *> k = m >>= \ _ -> k
  (<*>) = ap
  liftA2 = liftM2

instance Monad (ST s) where
    {-# INLINE (>>=)  #-}
    (>>) = (*>)
    (ST m) >>= k
      = ST (\ s ->
        case (m s) of { (# new_s, r #) ->
        case (k r) of { ST k2 ->
        (k2 new_s) }})

instance Semigroup a => Semigroup (ST s a) where
  (<>) = liftA2 (<>)

instance Monoid a => Monoid (ST s a) where
  mempty = pure mempty
#if !(MIN_VERSION_base(4,11,0))
  mappend = liftA2 mappend
#endif

instance Show (ST s a) where
  showsPrec _ _ = showString "<<ST action>>"
  showList      = showList__ (showsPrec 0)

-- | A pretty type alias for 'Common'.
type s  s' = Common s s'

-- | A type that shows that the state threads s and s' refer to a heap
--   region that overlaps in some way such that s and s' can both access
--   the overlap, while maintaining privacy in their own heap regions.
data Common s s'

-- | Move a variable that you own into a region with
--   common overlap.
share :: STRef s a -> ST s (STRef (Common s s') a)
{-# INLINE share #-}
share = return . unsafeCoerce

-- | Lift an 'ST' computation into a context with another heap region
liftL :: ST s a -> ST (Common s s') a
{-# INLINE liftL #-}
liftL = unsafeCoerce

-- | Lift an 'ST' computation into a context with another heap region
liftR :: ST s' a -> ST (Common s s') a
{-# INLINE liftR #-}
liftR = unsafeCoerce

-- | Given proof that one has access to the heap regions s and s',
--   yield an STRef to the region s.
use :: STRef (Common s s') a -> STRef s a
{-# INLINE use #-}
use = unsafeCoerce

-- | Given proof that one has access to the heap regions s and s',
--   yield an 'STRef' that swaps the order of the regions.
symm :: STRef (Common s s') a -> STRef (Common s' s) a
{-# INLINE symm #-}
symm = unsafeCoerce

-- | Return the value computed by a state transformer computation
--   over a shared heap. The @forall@ ensures that the internal state(s)
--   used by the 'ST' computation is inaccessible to the rest of the program.
runST2 :: (forall s s'. ST (Common s s') a) -> a
{-# INLINE runST2 #-}
runST2 (ST st_rep) = case runRegion# st_rep of (# _, a #) -> a

-- | Return the value computed by a state transformer computation.
--   The @forall@ ensures that the internal state used by the 'ST'
--   computation is inaccessible to the rest of the program.
runST :: (forall s. ST s a) -> a
{-# INLINE runST #-}
runST (ST st_rep) = case runRegion# st_rep of (# _, a #) -> a

-- | Allow the result of a state transformer computation to be used (lazily)
-- inside the computation.
--
-- Note that if @f@ is strict, @'fixST' f = _|_@.
fixST :: (a -> ST s a) -> ST s a
-- See Note [fixST]
fixST k = unsafeIOToST $ do
    m <- newEmptyMVar
    ans <- unsafeDupableInterleaveIO
             (readMVar m `catch` \BlockedIndefinitelyOnMVar ->
                                    throwIO NonTermination)
    result <- unsafeSTToIO (k ans)
    putMVar m result
    return result

{- Note [fixST]
   ~~~~~~~~~~~~
For many years, we implemented fixST much like a pure fixpoint,
using liftST:
  fixST :: (a -> ST s a) -> ST s a
  fixST k = ST $ \ s ->
      let ans       = liftST (k r) s
          STret _ r = ans
      in
      case ans of STret s' x -> (# s', x #)
We knew that lazy blackholing could cause the computation to be re-run if the
result was demanded strictly, but we thought that would be okay in the case of
ST. However, that is not the case (see Trac #15349). Notably, the first time
the computation is executed, it may mutate variables that cause it to behave
*differently* the second time it's run. That may allow it to terminate when it
should not. More frighteningly, Arseniy Alekseyev produced a somewhat contrived
example ( https://mail.haskell.org/pipermail/libraries/2018-July/028889.html )
demonstrating that it can break reasonable assumptions in "trustworthy" code,
causing a memory safety violation. So now we implement fixST much like we do
fixIO. See also the implementation notes for fixIO. Simon Marlow wondered
whether we could get away with an IORef instead of an MVar. I believe we
cannot. The function passed to fixST may spark a parallel computation that
demands the final result. Such a computation should block until the final
result is available.
-}

runRegion# :: forall (r :: RuntimeRep) (o :: TYPE r) s.
           (State# (Any ~~ s) -> o) -> o
runRegion# m = m (rwToAny# realWorld#)
{-# INLINE runRegion# #-}

rwToAny# :: forall s s'. State# s' -> State# (Any ~~ s)
rwToAny# x# = unsafeCoerce# x#
{-# INLINE rwToAny# #-}

rwFromAny# :: forall s s'. State# (Any ~~ s) -> State# s'
rwFromAny# x# = unsafeCoerce# x#
{-# INLINE rwFromAny# #-}

--rwTupleFromAny# :: forall s a. (# State# (Any ~~ s), a #) -> (# State# s, a #)
--rwTupleFromAny# (# x, a #) = (# unsafeCoerce# x, a #)

rwTupleToAny# :: forall s a. (# State# s, a #) -> (# State# (Any ~~ s), a #)
rwTupleToAny# (# x, a #) = (# unsafeCoerce# x, a #)
{-# INLINE rwTupleToAny# #-}

repToAny# :: (State# s -> (# State# s, a #)) -> STRep (Any ~~ s) a
repToAny# = unsafeCoerce#
{-# INLINE repToAny# #-}

repFromAny# :: STRep (Any ~~ s) a -> (State# s -> (# State# s, a #))
repFromAny# = unsafeCoerce#
{-# INLINE repFromAny# #-}