{-# LANGUAGE CPP #-}
#ifdef TRUSTWORTHY
{-# LANGUAGE Trustworthy #-}
#endif
#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(x,y,z) 1
#endif
module Control.Lens.Internal.PrismTH
( makePrisms
, makeClassyPrisms
, makeDecPrisms
) where
import Control.Applicative
import Control.Lens.Fold
import Control.Lens.Getter
import Control.Lens.Internal.TH
import Control.Lens.Lens
import Control.Lens.Setter
import Control.Monad
import Data.Char (isUpper)
import Data.List
import Data.Set.Lens
import Data.Traversable
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D
import Language.Haskell.TH.Lens
import qualified Data.Map as Map
import qualified Data.Set as Set
import Prelude
makePrisms :: Name -> DecsQ
makePrisms = makePrisms' True
makeClassyPrisms :: Name -> DecsQ
makeClassyPrisms = makePrisms' False
makePrisms' :: Bool -> Name -> DecsQ
makePrisms' normal typeName =
do info <- D.reifyDatatype typeName
let cls | normal = Nothing
| otherwise = Just (D.datatypeName info)
cons = D.datatypeCons info
makeConsPrisms (D.datatypeType info) (map normalizeCon cons) cls
makeDecPrisms :: Bool -> Dec -> DecsQ
makeDecPrisms normal dec =
do info <- D.normalizeDec dec
let cls | normal = Nothing
| otherwise = Just (D.datatypeName info)
cons = D.datatypeCons info
makeConsPrisms (D.datatypeType info) (map normalizeCon cons) cls
makeConsPrisms :: Type -> [NCon] -> Maybe Name -> DecsQ
makeConsPrisms t [con@(NCon _ [] [] _)] Nothing = makeConIso t con
makeConsPrisms t cons Nothing =
fmap concat $ for cons $ \con ->
do let conName = view nconName con
stab <- computeOpticType t cons con
let n = prismName conName
sequenceA
[ sigD n (close (stabToType stab))
, valD (varP n) (normalB (makeConOpticExp stab cons con)) []
]
makeConsPrisms t cons (Just typeName) =
sequenceA
[ makeClassyPrismClass t className methodName cons
, makeClassyPrismInstance t className methodName cons
]
where
className = mkName ("As" ++ nameBase typeName)
methodName = prismName typeName
data OpticType = PrismType | ReviewType
data Stab = Stab Cxt OpticType Type Type Type Type
simplifyStab :: Stab -> Stab
simplifyStab (Stab cx ty _ t _ b) = Stab cx ty t t b b
stabSimple :: Stab -> Bool
stabSimple (Stab _ _ s t a b) = s == t && a == b
stabToType :: Stab -> Type
stabToType stab@(Stab cx ty s t a b) = ForallT vs cx $
case ty of
PrismType | stabSimple stab -> prism'TypeName `conAppsT` [t,b]
| otherwise -> prismTypeName `conAppsT` [s,t,a,b]
ReviewType -> reviewTypeName `conAppsT` [t,b]
where
vs = map PlainTV
$ nub
$ toListOf typeVars cx
stabType :: Stab -> OpticType
stabType (Stab _ o _ _ _ _) = o
computeOpticType :: Type -> [NCon] -> NCon -> Q Stab
computeOpticType t cons con =
do let cons' = delete con cons
if null (_nconVars con)
then computePrismType t (view nconCxt con) cons' con
else computeReviewType t (view nconCxt con) (view nconTypes con)
computeReviewType :: Type -> Cxt -> [Type] -> Q Stab
computeReviewType s' cx tys =
do let t = s'
s <- fmap VarT (newName "s")
a <- fmap VarT (newName "a")
b <- toTupleT (map return tys)
return (Stab cx ReviewType s t a b)
computePrismType :: Type -> Cxt -> [NCon] -> NCon -> Q Stab
computePrismType t cx cons con =
do let ts = view nconTypes con
unbound = setOf typeVars t Set.\\ setOf typeVars cons
sub <- sequenceA (fromSet (newName . nameBase) unbound)
b <- toTupleT (map return ts)
a <- toTupleT (map return (substTypeVars sub ts))
let s = substTypeVars sub t
return (Stab cx PrismType s t a b)
computeIsoType :: Type -> [Type] -> TypeQ
computeIsoType t' fields =
do sub <- sequenceA (fromSet (newName . nameBase) (setOf typeVars t'))
let t = return t'
s = return (substTypeVars sub t')
b = toTupleT (map return fields)
a = toTupleT (map return (substTypeVars sub fields))
#ifndef HLINT
ty | Map.null sub = appsT (conT iso'TypeName) [t,b]
| otherwise = appsT (conT isoTypeName) [s,t,a,b]
#endif
close =<< ty
makeConOpticExp :: Stab -> [NCon] -> NCon -> ExpQ
makeConOpticExp stab cons con =
case stabType stab of
PrismType -> makeConPrismExp stab cons con
ReviewType -> makeConReviewExp con
makeConIso :: Type -> NCon -> DecsQ
makeConIso s con =
do let ty = computeIsoType s (view nconTypes con)
defName = prismName (view nconName con)
sequenceA
[ sigD defName ty
, valD (varP defName) (normalB (makeConIsoExp con)) []
]
makeConPrismExp ::
Stab ->
[NCon] ->
NCon ->
ExpQ
makeConPrismExp stab cons con = appsE [varE prismValName, reviewer, remitter]
where
ts = view nconTypes con
fields = length ts
conName = view nconName con
reviewer = makeReviewer conName fields
remitter | stabSimple stab = makeSimpleRemitter conName fields
| otherwise = makeFullRemitter cons conName
makeConIsoExp :: NCon -> ExpQ
makeConIsoExp con = appsE [varE isoValName, remitter, reviewer]
where
conName = view nconName con
fields = length (view nconTypes con)
reviewer = makeReviewer conName fields
remitter = makeIsoRemitter conName fields
makeConReviewExp :: NCon -> ExpQ
makeConReviewExp con = appE (varE untoValName) reviewer
where
conName = view nconName con
fields = length (view nconTypes con)
reviewer = makeReviewer conName fields
makeReviewer :: Name -> Int -> ExpQ
makeReviewer conName fields =
do xs <- newNames "x" fields
lam1E (toTupleP (map varP xs))
(conE conName `appsE1` map varE xs)
makeSimpleRemitter :: Name -> Int -> ExpQ
makeSimpleRemitter conName fields =
do x <- newName "x"
xs <- newNames "y" fields
let matches =
[ match (conP conName (map varP xs))
(normalB (appE (conE rightDataName) (toTupleE (map varE xs))))
[]
, match wildP (normalB (appE (conE leftDataName) (varE x))) []
]
lam1E (varP x) (caseE (varE x) matches)
makeFullRemitter :: [NCon] -> Name -> ExpQ
makeFullRemitter cons target =
do x <- newName "x"
lam1E (varP x) (caseE (varE x) (map mkMatch cons))
where
mkMatch (NCon conName _ _ n) =
do xs <- newNames "y" (length n)
match (conP conName (map varP xs))
(normalB
(if conName == target
then appE (conE rightDataName) (toTupleE (map varE xs))
else appE (conE leftDataName) (conE conName `appsE1` map varE xs)))
[]
makeIsoRemitter :: Name -> Int -> ExpQ
makeIsoRemitter conName fields =
do xs <- newNames "x" fields
lam1E (conP conName (map varP xs))
(toTupleE (map varE xs))
makeClassyPrismClass ::
Type ->
Name ->
Name ->
[NCon] ->
DecQ
makeClassyPrismClass t className methodName cons =
do r <- newName "r"
#ifndef HLINT
let methodType = appsT (conT prism'TypeName) [varT r,return t]
#endif
methodss <- traverse (mkMethod (VarT r)) cons'
classD (cxt[]) className (map PlainTV (r : vs)) (fds r)
( sigD methodName methodType
: map return (concat methodss)
)
where
mkMethod r con =
do Stab cx o _ _ _ b <- computeOpticType t cons con
let stab' = Stab cx o r r b b
defName = view nconName con
body = appsE [varE composeValName, varE methodName, varE defName]
sequenceA
[ sigD defName (return (stabToType stab'))
, valD (varP defName) (normalB body) []
]
cons' = map (over nconName prismName) cons
vs = Set.toList (setOf typeVars t)
fds r
| null vs = []
| otherwise = [FunDep [r] vs]
makeClassyPrismInstance ::
Type ->
Name ->
Name ->
[NCon] ->
DecQ
makeClassyPrismInstance s className methodName cons =
do let vs = Set.toList (setOf typeVars s)
cls = className `conAppsT` (s : map VarT vs)
instanceD (cxt[]) (return cls)
( valD (varP methodName)
(normalB (varE idValName)) []
: [ do stab <- computeOpticType s cons con
let stab' = simplifyStab stab
valD (varP (prismName conName))
(normalB (makeConOpticExp stab' cons con)) []
| con <- cons
, let conName = view nconName con
]
)
data NCon = NCon
{ _nconName :: Name
, _nconVars :: [Name]
, _nconCxt :: Cxt
, _nconTypes :: [Type]
}
deriving (Eq)
instance HasTypeVars NCon where
typeVarsEx s f (NCon x vars y z) = NCon x vars <$> typeVarsEx s' f y <*> typeVarsEx s' f z
where s' = foldl' (flip Set.insert) s vars
nconName :: Lens' NCon Name
nconName f x = fmap (\y -> x {_nconName = y}) (f (_nconName x))
nconCxt :: Lens' NCon Cxt
nconCxt f x = fmap (\y -> x {_nconCxt = y}) (f (_nconCxt x))
nconTypes :: Lens' NCon [Type]
nconTypes f x = fmap (\y -> x {_nconTypes = y}) (f (_nconTypes x))
normalizeCon :: D.ConstructorInfo -> NCon
normalizeCon info = NCon (D.constructorName info)
(D.tvName <$> D.constructorVars info)
(D.constructorContext info)
(D.constructorFields info)
prismName :: Name -> Name
prismName n = case nameBase n of
[] -> error "prismName: empty name base?"
x:xs | isUpper x -> mkName ('_':x:xs)
| otherwise -> mkName ('.':x:xs)
close :: Type -> TypeQ
close t = forallT (map PlainTV (Set.toList vs)) (cxt[]) (return t)
where
vs = setOf typeVars t