{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 800
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
#endif
#include "free-common.h"
module Control.Monad.Free.TH
(
makeFree,
makeFree_,
makeFreeCon,
makeFreeCon_,
) where
import Control.Arrow
import Control.Monad
import Data.Char (toLower)
import Data.List ((\\), nub)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
#endif
data Arg
= Captured Type Exp
| Param Type
deriving (Show)
params :: [Arg] -> [Type]
params [] = []
params (Param t : xs) = t : params xs
params (_ : xs) = params xs
captured :: [Arg] -> [(Type, Exp)]
captured [] = []
captured (Captured t e : xs) = (t, e) : captured xs
captured (_ : xs) = captured xs
zipExprs :: [Exp] -> [Exp] -> [Arg] -> [Exp]
zipExprs (p:ps) cs (Param _ : as) = p : zipExprs ps cs as
zipExprs ps (c:cs) (Captured _ _ : as) = c : zipExprs ps cs as
zipExprs _ _ _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name
findTypeOrFail :: String -> Q Name
findTypeOrFail s = lookupTypeName s >>= maybe (fail $ s ++ " is not in scope") return
findValueOrFail :: String -> Q Name
findValueOrFail s = lookupValueName s >>= maybe (fail $ s ++ "is not in scope") return
mkOpName :: String -> Q String
mkOpName (':':name) = return name
mkOpName ( c :name) = return $ toLower c : name
mkOpName _ = fail "impossible happened: empty (null) constructor name"
usesTV :: Name -> Type -> Bool
usesTV n (VarT name) = n == name
usesTV n (AppT t1 t2) = any (usesTV n) [t1, t2]
usesTV n (SigT t _ ) = usesTV n t
usesTV n (ForallT bs _ t) = usesTV n t && n `notElem` map tyVarBndrName bs
usesTV _ _ = False
mkArg :: Type -> Type -> Q Arg
mkArg (VarT n) t
| usesTV n t =
case t of
VarT _ -> return $ Captured (TupleT 0) (TupE [])
AppT (AppT ArrowT _) _ -> do
(ts, name) <- arrowsToTuple t
when (any (usesTV n) ts) $ fail $ unlines
[ "type variable " ++ pprint n ++ " is forbidden"
, "in a type like (a1 -> ... -> aN -> " ++ pprint n ++ ")"
, "in a constructor's argument type: " ++ pprint t ]
when (name /= n) $ fail $ unlines
[ "expected final return type `" ++ pprint n ++ "'"
, "but got `" ++ pprint name ++ "'"
, "in a constructor's argument type: `" ++ pprint t ++ "'" ]
let tup = nonUnaryTupleT ts
xs <- mapM (const $ newName "x") ts
return $ Captured tup (LamE (map VarP xs) (nonUnaryTupE $ map VarE xs))
_ -> fail $ unlines
[ "expected a type variable `" ++ pprint n ++ "'"
, "or a type like (a1 -> ... -> aN -> " ++ pprint n ++ ")"
, "but got `" ++ pprint t ++ "'"
, "in a constructor's argument" ]
| otherwise = return $ Param t
where
arrowsToTuple (AppT (AppT ArrowT t1) t2) = do
(ts, name) <- arrowsToTuple t2
return (t1:ts, name)
arrowsToTuple (VarT name) = return ([], name)
arrowsToTuple rt = fail $ unlines
[ "expected final return type `" ++ pprint n ++ "'"
, "but got `" ++ pprint rt ++ "'"
, "in a constructor's argument type: `" ++ pprint t ++ "'" ]
nonUnaryTupleT :: [Type] -> Type
nonUnaryTupleT [t'] = t'
nonUnaryTupleT ts = foldl AppT (TupleT $ length ts) ts
nonUnaryTupE :: [Exp] -> Exp
nonUnaryTupE [e] = e
nonUnaryTupE es = TupE $
#if MIN_VERSION_template_haskell(2,16,0)
map Just
#endif
es
mkArg n _ = fail $ unlines
[ "expected a type variable"
, "but got `" ++ pprint n ++ "'"
, "as the last parameter of the type constructor" ]
mapRet :: (Exp -> Exp) -> Exp -> Exp
mapRet f (LamE ps e) = LamE ps $ mapRet f e
mapRet f e = f e
unifyT :: (Type, Exp) -> (Type, Exp) -> Q (Type, [Exp])
unifyT (TupleT 0, _) (TupleT 0, _) = fail "can't accept 2 mere parameters"
unifyT (TupleT 0, _) (t, e) = do
maybe' <- ConT <$> findTypeOrFail "Maybe"
nothing' <- ConE <$> findValueOrFail "Nothing"
just' <- ConE <$> findValueOrFail "Just"
return (AppT maybe' t, [nothing', mapRet (AppE just') e])
unifyT x y@(TupleT 0, _) = second reverse <$> unifyT y x
unifyT (t1, e1) (t2, e2) = do
either' <- ConT <$> findTypeOrFail "Either"
left' <- ConE <$> findValueOrFail "Left"
right' <- ConE <$> findValueOrFail "Right"
return (AppT (AppT either' t1) t2, [mapRet (AppE left') e1, mapRet (AppE right') e2])
unifyCaptured :: Name -> [(Type, Exp)] -> Q (Type, [Exp])
unifyCaptured a [] = return (VarT a, [])
unifyCaptured _ [(t, e)] = return (t, [e])
unifyCaptured _ [x, y] = unifyT x y
unifyCaptured _ xs = fail $ unlines
[ "can't unify more than 2 return types"
, "that use type parameter"
, "when unifying return types: "
, unlines (map (pprint . fst) xs) ]
extractVars :: Type -> [Name]
extractVars (ForallT bs _ t) = extractVars t \\ map bndrName bs
where
bndrName (PlainTV n) = n
bndrName (KindedTV n _) = n
extractVars (VarT n) = [n]
extractVars (AppT x y) = extractVars x ++ extractVars y
#if MIN_VERSION_template_haskell(2,8,0)
extractVars (SigT x k) = extractVars x ++ extractVars k
#else
extractVars (SigT x k) = extractVars x
#endif
#if MIN_VERSION_template_haskell(2,11,0)
extractVars (InfixT x _ y) = extractVars x ++ extractVars y
extractVars (UInfixT x _ y) = extractVars x ++ extractVars y
extractVars (ParensT x) = extractVars x
#endif
extractVars _ = []
liftCon' :: Bool -> [TyVarBndr] -> Cxt -> Type -> Type -> [Type] -> Name -> [Type] -> Q [Dec]
liftCon' typeSig tvbs cx f n ns cn ts = do
opName <- mkName <$> mkOpName (nameBase cn)
m <- newName "m"
a <- newName "a"
monadFree <- findTypeOrFail "MonadFree"
liftF <- findValueOrFail "liftF"
args <- mapM (mkArg n) ts
let ps = params args
cs = captured args
(retType, es) <- unifyCaptured a cs
let opType = foldr (AppT . AppT ArrowT) (AppT (VarT m) retType) ps
xs <- mapM (const $ newName "p") ps
let pat = map VarP xs
exprs = zipExprs (map VarE xs) es args
fval = foldl AppE (ConE cn) exprs
ns' = nub (concatMap extractVars ns)
q = filter nonNext tvbs ++ map PlainTV (qa ++ m : ns')
qa = case retType of VarT b | a == b -> [a]; _ -> []
f' = foldl AppT f ns
return $ concat
[ if typeSig
#if MIN_VERSION_template_haskell(2,10,0)
then [ SigD opName (ForallT q (cx ++ [ConT monadFree `AppT` f' `AppT` VarT m]) opType) ]
#else
then [ SigD opName (ForallT q (cx ++ [ClassP monadFree [f', VarT m]]) opType) ]
#endif
else []
, [ FunD opName [ Clause pat (NormalB $ AppE (VarE liftF) fval) [] ] ] ]
where
nonNext (PlainTV pn) = VarT pn /= n
nonNext (KindedTV kn _) = VarT kn /= n
liftCon :: Bool -> [TyVarBndr] -> Cxt -> Type -> Type -> [Type] -> Maybe [Name] -> Con -> Q [Dec]
liftCon typeSig ts cx f n ns onlyCons con
| not (any (`melem` onlyCons) (constructorNames con)) = return []
| otherwise = case con of
NormalC cName fields -> liftCon' typeSig ts cx f n ns cName $ map snd fields
RecC cName fields -> liftCon' typeSig ts cx f n ns cName $ map (\(_, _, ty) -> ty) fields
InfixC (_,t1) cName (_,t2) -> liftCon' typeSig ts cx f n ns cName [t1, t2]
ForallC ts' cx' con' -> liftCon typeSig (ts ++ ts') (cx ++ cx') f n ns onlyCons con'
#if MIN_VERSION_template_haskell(2,11,0)
GadtC cNames fields resType -> do
decs <- forM (filter (`melem` onlyCons) cNames) $ \cName ->
liftGadtC cName fields resType typeSig ts cx f
return (concat decs)
RecGadtC cNames fields resType -> do
let fields' = map (\(_, x, y) -> (x, y)) fields
decs <- forM (filter (`melem` onlyCons) cNames) $ \cName ->
liftGadtC cName fields' resType typeSig ts cx f
return (concat decs)
#endif
_ -> fail $ "Unsupported constructor type: `" ++ pprint con ++ "'"
#if MIN_VERSION_template_haskell(2,11,0)
splitAppT :: Type -> [Type]
splitAppT (AppT x y) = splitAppT x ++ [y]
splitAppT t = [t]
liftGadtC :: Name -> [BangType] -> Type -> Bool -> [TyVarBndr] -> Cxt -> Type -> Q [Dec]
liftGadtC cName fields resType typeSig ts cx f =
liftCon typeSig ts cx f nextTy (init tys) Nothing (NormalC cName fields)
where
(_f : tys) = splitAppT resType
nextTy = last tys
#endif
melem :: Eq a => a -> Maybe [a] -> Bool
melem _ Nothing = True
melem x (Just xs) = x `elem` xs
constructorNames :: Con -> [Name]
constructorNames (NormalC name _) = [name]
constructorNames (RecC name _) = [name]
constructorNames (InfixC _ name _) = [name]
constructorNames (ForallC _ _ c) = constructorNames c
#if MIN_VERSION_template_haskell(2,11,0)
constructorNames (GadtC names _ _) = names
constructorNames (RecGadtC names _ _) = names
#endif
constructorNames con' = fail $ "Unsupported constructor type: `" ++ pprint con' ++ "'"
liftDec :: Bool
-> Maybe [Name]
-> Dec
-> Q [Dec]
#if MIN_VERSION_template_haskell(2,11,0)
liftDec typeSig onlyCons (DataD _ tyName tyVarBndrs _ cons _)
#else
liftDec typeSig onlyCons (DataD _ tyName tyVarBndrs cons _)
#endif
| null tyVarBndrs = fail $ "Type constructor " ++ pprint tyName ++ " needs at least one type parameter"
| otherwise = concat <$> mapM (liftCon typeSig [] [] con nextTy (init tys) onlyCons) cons
where
tys = map (VarT . tyVarBndrName) tyVarBndrs
nextTy = last tys
con = ConT tyName
liftDec _ _ dec = fail $ unlines
[ "failed to derive makeFree operations:"
, "expected a data type constructor"
, "but got " ++ pprint dec ]
genFree :: Bool
-> Maybe [Name]
-> Name
-> Q [Dec]
genFree typeSig cnames tyCon = do
info <- reify tyCon
case info of
TyConI dec -> liftDec typeSig cnames dec
_ -> fail "makeFree expects a type constructor"
genFreeCon :: Bool
-> Name
-> Q [Dec]
genFreeCon typeSig cname = do
info <- reify cname
case info of
DataConI _ _ tname
#if !(MIN_VERSION_template_haskell(2,11,0))
_
#endif
-> genFree typeSig (Just [cname]) tname
_ -> fail $ unlines
[ "expected a data constructor"
, "but got " ++ pprint info ]
makeFree :: Name -> Q [Dec]
makeFree = genFree True Nothing
makeFree_ :: Name -> Q [Dec]
makeFree_ = genFree False Nothing
makeFreeCon :: Name -> Q [Dec]
makeFreeCon = genFreeCon True
makeFreeCon_ :: Name -> Q [Dec]
makeFreeCon_ = genFreeCon False