{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK not-home #-}
module Polysemy.Internal.TH.Common
( ConLiftInfo (..)
, getEffectMetadata
, makeMemberConstraint
, makeMemberConstraint'
, makeSemType
, makeInterpreterType
, makeEffectType
, makeUnambiguousSend
, checkExtensions
, foldArrowTs
, splitArrowTs
, pattern (:->)
) where
import Control.Arrow ((>>>))
import Control.Monad
import Data.Bifunctor
import Data.Char (toLower)
import Data.Generics hiding (Fixity)
import Data.List
import qualified Data.Map.Strict as M
import Data.Tuple
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Language.Haskell.TH.PprLib
import Polysemy.Internal (Sem, send)
import Polysemy.Internal.Union (MemberWithError)
#if __GLASGOW_HASKELL__ >= 804
import Prelude hiding ((<>))
#endif
data ConLiftInfo = CLInfo
{
cliEffName :: Name
,
cliEffArgs :: [Type]
,
cliEffRes :: Type
,
cliConName :: Name
,
cliFunName :: Name
,
cliFunFixity :: Maybe Fixity
,
cliFunArgs :: [(Name, Type)]
,
cliFunCxt :: Cxt
,
cliUnionName :: Name
} deriving Show
getEffectMetadata :: Name -> Q (Name, [ConLiftInfo])
getEffectMetadata type_name = do
dt_info <- reifyDatatype type_name
cl_infos <- traverse makeCLInfo $ constructorName <$> datatypeCons dt_info
pure (datatypeName dt_info, cl_infos)
liftFunNameFromCon :: Name -> Name
liftFunNameFromCon n = mkName $
case nameBase n of
':' : cs -> cs
c : cs -> toLower c : cs
"" -> error "liftFunNameFromCon: empty constructor name"
makeCLInfo :: Name -> Q ConLiftInfo
makeCLInfo cliConName = do
(con_type, cliEffName) <- reify cliConName >>= \case
DataConI _ t p -> pure (t, p)
_ -> notDataCon cliConName
let (con_args, [con_return_type]) = splitAtEnd 1
$ splitArrowTs con_type
(ty_con_args, [monad_arg, res_arg]) <-
case splitAtEnd 2 $ tail $ splitAppTs $ con_return_type of
r@(_, [_, _]) -> pure r
_ -> missingEffArgs cliEffName
monad_name <- maybe (argNotVar cliEffName monad_arg)
pure
(tVarName monad_arg)
cliUnionName <- newName "r"
let normalize_types :: (TypeSubstitution t, Data t) => t -> t
normalize_types = replaceMArg monad_name cliUnionName
. simplifyKinds
cliEffArgs = normalize_types ty_con_args
cliEffRes = normalize_types res_arg
cliFunName = liftFunNameFromCon cliConName
cliFunFixity <- reifyFixity cliConName
fun_arg_names <- replicateM (length con_args) $ newName "x"
let cliFunArgs = zip fun_arg_names $ normalize_types con_args
cliFunCxt = topLevelConstraints con_type
pure CLInfo{..}
makeEffectType :: ConLiftInfo -> Type
makeEffectType cli = foldl' AppT (ConT $ cliEffName cli) $ cliEffArgs cli
makeInterpreterType :: ConLiftInfo -> Name -> Type -> Type
makeInterpreterType cli r result = sem_with_eff :-> makeSemType r result where
sem_with_eff = ConT ''Sem `AppT` r_with_eff `AppT` result
r_with_eff = PromotedConsT `AppT` makeEffectType cli `AppT` VarT r
makeMemberConstraint :: Name -> ConLiftInfo -> Pred
makeMemberConstraint r cli = makeMemberConstraint' r $ makeEffectType cli
makeMemberConstraint' :: Name -> Type -> Pred
makeMemberConstraint' r eff = classPred ''MemberWithError [eff, VarT r]
makeSemType :: Name -> Type -> Type
makeSemType r result = ConT ''Sem `AppT` VarT r `AppT` result
makeUnambiguousSend :: Bool -> ConLiftInfo -> Exp
makeUnambiguousSend should_make_sigs cli =
let fun_args_names = fmap fst $ cliFunArgs cli
action = foldl1' AppE
$ ConE (cliConName cli) : (VarE <$> fun_args_names)
eff = foldl' AppT (ConT $ cliEffName cli) $ args
args = (if should_make_sigs then id else map capturableTVars)
$ cliEffArgs cli ++ [sem, cliEffRes cli]
sem = ConT ''Sem `AppT` VarT (cliUnionName cli)
in AppE (VarE 'send) $ SigE action eff
argNotVar :: Name -> Type -> Q a
argNotVar eff_name arg = fail $ show
$ text "Argument ‘" <> ppr arg <> text "’ in effect ‘" <> ppr eff_name
<> text "’ is not a type variable"
checkExtensions :: [Extension] -> Q ()
checkExtensions exts = do
states <- zip exts <$> traverse isExtEnabled exts
maybe (pure ())
(\(ext, _) -> fail $ show
$ char '‘' <> text (show ext) <> char '’'
<+> text "extension needs to be enabled for Polysemy's Template Haskell to work")
(find (not . snd) states)
missingEffArgs :: Name -> Q a
missingEffArgs name = fail $ show
$ text "Effect ‘" <> ppr name
<> text "’ has not enough type arguments"
$+$ nest 4
( text "At least monad and result argument are required, e.g.:"
$+$ nest 4
( text ""
$+$ ppr (DataD [] base args Nothing [] []) <+> text "..."
$+$ text ""
)
)
where
base = capturableBase name
args = PlainTV . mkName <$> ["m", "a"]
notDataCon :: Name -> Q a
notDataCon name = fail $ show
$ char '‘' <> ppr name <> text "’ is not a data constructor"
infixr 1 :->
pattern (:->) :: Type -> Type -> Type
pattern a :-> b <- (removeTyAnns -> ArrowT) `AppT` a `AppT` b where
a :-> b = ArrowT `AppT` a `AppT` b
capturableBase :: Name -> Name
capturableBase = mkName . nameBase
capturableTVars :: Type -> Type
capturableTVars = everywhere $ mkT $ \case
VarT n -> VarT $ capturableBase n
ForallT bs cs t -> ForallT (goBndr <$> bs) (capturableTVars <$> cs) t
where
goBndr (PlainTV n ) = PlainTV $ capturableBase n
goBndr (KindedTV n k) = KindedTV (capturableBase n) $ capturableTVars k
t -> t
foldArrowTs :: Type -> [Type] -> Type
foldArrowTs = foldr (:->)
replaceMArg :: TypeSubstitution t => Name -> Name -> t -> t
replaceMArg m r = applySubstitution $ M.singleton m $ ConT ''Sem `AppT` VarT r
simplifyKinds :: Data t => t -> t
simplifyKinds = everywhere $ mkT $ \case
SigT t StarT -> t
SigT t VarT{} -> t
ForallT bs cs t -> ForallT (goBndr <$> bs) (simplifyKinds <$> cs) t
where
goBndr (KindedTV n StarT ) = PlainTV n
goBndr (KindedTV n VarT{}) = PlainTV n
goBndr b = b
t -> t
splitAppTs :: Type -> [Type]
splitAppTs = removeTyAnns >>> \case
t `AppT` arg -> splitAppTs t ++ [arg]
t -> [t]
splitArrowTs :: Type -> [Type]
splitArrowTs = removeTyAnns >>> \case
t :-> ts -> t : splitArrowTs ts
t -> [t]
tVarName :: Type -> Maybe Name
tVarName = removeTyAnns >>> \case
VarT n -> Just n
_ -> Nothing
topLevelConstraints :: Type -> Cxt
topLevelConstraints = \case
ForallT _ cs _ -> cs
_ -> []
removeTyAnns :: Type -> Type
removeTyAnns = \case
ForallT _ _ t -> removeTyAnns t
SigT t _ -> removeTyAnns t
ParensT t -> removeTyAnns t
t -> t
splitAtEnd :: Int -> [a] -> ([a], [a])
splitAtEnd n = swap . join bimap reverse . splitAt n . reverse