{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE UnboxedTuples       #-}
{-# LANGUAGE ScopedTypeVariables #-}

#include "inline.hs"

-- |
-- Module      : Streamly.Internal.Mutable.Prim.Var
-- Copyright   : (c) 2019 Composewell Technologies
--
-- License     : BSD3
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
-- A mutable variable in a mutation capable monad (IO/ST) holding a 'Prim'
-- value. This allows fast modification because of unboxed storage.
--
-- = Multithread Consistency Notes
--
-- In general, any value that straddles a machine word cannot be guaranteed to
-- be consistently read from another thread without a lock.  GHC heap objects
-- are always machine word aligned, therefore, a 'Var' is also word aligned. On
-- a 64-bit platform, writing a 64-bit aligned type from one thread and reading
-- it from another thread should give consistent old or new value. The same
-- holds true for 32-bit values on a 32-bit platform.

module Streamly.Internal.Mutable.Prim.Var
    (
      Var
    , MonadMut
    , Prim

    -- * Construction
    , newVar

    -- * Write
    , writeVar
    , modifyVar'

    -- * Read
    , readVar
    )
where

import Control.Monad.Primitive (PrimMonad(..), primitive_)
import Data.Primitive.Types (Prim, sizeOf#, readByteArray#, writeByteArray#)
import GHC.Exts (MutableByteArray#, newByteArray#)

-- | A 'Var' holds a single 'Prim' value.
data Var m a = Var (MutableByteArray# (PrimState m))

-- The name PrimMonad does not give a clue what it means, an explicit "Mut"
-- suffix provides a better hint. MonadMut is just a generalization of MonadIO.
--
-- | A monad that allows mutable operations using a state token.
type MonadMut = PrimMonad

-- | Create a new mutable variable.
{-# INLINE newVar #-}
newVar :: forall m a. (MonadMut m, Prim a) => a -> m (Var m a)
newVar :: a -> m (Var m a)
newVar a
x = (State# (PrimState m) -> (# State# (PrimState m), Var m a #))
-> m (Var m a)
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive (\State# (PrimState m)
s# ->
      case Int#
-> State# (PrimState m)
-> (# State# (PrimState m), MutableByteArray# (PrimState m) #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# (a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)) State# (PrimState m)
s# of
        (# State# (PrimState m)
s1#, MutableByteArray# (PrimState m)
arr# #) ->
            case MutableByteArray# (PrimState m)
-> Int# -> a -> State# (PrimState m) -> State# (PrimState m)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> a -> State# s -> State# s
writeByteArray# MutableByteArray# (PrimState m)
arr# Int#
0# a
x State# (PrimState m)
s1# of
                State# (PrimState m)
s2# -> (# State# (PrimState m)
s2#, MutableByteArray# (PrimState m) -> Var m a
forall (m :: * -> *) a. MutableByteArray# (PrimState m) -> Var m a
Var MutableByteArray# (PrimState m)
arr# #)
    )

-- | Write a value to a mutable variable.
{-# INLINE writeVar #-}
writeVar :: (MonadMut m, Prim a) => Var m a -> a -> m ()
writeVar :: Var m a -> a -> m ()
writeVar (Var MutableByteArray# (PrimState m)
arr#) a
x = (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ (MutableByteArray# (PrimState m)
-> Int# -> a -> State# (PrimState m) -> State# (PrimState m)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> a -> State# s -> State# s
writeByteArray# MutableByteArray# (PrimState m)
arr# Int#
0# a
x)

-- | Read a value from a variable.
{-# INLINE readVar #-}
readVar :: (MonadMut m, Prim a) => Var m a -> m a
readVar :: Var m a -> m a
readVar (Var MutableByteArray# (PrimState m)
arr#) = (State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive (MutableByteArray# (PrimState m)
-> Int# -> State# (PrimState m) -> (# State# (PrimState m), a #)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> State# s -> (# State# s, a #)
readByteArray# MutableByteArray# (PrimState m)
arr# Int#
0#)

-- | Modify the value of a mutable variable using a function with strict
-- application.
{-# INLINE modifyVar' #-}
modifyVar' :: (MonadMut m, Prim a) => Var m a -> (a -> a) -> m ()
modifyVar' :: Var m a -> (a -> a) -> m ()
modifyVar' (Var MutableByteArray# (PrimState m)
arr#) a -> a
g = (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
primitive_ ((State# (PrimState m) -> State# (PrimState m)) -> m ())
-> (State# (PrimState m) -> State# (PrimState m)) -> m ()
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s# ->
  case MutableByteArray# (PrimState m)
-> Int# -> State# (PrimState m) -> (# State# (PrimState m), a #)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> State# s -> (# State# s, a #)
readByteArray# MutableByteArray# (PrimState m)
arr# Int#
0# State# (PrimState m)
s# of
    (# State# (PrimState m)
s'#, a
a #) -> let a' :: a
a' = a -> a
g a
a in a
a' a -> State# (PrimState m) -> State# (PrimState m)
`seq` MutableByteArray# (PrimState m)
-> Int# -> a -> State# (PrimState m) -> State# (PrimState m)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> a -> State# s -> State# s
writeByteArray# MutableByteArray# (PrimState m)
arr# Int#
0# a
a' State# (PrimState m)
s'#