{-# LANGUAGE Unsafe #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash, UnboxedTuples, RankNTypes #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.ST.Lazy.Imp
-- Copyright   :  (c) The University of Glasgow 2001
-- License     :  BSD-style (see the file libraries/base/LICENSE)
-- 
-- Maintainer  :  libraries@haskell.org
-- Stability   :  provisional
-- Portability :  non-portable (requires universal quantification for runST)
--
-- This module presents an identical interface to "Control.Monad.ST",
-- except that the monad delays evaluation of 'ST' operations until
-- a value depending on them is required.
--
-----------------------------------------------------------------------------

module Control.Monad.ST.Lazy.Imp (
        -- * The 'ST' monad
        ST,
        runST,
        fixST,

        -- * Converting between strict and lazy 'ST'
        strictToLazyST, lazyToStrictST,

        -- * Converting 'ST' To 'IO'
        RealWorld,
        stToIO,

        -- * Unsafe operations
        unsafeInterleaveST,
        unsafeIOToST
    ) where

import Control.Monad.Fix

import qualified Control.Monad.ST as ST
import qualified Control.Monad.ST.Unsafe as ST

import qualified GHC.ST as GHC.ST
import GHC.Base
import qualified Control.Monad.Fail as Fail

-- | The lazy @'ST' monad.
-- The ST monad allows for destructive updates, but is escapable (unlike IO).
-- A computation of type @'ST' s a@ returns a value of type @a@, and
-- execute in "thread" @s@. The @s@ parameter is either
--
-- * an uninstantiated type variable (inside invocations of 'runST'), or
--
-- * 'RealWorld' (inside invocations of 'stToIO').
--
-- It serves to keep the internal states of different invocations of
-- 'runST' separate from each other and from invocations of 'stToIO'.
--
-- The '>>=' and '>>' operations are not strict in the state.  For example,
--
-- @'runST' (writeSTRef _|_ v >>= readSTRef _|_ >> return 2) = 2@
newtype ST s a = ST { ST s a -> State s -> (a, State s)
unST :: State s -> (a, State s) }

-- A lifted state token. This can be imagined as a moment in the timeline
-- of a lazy state thread. Forcing the token forces all delayed actions in
-- the thread up until that moment to be performed.
data State s = S# (State# s)

{- Note [Lazy ST and multithreading]

We used to imagine that passing a polymorphic state token was all that we
needed to keep state threads separate (see Launchbury and Peyton Jones, 1994:
https://www.microsoft.com/en-us/research/publication/lazy-functional-state-threads/).
But this breaks down in the face of concurrency (see #11760). Whereas a strict
ST computation runs to completion before producing anything, a value produced
by running a lazy ST computation may contain a thunk that, when forced, will
lead to further stateful computations. If such a thunk is entered by more than
one thread, then they may both read from and write to the same references and
arrays, interfering with each other. To work around this, any time we lazily
suspend execution of a lazy ST computation, we bind the result pair to a
NOINLINE binding (ensuring that it is not duplicated) and calculate that
pair using (unsafePerformIO . evaluate), ensuring that only one thread will
enter the thunk. We still use lifted state tokens to actually drive execution,
so in these cases we effectively deal with *two* state tokens: the lifted
one we get from the previous computation, and the unlifted one we pull out of
thin air. -}

{- Note [Lazy ST: not producing lazy pairs]

The fixST and strictToLazyST functions used to construct functions that
produced lazy pairs. Why don't we need that laziness? The ST type is kept
abstract, so no one outside this module can ever get their hands on a (result,
State s) pair. We ourselves never match on such pairs when performing ST
computations unless we also force one of their components. So no one should be
able to detect the change. By refraining from producing such thunks (which
reference delayed ST computations), we avoid having to ask whether we have to
wrap them up with unsafePerformIO. See Note [Lazy ST and multithreading]. -}

-- | This is a terrible hack to prevent a thunk from being entered twice.
-- Simon Peyton Jones would very much like to be rid of it.
noDup :: a -> a
noDup :: a -> a
noDup a :: a
a = (State# RealWorld -> a) -> a
forall o. (State# RealWorld -> o) -> o
runRW# (\s :: State# RealWorld
s ->
  case State# RealWorld -> State# RealWorld
forall d. State# d -> State# d
noDuplicate# State# RealWorld
s of
    _ -> a
a)

-- | @since 2.01
instance Functor (ST s) where
    fmap :: (a -> b) -> ST s a -> ST s b
fmap f :: a -> b
f m :: ST s a
m = (State s -> (b, State s)) -> ST s b
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (b, State s)) -> ST s b)
-> (State s -> (b, State s)) -> ST s b
forall a b. (a -> b) -> a -> b
$ \ s :: State s
s ->
      let
        -- See Note [Lazy ST and multithreading]
        {-# NOINLINE res #-}
        res :: (a, State s)
res = (a, State s) -> (a, State s)
forall a. a -> a
noDup (ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s a
m State s
s)
        (r :: a
r,new_s :: State s
new_s) = (a, State s)
res
      in
        (a -> b
f a
r,State s
new_s)

    x :: a
x <$ :: a -> ST s b -> ST s a
<$ m :: ST s b
m = (State s -> (a, State s)) -> ST s a
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (a, State s)) -> ST s a)
-> (State s -> (a, State s)) -> ST s a
forall a b. (a -> b) -> a -> b
$ \ s :: State s
s ->
      let
        {-# NOINLINE s' #-}
        -- See Note [Lazy ST and multithreading]
        s' :: State s
s' = State s -> State s
forall a. a -> a
noDup ((b, State s) -> State s
forall a b. (a, b) -> b
snd (ST s b -> State s -> (b, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s b
m State s
s))
      in (a
x, State s
s')

-- | @since 2.01
instance Applicative (ST s) where
    pure :: a -> ST s a
pure a :: a
a = (State s -> (a, State s)) -> ST s a
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (a, State s)) -> ST s a)
-> (State s -> (a, State s)) -> ST s a
forall a b. (a -> b) -> a -> b
$ \ s :: State s
s -> (a
a,State s
s)

    fm :: ST s (a -> b)
fm <*> :: ST s (a -> b) -> ST s a -> ST s b
<*> xm :: ST s a
xm = (State s -> (b, State s)) -> ST s b
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (b, State s)) -> ST s b)
-> (State s -> (b, State s)) -> ST s b
forall a b. (a -> b) -> a -> b
$ \ s :: State s
s ->
       let
         {-# NOINLINE res1 #-}
         !res1 :: (a -> b, State s)
res1 = ST s (a -> b) -> State s -> (a -> b, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s (a -> b)
fm State s
s
         !(f :: a -> b
f, s' :: State s
s') = (a -> b, State s)
res1

         {-# NOINLINE res2 #-}
         -- See Note [Lazy ST and multithreading]
         res2 :: (a, State s)
res2 = (a, State s) -> (a, State s)
forall a. a -> a
noDup (ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s a
xm State s
s')
         (x :: a
x, s'' :: State s
s'') = (a, State s)
res2
       in (a -> b
f a
x, State s
s'')
    -- Why can we use a strict binding for res1? If someone
    -- forces the (f x, s'') pair, then they must need
    -- f or s''. To get s'', they need s'.

    liftA2 :: (a -> b -> c) -> ST s a -> ST s b -> ST s c
liftA2 f :: a -> b -> c
f m :: ST s a
m n :: ST s b
n = (State s -> (c, State s)) -> ST s c
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (c, State s)) -> ST s c)
-> (State s -> (c, State s)) -> ST s c
forall a b. (a -> b) -> a -> b
$ \ s :: State s
s ->
      let
        {-# NOINLINE res1 #-}
        -- See Note [Lazy ST and multithreading]
        res1 :: (a, State s)
res1 = (a, State s) -> (a, State s)
forall a. a -> a
noDup (ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s a
m State s
s)
        (x :: a
x, s' :: State s
s') = (a, State s)
res1

        {-# NOINLINE res2 #-}
        res2 :: (b, State s)
res2 = (b, State s) -> (b, State s)
forall a. a -> a
noDup (ST s b -> State s -> (b, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s b
n State s
s')
        (y :: b
y, s'' :: State s
s'') = (b, State s)
res2
      in (a -> b -> c
f a
x b
y, State s
s'')
    -- We don't get to be strict in liftA2, but we clear out a
    -- NOINLINE in comparison to the default definition, which may
    -- help the simplifier.

    m :: ST s a
m *> :: ST s a -> ST s b -> ST s b
*> n :: ST s b
n = (State s -> (b, State s)) -> ST s b
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (b, State s)) -> ST s b)
-> (State s -> (b, State s)) -> ST s b
forall a b. (a -> b) -> a -> b
$ \s :: State s
s ->
       let
         {-# NOINLINE s' #-}
         -- See Note [Lazy ST and multithreading]
         s' :: State s
s' = State s -> State s
forall a. a -> a
noDup ((a, State s) -> State s
forall a b. (a, b) -> b
snd (ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s a
m State s
s))
       in ST s b -> State s -> (b, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s b
n State s
s'

    m :: ST s a
m <* :: ST s a -> ST s b -> ST s a
<* n :: ST s b
n = (State s -> (a, State s)) -> ST s a
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (a, State s)) -> ST s a)
-> (State s -> (a, State s)) -> ST s a
forall a b. (a -> b) -> a -> b
$ \s :: State s
s ->
       let
         {-# NOINLINE res1 #-}
         !res1 :: (a, State s)
res1 = ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s a
m State s
s
         !(mr :: a
mr, s' :: State s
s') = (a, State s)
res1

         {-# NOINLINE s'' #-}
         -- See Note [Lazy ST and multithreading]
         s'' :: State s
s'' = State s -> State s
forall a. a -> a
noDup ((b, State s) -> State s
forall a b. (a, b) -> b
snd (ST s b -> State s -> (b, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s b
n State s
s'))
       in (a
mr, State s
s'')
    -- Why can we use a strict binding for res1? The same reason as
    -- in <*>. If someone demands the (mr, s'') pair, then they will
    -- force mr or s''. To get s'', they need s'.

-- | @since 2.01
instance Monad (ST s) where
    >> :: ST s a -> ST s b -> ST s b
(>>) = ST s a -> ST s b -> ST s b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
(*>)

    m :: ST s a
m >>= :: ST s a -> (a -> ST s b) -> ST s b
>>= k :: a -> ST s b
k = (State s -> (b, State s)) -> ST s b
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (b, State s)) -> ST s b)
-> (State s -> (b, State s)) -> ST s b
forall a b. (a -> b) -> a -> b
$ \ s :: State s
s ->
       let
         -- See Note [Lazy ST and multithreading]
         {-# NOINLINE res #-}
         res :: (a, State s)
res = (a, State s) -> (a, State s)
forall a. a -> a
noDup (ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST ST s a
m State s
s)
         (r :: a
r,new_s :: State s
new_s) = (a, State s)
res
       in
         ST s b -> State s -> (b, State s)
forall s a. ST s a -> State s -> (a, State s)
unST (a -> ST s b
k a
r) State s
new_s

-- | @since 4.10
instance Fail.MonadFail (ST s) where
    fail :: String -> ST s a
fail s :: String
s = String -> ST s a
forall a. String -> a
errorWithoutStackTrace String
s

-- | Return the value computed by an 'ST' computation.
-- The @forall@ ensures that the internal state used by the 'ST'
-- computation is inaccessible to the rest of the program.
runST :: (forall s. ST s a) -> a
runST :: (forall s. ST s a) -> a
runST (ST st) = (State# RealWorld -> a) -> a
forall o. (State# RealWorld -> o) -> o
runRW# (\s :: State# RealWorld
s -> case State RealWorld -> (a, State RealWorld)
st (State# RealWorld -> State RealWorld
forall s. State# s -> State s
S# State# RealWorld
s) of (r :: a
r, _) -> a
r)

-- | Allow the result of an 'ST' computation to be used (lazily)
-- inside the computation.
-- Note that if @f@ is strict, @'fixST' f = _|_@.
fixST :: (a -> ST s a) -> ST s a
fixST :: (a -> ST s a) -> ST s a
fixST m :: a -> ST s a
m = (State s -> (a, State s)) -> ST s a
forall s a. (State s -> (a, State s)) -> ST s a
ST (\ s :: State s
s -> 
                let 
                   q :: (a, State s)
q@(r :: a
r,_s' :: State s
_s') = ST s a -> State s -> (a, State s)
forall s a. ST s a -> State s -> (a, State s)
unST (a -> ST s a
m a
r) State s
s
                in (a, State s)
q)
-- Why don't we need unsafePerformIO in fixST? We create a thunk, q,
-- to perform a lazy state computation, and we pass a reference to that
-- thunk, r, to m. Uh oh? No, I think it should be fine, because that thunk
-- itself is demanded directly in the `let` body. See also
-- Note [Lazy ST: not producing lazy pairs].

-- | @since 2.01
instance MonadFix (ST s) where
        mfix :: (a -> ST s a) -> ST s a
mfix = (a -> ST s a) -> ST s a
forall a s. (a -> ST s a) -> ST s a
fixST

-- ---------------------------------------------------------------------------
-- Strict <--> Lazy

{-|
Convert a strict 'ST' computation into a lazy one.  The strict state
thread passed to 'strictToLazyST' is not performed until the result of
the lazy state thread it returns is demanded.
-}
strictToLazyST :: ST.ST s a -> ST s a
strictToLazyST :: ST s a -> ST s a
strictToLazyST (GHC.ST.ST m :: STRep s a
m) = (State s -> (a, State s)) -> ST s a
forall s a. (State s -> (a, State s)) -> ST s a
ST ((State s -> (a, State s)) -> ST s a)
-> (State s -> (a, State s)) -> ST s a
forall a b. (a -> b) -> a -> b
$ \(S# s :: State# s
s) ->
  case STRep s a
m State# s
s of
    (# s' :: State# s
s', a :: a
a #) -> (a
a, State# s -> State s
forall s. State# s -> State s
S# State# s
s')
-- See Note [Lazy ST: not producing lazy pairs]

{-| 
Convert a lazy 'ST' computation into a strict one.
-}
lazyToStrictST :: ST s a -> ST.ST s a
lazyToStrictST :: ST s a -> ST s a
lazyToStrictST (ST m :: State s -> (a, State s)
m) = STRep s a -> ST s a
forall s a. STRep s a -> ST s a
GHC.ST.ST (STRep s a -> ST s a) -> STRep s a -> ST s a
forall a b. (a -> b) -> a -> b
$ \s :: State# s
s ->
        case (State s -> (a, State s)
m (State# s -> State s
forall s. State# s -> State s
S# State# s
s)) of (a :: a
a, S# s' :: State# s
s') -> (# State# s
s', a
a #)

-- | A monad transformer embedding lazy 'ST' in the 'IO'
-- monad.  The 'RealWorld' parameter indicates that the internal state
-- used by the 'ST' computation is a special one supplied by the 'IO'
-- monad, and thus distinct from those used by invocations of 'runST'.
stToIO :: ST RealWorld a -> IO a
stToIO :: ST RealWorld a -> IO a
stToIO = ST RealWorld a -> IO a
forall a. ST RealWorld a -> IO a
ST.stToIO (ST RealWorld a -> IO a)
-> (ST RealWorld a -> ST RealWorld a) -> ST RealWorld a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST RealWorld a -> ST RealWorld a
forall s a. ST s a -> ST s a
lazyToStrictST

-- ---------------------------------------------------------------------------
-- Strict <--> Lazy

unsafeInterleaveST :: ST s a -> ST s a
unsafeInterleaveST :: ST s a -> ST s a
unsafeInterleaveST = ST s a -> ST s a
forall s a. ST s a -> ST s a
strictToLazyST (ST s a -> ST s a) -> (ST s a -> ST s a) -> ST s a -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST s a -> ST s a
forall s a. ST s a -> ST s a
ST.unsafeInterleaveST (ST s a -> ST s a) -> (ST s a -> ST s a) -> ST s a -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST s a -> ST s a
forall s a. ST s a -> ST s a
lazyToStrictST

unsafeIOToST :: IO a -> ST s a
unsafeIOToST :: IO a -> ST s a
unsafeIOToST = ST s a -> ST s a
forall s a. ST s a -> ST s a
strictToLazyST (ST s a -> ST s a) -> (IO a -> ST s a) -> IO a -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> ST s a
forall a s. IO a -> ST s a
ST.unsafeIOToST