{-# LANGUAGE Trustworthy, TemplateHaskell, LambdaCase, ViewPatterns #-}
module Data.Extensible.TH (mkField
, mkFieldAs
, decEffects
, decEffectSet
, decEffectSuite
, customDecEffects) where
import Data.Extensible.Class (itemAssoc)
import Data.Extensible.Effect
import Data.Extensible.Field
import Data.List (nub)
import Language.Haskell.TH
import Data.Char
import Control.Monad
import Type.Membership
mkField :: String -> DecsQ
mkField str = fmap concat $ forM (words str) $ \s@(x:xs) ->
let name = mkName $ if isLower x then x : xs else '_' : x : xs
in mkFieldAs name s
mkFieldAs :: Name -> String -> DecsQ
mkFieldAs name s = do
let st = litT (strTyLit s)
let lbl = conE 'Proxy `sigE` (conT ''Proxy `appT` st)
sequence [sigD name $ conT ''FieldOptic `appT` st
, valD (varP name) (normalB $ varE 'itemAssoc `appE` lbl) []
, return $ PragmaD $ InlineP name Inline FunLike AllPhases
]
decEffects :: DecsQ -> DecsQ
decEffects = customDecEffects False True
decEffectSet :: DecsQ -> DecsQ
decEffectSet = customDecEffects True False
decEffectSuite :: DecsQ -> DecsQ
decEffectSuite = customDecEffects True True
customDecEffects :: Bool
-> Bool
-> DecsQ -> DecsQ
customDecEffects synSet synActions decs = decs >>= \ds -> fmap concat $ forM ds $ \case
#if MIN_VERSION_template_haskell(2,11,0)
DataD _ dataName tparams _ cs _
#else
DataD _ dataName tparams cs _
#endif
-> do
(cxts, dcs) <- fmap unzip $ traverse (con2Eff tparams) cs
let vars = map PlainTV $ nub $ concatMap (varsT . snd) cxts
return $ [TySynD dataName vars (typeListT $ map snd cxts) | synSet]
++ [ TySynD k (map PlainTV $ nub $ varsT t) t | synActions, (k, t) <- cxts]
++ concat dcs
_ -> fail "mkEffects accepts GADT declaration"
con2Eff :: [TyVarBndr] -> Con -> Q ((Name, Type), [Dec])
#if MIN_VERSION_template_haskell(2,11,0)
con2Eff _ (GadtC [name] st (AppT _ resultT))
= return $ effectFunD name (map snd st) resultT
#endif
con2Eff tparams (ForallC _ eqs (NormalC name st))
= return $ fromMangledGADT tparams eqs name st
con2Eff tparams (ForallC _ _ c) = con2Eff tparams c
con2Eff _ p = do
runIO (print p)
fail "Unsupported constructor"
fromMangledGADT :: [TyVarBndr] -> [Type] -> Name -> [(Strict, Type)] -> ((Name, Type), [Dec])
fromMangledGADT tyvars_ eqs con fieldTypes
= effectFunD con argumentsT result
where
getTV (PlainTV n) = n
getTV (KindedTV n _) = n
tyvars = map getTV tyvars_
dic_ = [(v, t) | AppT (AppT EqualityT (VarT v)) t <- eqs]
dic = dic_ ++ [(t, VarT v) | (v, VarT t) <- dic_]
params' = do
(t, v) <- zip tyvars uniqueNames
case lookup t dic of
Just (VarT p) -> return (t, p)
_ -> return (t, v)
argumentsT = map (\case
(_, VarT n) -> maybe (VarT n) VarT $ lookup n params'
(_, x) -> x) fieldTypes
result = case lookup (last tyvars) dic of
Just (VarT v) -> case lookup v params' of
Just p -> VarT p
Nothing -> VarT v
Just t -> t
Nothing -> VarT $ mkName "x"
varsT :: Type -> [Name]
varsT (VarT v) = [v]
varsT (AppT s t) = varsT s ++ varsT t
varsT _ = []
effectFunD :: Name
-> [Type]
-> Type
-> ((Name, Type), [Dec])
effectFunD key argumentsT resultT = ((key, PromotedT '(:>) `AppT` nameT `AppT` actionT)
, [SigD fName typ, FunD fName [effClause nameT (length argumentsT)]]) where
varList = mkName "xs"
fName = let (ch : rest) = nameBase key in mkName $ toLower ch : rest
typ = ForallT (map PlainTV $ varList : varsT resultT ++ concatMap varsT argumentsT)
[associateT nameT actionT varList]
$ effectFunT varList argumentsT resultT
actionT = ConT ''Action `AppT` typeListT argumentsT `AppT` resultT
nameT = LitT $ StrTyLit $ nameBase key
effectFunT :: Name
-> [Type]
-> Type
-> Type
effectFunT varList argumentsT resultT
= foldr (\x y -> ArrowT `AppT` x `AppT` y) rt argumentsT where
rt = ConT ''Eff `AppT` VarT varList `AppT` resultT
uniqueNames :: [Name]
uniqueNames = map mkName $ concatMap (flip replicateM ['a'..'z']) [1..]
typeListT :: [Type] -> Type
typeListT = foldr (\x y -> PromotedConsT `AppT` x `AppT` y) PromotedNilT
associateT :: Type
-> Type
-> Name
-> Type
associateT nameT t xs = ConT ''Lookup `AppT` VarT xs `AppT` nameT `AppT` t
effClause :: Type
-> Int
-> Clause
effClause nameT n = Clause (map VarP argNames) (NormalB rhs) [] where
lifter = VarE 'liftEff `AppE` (ConE 'Proxy `SigE` AppT (ConT ''Proxy) nameT)
argNames = map (mkName . ("a" ++) . show) [0..n-1]
rhs = lifter `AppE` foldr (\x y -> ConE 'AArgument `AppE` x `AppE` y)
(ConE 'AResult)
(map VarE argNames)