{-# LANGUAGE ExistentialQuantification, FlexibleContexts, TypeOperators #-}

module Test.IOSpec.STM
   (
   -- * The specification of STM
     STMS
   -- * Atomically
   , atomically
   -- * The STM monad
   , STM
   , TVar
   , newTVar
   , readTVar
   , writeTVar
   , retry
   , orElse
   , check
   )
   where

import Test.IOSpec.VirtualMachine
import Test.IOSpec.Types
import Data.Dynamic
import Data.Maybe (fromJust)
import Control.Monad.State
import Control.Monad (ap)

-- The 'STMS' data type and its instances.
--
-- | An expression of type @IOSpec 'STMS' a@ corresponds to an 'IO'
-- computation that may use 'atomically' and returns a value of type
-- @a@.
--
-- By itself, 'STMS' is not terribly useful. You will probably want
-- to use @IOSpec (ForkS :+: STMS)@.
data STMS a =
  forall b . Atomically (STM b) (b -> a)

instance Functor STMS where
  fmap :: forall a b. (a -> b) -> STMS a -> STMS b
fmap a -> b
f (Atomically STM b
s b -> a
io) = forall a b. STM b -> (b -> a) -> STMS a
Atomically STM b
s (a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
io)

-- | The 'atomically' function atomically executes an 'STM' action.
atomically     :: (STMS :<: f) => STM a -> IOSpec f a
atomically :: forall (f :: * -> *) a. (STMS :<: f) => STM a -> IOSpec f a
atomically STM a
stm = forall (g :: * -> *) (f :: * -> *) a.
(g :<: f) =>
g (IOSpec f a) -> IOSpec f a
inject forall a b. (a -> b) -> a -> b
$ forall a b. STM b -> (b -> a) -> STMS a
Atomically STM a
stm (forall (m :: * -> *) a. Monad m => a -> m a
return)

instance Executable STMS where
  step :: forall a. STMS a -> VM (Step a)
step (Atomically STM b
stm b -> a
b) =
    do Store
state <- forall s (m :: * -> *). MonadState s m => m s
get
       case forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall a. STM a -> VM (Maybe a)
executeSTM STM b
stm) Store
state of
         Done (Maybe b
Nothing,Store
_)         -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Step a
Block
         Done (Just b
x,Store
finalState) -> forall s (m :: * -> *). MonadState s m => s -> m ()
put Store
finalState forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Step a
Step (b -> a
b b
x))
         Effect (Maybe b, Store)
_                        -> forall a. String -> a
internalError String
"Unsafe usage of STM"

-- The 'STM' data type and its instances.
data STM a =
    STMReturn a
  | NewTVar Data (Loc -> STM a)
  | ReadTVar Loc (Data -> STM a)
  | WriteTVar Loc Data (STM a)
  | Retry
  | OrElse (STM a) (STM a)

instance Functor STM where
  fmap :: forall a b. (a -> b) -> STM a -> STM b
fmap a -> b
f (STMReturn a
x)      = forall a. a -> STM a
STMReturn (a -> b
f a
x)
  fmap a -> b
f (NewTVar Data
d Loc -> STM a
io)     = forall a. Data -> (Loc -> STM a) -> STM a
NewTVar Data
d (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Loc -> STM a
io)
  fmap a -> b
f (ReadTVar Loc
l Data -> STM a
io)    = forall a. Loc -> (Data -> STM a) -> STM a
ReadTVar Loc
l (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Data -> STM a
io)
  fmap a -> b
f (WriteTVar Loc
l Data
d STM a
io) = forall a. Loc -> Data -> STM a -> STM a
WriteTVar Loc
l Data
d (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f STM a
io)
  fmap a -> b
_ STM a
Retry              = forall a. STM a
Retry
  fmap a -> b
f (OrElse STM a
io1 STM a
io2)   = forall a. STM a -> STM a -> STM a
OrElse (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f STM a
io1) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f STM a
io2)

instance Applicative STM where
  pure :: forall a. a -> STM a
pure  = forall a. a -> STM a
STMReturn
  <*> :: forall a b. STM (a -> b) -> STM a -> STM b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad STM where
  return :: forall a. a -> STM a
return                = forall a. a -> STM a
STMReturn
  STMReturn a
a >>= :: forall a b. STM a -> (a -> STM b) -> STM b
>>= a -> STM b
f     = a -> STM b
f a
a
  NewTVar Data
d Loc -> STM a
g >>= a -> STM b
f     = forall a. Data -> (Loc -> STM a) -> STM a
NewTVar Data
d (\Loc
l -> Loc -> STM a
g Loc
l forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)
  ReadTVar Loc
l Data -> STM a
g >>= a -> STM b
f    = forall a. Loc -> (Data -> STM a) -> STM a
ReadTVar Loc
l (\Data
d -> Data -> STM a
g Data
d forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)
  WriteTVar Loc
l Data
d STM a
p >>= a -> STM b
f = forall a. Loc -> Data -> STM a -> STM a
WriteTVar Loc
l Data
d (STM a
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)
  STM a
Retry >>= a -> STM b
_           = forall a. STM a
Retry
  OrElse STM a
p STM a
q >>= a -> STM b
f      = forall a. STM a -> STM a -> STM a
OrElse (STM a
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f) (STM a
q forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> STM b
f)

-- | A 'TVar' is a shared, mutable variable used by STM.
newtype TVar a = TVar Loc

-- | The 'newTVar' function creates a new transactional variable.
newTVar   :: Typeable a => a -> STM (TVar a)
newTVar :: forall a. Typeable a => a -> STM (TVar a)
newTVar a
d = forall a. Data -> (Loc -> STM a) -> STM a
NewTVar (forall a. Typeable a => a -> Data
toDyn a
d) (forall a. a -> STM a
STMReturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Loc -> TVar a
TVar)

-- | The 'readTVar' function reads the value stored in a
-- transactional variable.
readTVar          :: Typeable a => TVar a -> STM a
readTVar :: forall a. Typeable a => TVar a -> STM a
readTVar (TVar Loc
l) = forall a. Loc -> (Data -> STM a) -> STM a
ReadTVar Loc
l (forall a. a -> STM a
STMReturn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Typeable a => Data -> Maybe a
fromDynamic)

-- | The 'writeTVar' function overwrites the value stored in a
-- transactional variable.
writeTVar            :: Typeable a => TVar a -> a -> STM ()
writeTVar :: forall a. Typeable a => TVar a -> a -> STM ()
writeTVar (TVar Loc
l) a
d = forall a. Loc -> Data -> STM a -> STM a
WriteTVar Loc
l (forall a. Typeable a => a -> Data
toDyn a
d) (forall a. a -> STM a
STMReturn ())

-- | The 'retry' function abandons a transaction and retries at some
-- later time.
retry :: STM a
retry :: forall a. STM a
retry = forall a. STM a
Retry

-- | The 'check' function checks if its boolean argument holds. If
-- the boolean is true, it returns (); otherwise it calls 'retry'.
check       :: Bool -> STM ()
check :: Bool -> STM ()
check Bool
True  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
check Bool
False = forall a. STM a
retry

-- | The 'orElse' function takes two 'STM' actions @stm1@ and @stm2@ and
-- performs @stm1@. If @stm1@ calls 'retry' it performs @stm2@. If @stm1@
-- succeeds, on the other hand, @stm2@ is not executed.
orElse     :: STM a -> STM a -> STM a
orElse :: forall a. STM a -> STM a -> STM a
orElse STM a
p STM a
q = forall a. STM a -> STM a -> STM a
OrElse STM a
p STM a
q

executeSTM :: STM a -> VM (Maybe a)
executeSTM :: forall a. STM a -> VM (Maybe a)
executeSTM (STMReturn a
x)      = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. Monad m => a -> m a
return a
x)
executeSTM (NewTVar Data
d Loc -> STM a
io)     = do
  Loc
loc <- VM Loc
alloc
  Loc -> Data -> VM ()
updateHeap Loc
loc Data
d
  forall a. STM a -> VM (Maybe a)
executeSTM (Loc -> STM a
io Loc
loc)
executeSTM (ReadTVar Loc
l Data -> STM a
io)    = do
  Loc -> VM (Maybe Data)
lookupHeap Loc
l forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Just Data
d) -> do
  forall a. STM a -> VM (Maybe a)
executeSTM (Data -> STM a
io Data
d)
executeSTM (WriteTVar Loc
l Data
d STM a
io) = do
  Loc -> Data -> VM ()
updateHeap Loc
l Data
d
  forall a. STM a -> VM (Maybe a)
executeSTM STM a
io
executeSTM STM a
Retry              = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
executeSTM (OrElse STM a
p STM a
q)       = do
  Store
state <- forall s (m :: * -> *). MonadState s m => m s
get
  case forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall a. STM a -> VM (Maybe a)
executeSTM STM a
p) Store
state of
    Done (Maybe a
Nothing,Store
_) -> forall a. STM a -> VM (Maybe a)
executeSTM STM a
q
    Done (Just a
x,Store
s)  -> forall s (m :: * -> *). MonadState s m => s -> m ()
put Store
s forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just a
x)
    Effect (Maybe a, Store)
_                -> forall a. String -> a
internalError String
"Unsafe usage of STM"

internalError :: String -> a
internalError :: forall a. String -> a
internalError String
msg = forall a. HasCallStack => String -> a
error (String
"IOSpec.STM: " forall a. [a] -> [a] -> [a]
++ String
msg)