{-# LANGUAGE MagicHash, UnboxedTuples, Rank2Types, FlexibleInstances,
    MultiParamTypeClasses, UndecidableInstances, RecursiveDo #-}
{- |
   Module      :  Control.Monad.ST.Trans
   Copyright   :  Josef Svenningsson 2008-2010
                  (c) The University of Glasgow, 1994-2000
   License     :  BSD
 
   Maintainer  :  josef.svenningsson@gmail.com
   Stability   :  experimental
   Portability :  non-portable (GHC Extensions)

   This module provides the implementation of the 'STT' type for those
   occasions where it's needed in order to implement new liftings through
   operations in other monads.

   Warning! This monad transformer should not be used with monads that
   can contain multiple answers, like the list monad. The reason is that 
   the will be duplicated across the different answers and this cause
   Bad Things to happen (such as loss of referential transparency). Safe 
   monads include the monads State, Reader, Writer, Maybe and 
   combinations of their corresponding monad transformers.
-}
module Control.Monad.ST.Trans.Internal where

import GHC.Base
import GHC.ST hiding (liftST)

import qualified Control.Monad.Fail as MF
import Control.Monad.Fix
import Control.Monad.Trans
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class

#if __GLASGOW_HASKELL__ <= 708
import Control.Applicative
#endif

import Data.Array.ST
import Data.Array.Base
import GHC.Int    (      Int8,  Int16,  Int32,  Int64)
import GHC.Word   (Word, Word8, Word16, Word32, Word64)
import GHC.Ptr    (Ptr, FunPtr)
import GHC.Stable (StablePtr)

-- | 'STT' is the monad transformer providing polymorphic updateable references
newtype STT s m a = STT (State# s -> m (STTRet s a))

unSTT :: STT s m a -> (State# s -> m (STTRet s a))
unSTT (STT f) = f

-- | 'STTRet' is needed to encapsulate the unboxed state token that GHC passes
--   around. This type is essentially a pair, but an ordinary pair is not
--   not allowed to contain unboxed types.
data STTRet s a = STTRet (State# s) a

-- | Lifting the `ST` monad into `STT`. The library uses this function
--   extensively to be able to reuse functions from `ST`.
liftST :: Applicative m => ST s a -> STT s m a
liftST (ST f) = STT (\s -> let (# s', a #) = f s in pure (STTRet s' a))
{-# INLINE liftST #-}

-- All instances have to go in this module because otherwise they
-- would be orphan instances.

instance Monad m => Monad (STT s m) where
  return a = STT $ \st -> return (STTRet st a)
  STT m >>= k = STT $ \st ->
    do ret <- m st
       case ret of
         STTRet new_st a ->
             unSTT (k a) new_st

instance MF.MonadFail m => MF.MonadFail (STT s m) where
  fail msg = lift (fail msg)

instance MonadTrans (STT s) where
  lift m = STT $ \st ->
   do a <- m
      return (STTRet st a)

liftSTT :: STT s m a -> State# s -> m (STTRet s a)
liftSTT (STT m) s = m s

instance (MonadFix m) => MonadFix (STT s m) where
  mfix k = STT $ \ s -> mdo
    ans@(STTRet _ r) <- liftSTT (k r) s
    return ans

instance Functor (STTRet s) where
  fmap f (STTRet s a) = STTRet s (f a)

instance Functor m => Functor (STT s m) where
  fmap f (STT g) = STT $ \s# -> (fmap . fmap) f (g s#)

instance (Monad m, Functor m) => Applicative (STT s m) where
  pure a = STT $ \s# -> return (STTRet s# a)
  (STT m) <*> (STT n) = STT $ \s1 ->
                        do (STTRet s2 f) <- m s1
                           (STTRet s3 x) <- n s2
                           return (STTRet s3 (f x))

-- Instances of other monad classes

instance MonadError e m => MonadError e (STT s m) where
  throwError e = lift (throwError e)
  catchError (STT m) f = STT $ \st -> catchError (m st)
                         (\e -> unSTT (f e) st)

instance MonadReader r m => MonadReader r (STT s m) where
  ask = lift ask
  local f (STT m) = STT $ \st -> local f (m st)

instance MonadState s m => MonadState s (STT s' m) where
  get = lift get
  put s = lift (put s)

instance MonadWriter w m => MonadWriter w (STT s m) where
  tell w = lift (tell w)
  listen (STT m)= STT $ \st1 -> do (STTRet st2 a, w) <- listen (m st1)
                                   return (STTRet st2 (a,w))
  pass (STT m) = STT $ \st1 -> pass (do (STTRet st2 (a,f)) <- m st1
                                        return (STTRet st2 a, f))

-- MArray instances

instance (Applicative m, Monad m) => MArray (STArray s) e (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Bool (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Char (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Int (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Word (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) (Ptr a) (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) (FunPtr a) (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Float (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Double (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) (StablePtr a) (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Int8 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Int16 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Int32 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Int64 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Word8 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Word16 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Word32 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)

instance (Applicative m, Monad m) => MArray (STUArray s) Word64 (STT s m) where
    {-# INLINE getBounds #-}
    getBounds arr = liftST (getBounds arr)
    {-# INLINE getNumElements #-}
    getNumElements arr = liftST (getNumElements arr)
    {-# INLINE newArray #-}
    newArray bnds e = liftST (newArray bnds e)
    {-# INLINE unsafeRead #-}
    unsafeRead arr i = liftST (unsafeRead arr i)
    {-# INLINE unsafeWrite #-}
    unsafeWrite arr i e = liftST (unsafeWrite arr i e)