{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Language.Haskell.TH.TestUtils
(
QState(..)
, MockedMode(..)
, QMode(..)
, ReifyInfo(..)
, loadNames
, unmockedState
, runTestQ
, runTestQErr
, tryTestQ
) where
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail(..))
#endif
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as Except
import qualified Control.Monad.Trans.Reader as Reader
import qualified Control.Monad.Trans.State as State
import Data.Maybe (fromMaybe)
import Language.Haskell.TH (Name, Q, runIO, runQ)
import Language.Haskell.TH.Syntax (Quasi(..), mkNameU)
import Language.Haskell.TH.TestUtils.QMode
import Language.Haskell.TH.TestUtils.QState
runTestQ :: forall mode a. IsMockedMode mode => QState mode -> Q a -> TestQResult mode a
runTestQ state = fmapResult' (either error id) . tryTestQ state
where
fmapResult' = fmapResult @mode @(Either String a) @a
runTestQErr :: forall mode a. (IsMockedMode mode, Show a) => QState mode -> Q a -> TestQResult mode String
runTestQErr state = fmapResult' (either id (error . mkMsg)) . tryTestQ state
where
fmapResult' = fmapResult @mode @(Either String a) @String
mkMsg a = "Unexpected success: " ++ show a
tryTestQ :: forall mode a. IsMockedMode mode => QState mode -> Q a -> TestQResult mode (Either String a)
tryTestQ state = runResult @mode . runTestQMonad . runQ
where
runTestQMonad =
Except.runExceptT
. (`State.evalStateT` initialInternalState)
. (`Reader.runReaderT` state)
. unTestQ
initialInternalState = InternalState
{ lastErrorReport = Nothing
, newNameCounter = 0
}
data InternalState = InternalState
{ lastErrorReport :: Maybe String
, newNameCounter :: Int
}
newtype TestQ (mode :: MockedMode) a = TestQ
{ unTestQ
:: Reader.ReaderT (QState mode)
( State.StateT InternalState
( Except.ExceptT String
Q
)
)
a
} deriving (Functor, Applicative, Monad)
getState :: TestQ mode (QState mode)
getState = TestQ Reader.ask
getMode :: TestQ mode (QMode mode)
getMode = mode <$> getState
lookupReifyInfo :: (ReifyInfo -> a) -> Name -> TestQ mode a
lookupReifyInfo f name = do
QState{reifyInfo} <- getState
case lookup name reifyInfo of
Just info -> return $ f info
Nothing -> error $ "Cannot reify " ++ show name ++ " (did you mean to add it to reifyInfo?)"
getLastError :: TestQ mode (Maybe String)
getLastError = TestQ . lift $ State.gets lastErrorReport
storeLastError :: String -> TestQ mode ()
storeLastError msg = TestQ . lift $ State.modify (\state -> state { lastErrorReport = Just msg })
getAndIncrementNewNameCounter :: TestQ mode Int
getAndIncrementNewNameCounter = TestQ . lift $ State.state $ \state ->
let n = newNameCounter state
in (n, state { newNameCounter = n + 1 })
throwError :: String -> TestQ mode a
throwError = TestQ . lift . lift . Except.throwE
catchError :: TestQ mode a -> (String -> TestQ mode a) -> TestQ mode a
catchError (TestQ action) handler = TestQ $ catchE' action (unTestQ . handler)
where
catchE' = Reader.liftCatch (State.liftCatch Except.catchE)
liftQ :: Q a -> TestQ mode a
liftQ = TestQ . lift . lift . lift
instance MonadIO (TestQ mode) where
liftIO = liftQ . runIO
instance MonadFail (TestQ mode) where
fail msg = do
lastMessage <- getLastError
throwError $ fromMaybe msg lastMessage
use :: Override mode a -> TestQ mode a
use Override{..} = do
mode <- getMode
case (mode, whenMocked) of
(AllowQ, _) -> liftQ whenAllowed
(_, DoInstead testQ) -> testQ
(_, Unsupported label) -> error $ "Cannot run '" ++ label ++ "' with TestQ"
data Override mode a = Override
{ whenAllowed :: Q a
, whenMocked :: WhenMocked mode a
}
data WhenMocked mode a
= DoInstead (TestQ mode a)
| Unsupported String
instance Quasi (TestQ mode) where
qRunIO io = getMode >>= \case
MockQ -> error "IO actions not allowed"
_ -> liftIO io
qRecover handler action = action `catchError` const handler
qReport False msg = use Override
{ whenAllowed = qReport False msg
, whenMocked = DoInstead $ return ()
}
qReport True msg = storeLastError msg
qNewName name = use Override
{ whenAllowed = qNewName name
, whenMocked = DoInstead $ mkNameU name . fromIntegral <$> getAndIncrementNewNameCounter
}
qLookupName b name = use Override
{ whenAllowed = qLookupName b name
, whenMocked = DoInstead $ do
QState{knownNames} <- getState
return $ lookup name knownNames
}
qReify name = use Override
{ whenAllowed = qReify name
, whenMocked = DoInstead $ lookupReifyInfo reifyInfoInfo name
}
qReifyFixity name = use Override
{ whenAllowed = qReifyFixity name
, whenMocked = DoInstead $ lookupReifyInfo reifyInfoFixity name
}
qReifyRoles name = use Override
{ whenAllowed = qReifyRoles name
, whenMocked = DoInstead $ lookupReifyInfo reifyInfoRoles name >>= \case
Nothing -> error $ "No roles associated with " ++ show name
Just roles -> return roles
}
#if MIN_VERSION_template_haskell(2,16,0)
qReifyType name = use Override
{ whenAllowed = qReifyType name
, whenMocked = DoInstead $ lookupReifyInfo reifyInfoType name
}
#endif
qReifyInstances name types = use Override
{ whenAllowed = qReifyInstances name types
, whenMocked = Unsupported "qReifyInstances"
}
qReifyAnnotations annlookup = use Override
{ whenAllowed = qReifyAnnotations annlookup
, whenMocked = Unsupported "qReifyAnnotations"
}
qReifyModule mod' = use Override
{ whenAllowed = qReifyModule mod'
, whenMocked = Unsupported "qReifyModule"
}
qReifyConStrictness name = use Override
{ whenAllowed = qReifyConStrictness name
, whenMocked = Unsupported "qReifyConStrictness"
}
qLocation = use Override
{ whenAllowed = qLocation
, whenMocked = Unsupported "qLocation"
}
qAddDependentFile fp = use Override
{ whenAllowed = qAddDependentFile fp
, whenMocked = Unsupported "qAddDependentFile"
}
qAddTopDecls decls = use Override
{ whenAllowed = qAddTopDecls decls
, whenMocked = Unsupported "qAddTopDecls"
}
qAddModFinalizer q = use Override
{ whenAllowed = qAddModFinalizer q
, whenMocked = Unsupported "qAddModFinalizer"
}
qGetQ = use Override
{ whenAllowed = qGetQ
, whenMocked = Unsupported "qGetQ"
}
qPutQ a = use Override
{ whenAllowed = qPutQ a
, whenMocked = Unsupported "qPutQ"
}
qIsExtEnabled ext = use Override
{ whenAllowed = qIsExtEnabled ext
, whenMocked = Unsupported "qIsExtEnabled"
}
qExtsEnabled = use Override
{ whenAllowed = qExtsEnabled
, whenMocked = Unsupported "qExtsEnabled"
}
#if MIN_VERSION_template_haskell(2,13,0)
qAddCorePlugin plugin = use Override
{ whenAllowed = qAddCorePlugin plugin
, whenMocked = Unsupported "qAddCorePlugin"
}
#endif
#if MIN_VERSION_template_haskell(2,14,0)
qAddTempFile suffix = use Override
{ whenAllowed = qAddTempFile suffix
, whenMocked = Unsupported "qAddTempFile"
}
qAddForeignFilePath lang fp = use Override
{ whenAllowed = qAddForeignFilePath lang fp
, whenMocked = Unsupported "qAddForeignFilePath"
}
#elif MIN_VERSION_template_haskell(2,12,0)
qAddForeignFile lang fp = use Override
{ whenAllowed = qAddForeignFile lang fp
, whenMocked = Unsupported "qAddForeignFile"
}
#endif