module Control.Monad.Free.TH
(
makeFree
) where
import Control.Arrow
import Control.Applicative
import Control.Monad
import Data.Char (toLower)
import Language.Haskell.TH
data Arg
= Captured Type Exp
| Param Type
deriving (Show)
params :: [Arg] -> [Type]
params [] = []
params (Param t : xs) = t : params xs
params (_ : xs) = params xs
captured :: [Arg] -> [(Type, Exp)]
captured [] = []
captured (Captured t e : xs) = (t, e) : captured xs
captured (_ : xs) = captured xs
zipExprs :: [Exp] -> [Exp] -> [Arg] -> [Exp]
zipExprs (p:ps) cs (Param _ : as) = p : zipExprs ps cs as
zipExprs ps (c:cs) (Captured _ _ : as) = c : zipExprs ps cs as
zipExprs _ _ _ = []
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV name) = name
tyVarBndrName (KindedTV name _) = name
findTypeOrFail :: String -> Q Name
findTypeOrFail s = lookupTypeName s >>= maybe (fail $ s ++ " is not in scope") return
findValueOrFail :: String -> Q Name
findValueOrFail s = lookupValueName s >>= maybe (fail $ s ++ "is not in scope") return
mkOpName :: String -> Q String
mkOpName (':':name) = return name
mkOpName ( c :name) = return $ toLower c : name
mkOpName _ = fail "null constructor name"
usesTV :: Name -> Type -> Bool
usesTV n (VarT name) = n == name
usesTV n (AppT t1 t2) = any (usesTV n) [t1, t2]
usesTV n (SigT t _ ) = usesTV n t
usesTV n (ForallT bs _ t) = usesTV n t && n `notElem` map tyVarBndrName bs
usesTV _ _ = False
mkArg :: Name -> Type -> Q Arg
mkArg n t
| usesTV n t =
case t of
VarT _ -> return $ Captured (TupleT 0) (TupE [])
AppT (AppT ArrowT _) _ -> do
(ts, name) <- arrowsToTuple t
when (name /= n) $ fail "return type is not the parameter"
let tup = foldl AppT (TupleT $ length ts) ts
xs <- mapM (const $ newName "x") ts
return $ Captured tup (LamE (map VarP xs) (TupE (map VarE xs)))
_ -> fail "don't know how to make Arg"
| otherwise = return $ Param t
where
arrowsToTuple (AppT (AppT ArrowT t1) (VarT name)) = return ([t1], name)
arrowsToTuple (AppT (AppT ArrowT t1) t2) = do
(ts, name) <- arrowsToTuple t2
return (t1:ts, name)
arrowsToTuple _ = fail "return type is not a variable"
mapRet :: (Exp -> Exp) -> Exp -> Exp
mapRet f (LamE ps e) = LamE ps $ mapRet f e
mapRet f e = f e
unifyT :: (Type, Exp) -> (Type, Exp) -> Q (Type, [Exp])
unifyT (TupleT 0, _) (TupleT 0, _) = fail "can't accept 2 mere parameters"
unifyT (TupleT 0, _) (t, e) = do
maybe' <- ConT <$> findTypeOrFail "Maybe"
nothing' <- ConE <$> findValueOrFail "Nothing"
just' <- ConE <$> findValueOrFail "Just"
return (AppT maybe' t, [nothing', mapRet (AppE just') e])
unifyT x y@(TupleT 0, _) = second reverse <$> unifyT y x
unifyT (t1, e1) (t2, e2) = do
either' <- ConT <$> findTypeOrFail "Either"
left' <- ConE <$> findValueOrFail "Left"
right' <- ConE <$> findValueOrFail "Right"
return (AppT (AppT either' t1) t2, [mapRet (AppE left') e1, mapRet (AppE right') e2])
unifyCaptured :: Name -> [(Type, Exp)] -> Q (Type, [Exp])
unifyCaptured a [] = return (VarT a, [])
unifyCaptured _ [(t, e)] = return (t, [e])
unifyCaptured _ [x, y] = unifyT x y
unifyCaptured _ _ = fail "can't unify more than 2 arguments that use type parameter"
liftCon' :: [TyVarBndr] -> Cxt -> Type -> Name -> [Name] -> Name -> [Type] -> Q [Dec]
liftCon' tvbs cx f n ns cn ts = do
opName <- mkName <$> mkOpName (nameBase cn)
m <- newName "m"
a <- newName "a"
monadFree <- findTypeOrFail "MonadFree"
liftF <- findValueOrFail "liftF"
args <- mapM (mkArg n) ts
let ps = params args
cs = captured args
(retType, es) <- unifyCaptured a cs
let opType = foldr (AppT . AppT ArrowT) (AppT (VarT m) retType) ps
xs <- mapM (const $ newName "p") ps
let pat = map VarP xs
exprs = zipExprs (map VarE xs) es args
fval = foldl AppE (ConE cn) exprs
q = tvbs ++ map PlainTV (qa ++ m : ns)
qa = case retType of VarT b | a == b -> [a]; _ -> []
f' = foldl AppT f (map VarT ns)
return
#if MIN_VERSION_template_haskell(2,10,0)
[ SigD opName (ForallT q (cx ++ [ConT monadFree `AppT` f' `AppT` VarT m]) opType)
#else
[ SigD opName (ForallT q (cx ++ [ClassP monadFree [f', VarT m]]) opType)
#endif
, FunD opName [ Clause pat (NormalB $ AppE (VarE liftF) fval) [] ] ]
liftCon :: [TyVarBndr] -> Cxt -> Type -> Name -> [Name] -> Con -> Q [Dec]
liftCon ts cx f n ns con =
case con of
NormalC cName fields -> liftCon' ts cx f n ns cName $ map snd fields
RecC cName fields -> liftCon' ts cx f n ns cName $ map (\(_, _, ty) -> ty) fields
InfixC (_,t1) cName (_,t2) -> liftCon' ts cx f n ns cName [t1, t2]
ForallC ts' cx' con' -> liftCon (ts ++ ts') (cx ++ cx') f n ns con'
liftDec :: Dec -> Q [Dec]
liftDec (DataD _ tyName tyVarBndrs cons _)
| null tyVarBndrs = fail $ "Type " ++ show tyName ++ " needs at least one free variable"
| otherwise = concat <$> mapM (liftCon [] [] con nextTyName (init tyNames)) cons
where
tyNames = map tyVarBndrName tyVarBndrs
nextTyName = last tyNames
con = ConT tyName
liftDec dec = fail $ "liftDec: Don't know how to lift " ++ show dec
makeFree :: Name -> Q [Dec]
makeFree tyCon = do
info <- reify tyCon
case info of
TyConI dec -> liftDec dec
_ -> fail "makeFree expects a type constructor"