module Lens.Family.THCore (
defaultNameTransform
, LensTypeInfo
, ConstructorFieldInfo
, deriveLenses
, makeTraversals
) where
import Language.Haskell.TH
import Control.Applicative (pure)
import Data.Char (toLower)
defaultNameTransform :: String -> Maybe String
defaultNameTransform ('_':c:rest) = Just $ toLower c : rest
defaultNameTransform _ = Nothing
type LensTypeInfo = (Name, [TyVarBndr])
type ConstructorFieldInfo = (Name, Strict, Type)
deriveLenses ::
(Name -> LensTypeInfo -> ConstructorFieldInfo -> Q [Dec])
-> (String -> Maybe String)
-> Name -> Q [Dec]
deriveLenses sigDeriver nameTransform datatype = do
typeInfo <- extractLensTypeInfo datatype
let derive1 = deriveLens sigDeriver nameTransform typeInfo
constructorFields <- extractConstructorFields datatype
concat `fmap` mapM derive1 constructorFields
extractLensTypeInfo :: Name -> Q LensTypeInfo
extractLensTypeInfo datatype = do
let datatypeStr = nameBase datatype
i <- reify datatype
return $ case i of
TyConI (DataD _ n ts _ _) -> (n, ts)
TyConI (NewtypeD _ n ts _ _) -> (n, ts)
_ -> error $ "Can't derive Lens for: " ++ datatypeStr
++ ", type name required."
extractConstructorFields :: Name -> Q [ConstructorFieldInfo]
extractConstructorFields datatype = do
let datatypeStr = nameBase datatype
i <- reify datatype
return $ case i of
TyConI (DataD _ _ _ [RecC _ fs] _) -> fs
TyConI (NewtypeD _ _ _ (RecC _ fs) _) -> fs
TyConI (DataD _ _ _ [_] _) ->
error $ "Can't derive Lens without record selectors: " ++ datatypeStr
TyConI NewtypeD{} ->
error $ "Can't derive Lens without record selectors: " ++ datatypeStr
TyConI TySynD{} ->
error $ "Can't derive Lens for type synonym: " ++ datatypeStr
TyConI DataD{} ->
error $ "Can't derive Lens for tagged union: " ++ datatypeStr
_ ->
error $ "Can't derive Lens for: " ++ datatypeStr
++ ", type name required."
deriveLens :: (Name -> LensTypeInfo -> ConstructorFieldInfo -> Q [Dec])
-> (String -> Maybe String)
-> LensTypeInfo -> ConstructorFieldInfo -> Q [Dec]
deriveLens sigDeriver nameTransform ty field = do
let (fieldName, _fieldStrict, _fieldType) = field
(_tyName, _tyVars) = ty
case nameTransform (nameBase fieldName) of
Nothing -> return []
Just lensNameStr -> do
let lensName = mkName lensNameStr
sig <- sigDeriver lensName ty field
body <- deriveLensBody lensName fieldName
return $ sig ++ [body]
deriveLensBody :: Name -> Name -> Q Dec
deriveLensBody lensName fieldName = funD lensName [defLine]
where
a = mkName "a"
f = mkName "f"
defLine = clause pats (normalB body) []
pats = [varP f, varP a]
body = [| (\x -> $(record a fieldName [|x|]))
`fmap` $(appE (varE f) (appE (varE fieldName) (varE a)))
|]
record rec fld val = val >>= \v -> recUpdE (varE rec) [return (fld, v)]
makeTraversals :: Name -> Q [Dec]
makeTraversals = deriveTraversals (\s -> Just ('_':s))
deriveTraversals :: (String -> Maybe String) -> Name -> Q [Dec]
deriveTraversals nameTransform name = do
typeInfo <- extractLensTypeInfo name
let derive1 = deriveTraversal nameTransform typeInfo
constructors <- extractConstructorInfo name
concat `fmap` mapM derive1 constructors
extractConstructorInfo :: Name -> Q [Con]
extractConstructorInfo datatype = do
let datatypeStr = nameBase datatype
i <- reify datatype
return $ case i of
TyConI (DataD _ _ _ [] _) -> []
TyConI (DataD _ _ _ fs _) -> fs
TyConI (NewtypeD _ _ _ f _) -> [f]
_ -> error $ "Can't derive traversal for: " ++ datatypeStr
deriveTraversal :: (String -> Maybe String) -> LensTypeInfo -> Con -> Q [Dec]
deriveTraversal nameTransform ty con = do
let (tyName, _tyVars) = ty
(cName, cTys) = case con of
NormalC n tys -> (n, tys)
RecC n tys -> (n, map (\(_n, s, t) -> (s, t)) tys)
InfixC t1 n t2 -> (n, [t1, t2])
ForallC _ _ _
-> error $ "Traversal derivation not supported: "
++ "forall'd constructor in: " ++ nameBase tyName
case nameTransform (nameBase cName) of
Nothing -> return []
Just lensNameStr -> do
let lensName = mkName lensNameStr
sig <- return []
body <- deriveTraversalBody lensName cName (length cTys)
return $ sig ++ [body]
deriveTraversalBody :: Name -> Name -> Int -> Q Dec
deriveTraversalBody lensName constructorName nArgs =
funD lensName [defLine, fallback] where
argNames = mkArgNames nArgs "x"
newArgNames = mkArgNames nArgs "x'"
argTup = argTupFrom argNames
newArgPat = TildeP $ argPatFrom newArgNames
newArgVars = argVarsFrom newArgNames
t = mkName "t"
k = mkName "k"
constructorUncurried =
constructorUncurriedFrom constructorName newArgPat newArgVars
kApplied = AppE (VarE k) argTup
defLine = clause defPats (normalB defBody) []
defPats = [varP k, conP constructorName (map varP argNames)]
defBody = [| $(return constructorUncurried)
`fmap` $(return kApplied)
|]
fallback = clause fallbackPats (normalB fallbackBody) []
fallbackPats = [wildP, varP t]
fallbackBody = [| pure $(varE t) |]
constructorUncurriedFrom :: Name -> Pat -> [Exp] -> Exp
constructorUncurriedFrom conN pat = LamE [pat] . mkBody where
mkBody = foldl AppE (ConE conN)
unitPat :: Pat
unitPat = TupP []
unitExp :: Exp
unitExp = TupE []
argPatFrom :: [Name] -> Pat
argPatFrom [] = unitPat
argPatFrom [x] = VarP x
argPatFrom xs = TupP (map VarP xs)
argTupFrom :: [Name] -> Exp
argTupFrom [] = unitExp
argTupFrom [x] = VarE x
argTupFrom xs = TupE (map VarE xs)
argVarsFrom :: [Name] -> [Exp]
argVarsFrom = map VarE
mkArgNames :: Int -> String -> [Name]
mkArgNames nArgs base = take nArgs . map toName $ [1 :: Int ..] where
toName 1 = mkName base
toName n = mkName (base ++ show n)