{-# LANGUAGE CPP, PatternGuards, Rank2Types #-}
module Data.Functor.Foldable.TH
( MakeBaseFunctor(..)
, BaseRules
, baseRules
, baseRulesType
, baseRulesCon
, baseRulesField
) where
import Control.Applicative as A
import Control.Monad
import Data.Traversable as T
import Data.Functor.Identity
import Language.Haskell.TH
import Language.Haskell.TH.Datatype as TH.Abs
import Language.Haskell.TH.Datatype.TyVarBndr
import Language.Haskell.TH.Syntax (mkNameG_tc, mkNameG_v)
import Data.Char (GeneralCategory (..), generalCategory)
import Data.Orphans ()
#ifndef CURRENT_PACKAGE_KEY
import Data.Version (showVersion)
import Paths_recursion_schemes (version)
#endif
#ifdef __HADDOCK__
import Data.Functor.Foldable
#endif
class MakeBaseFunctor a where
makeBaseFunctor :: a -> DecsQ
makeBaseFunctor = makeBaseFunctorWith baseRules
makeBaseFunctorWith :: BaseRules -> a -> DecsQ
instance MakeBaseFunctor a => MakeBaseFunctor [a] where
makeBaseFunctorWith rules a = fmap concat (T.traverse (makeBaseFunctorWith rules) a)
instance MakeBaseFunctor a => MakeBaseFunctor (Q a) where
makeBaseFunctorWith rules a = makeBaseFunctorWith rules =<< a
instance MakeBaseFunctor Name where
makeBaseFunctorWith rules name = reifyDatatype name >>= makePrimForDI rules Nothing
instance MakeBaseFunctor Dec where
#if MIN_VERSION_template_haskell(2,11,0)
makeBaseFunctorWith rules (InstanceD overlaps ctx classHead []) = do
let instanceFor = InstanceD overlaps ctx
#else
makeBaseFunctorWith rules (InstanceD ctx classHead []) = do
let instanceFor = InstanceD ctx
#endif
case classHead of
ConT u `AppT` t | u == recursiveTypeName || u == corecursiveTypeName -> do
name <- headOfType t
di <- reifyDatatype name
makePrimForDI rules (Just $ \n -> instanceFor (ConT n `AppT` t)) di
_ -> fail $ "makeBaseFunctor: expected an instance head like `ctx => Recursive (T a b ...)`, got " ++ show classHead
makeBaseFunctorWith _ _ = fail "makeBaseFunctor(With): expected an empty instance declaration"
data BaseRules = BaseRules
{ _baseRulesType :: Name -> Name
, _baseRulesCon :: Name -> Name
, _baseRulesField :: Name -> Name
}
baseRules :: BaseRules
baseRules = BaseRules
{ _baseRulesType = toFName
, _baseRulesCon = toFName
, _baseRulesField = toFName
}
baseRulesType :: Functor f => ((Name -> Name) -> f (Name -> Name)) -> BaseRules -> f BaseRules
baseRulesType f rules = (\x -> rules { _baseRulesType = x }) <$> f (_baseRulesType rules)
baseRulesCon :: Functor f => ((Name -> Name) -> f (Name -> Name)) -> BaseRules -> f BaseRules
baseRulesCon f rules = (\x -> rules { _baseRulesCon = x }) <$> f (_baseRulesCon rules)
baseRulesField :: Functor f => ((Name -> Name) -> f (Name -> Name)) -> BaseRules -> f BaseRules
baseRulesField f rules = (\x -> rules { _baseRulesField = x }) <$> f (_baseRulesField rules)
toFName :: Name -> Name
toFName = mkName . f . nameBase
where
f name | isInfixName name = name ++ "$"
| otherwise = name ++ "F"
isInfixName :: String -> Bool
isInfixName = all isSymbolChar
makePrimForDI :: BaseRules
-> Maybe (Name -> [Dec] -> Dec)
-> DatatypeInfo
-> DecsQ
makePrimForDI rules mkInstance'
(DatatypeInfo { datatypeName = tyName
, datatypeInstTypes = instTys
, datatypeCons = cons
, datatypeVariant = variant }) = do
when isDataFamInstance $
fail "makeBaseFunctor: Data families are currently not supported."
makePrimForDI' rules mkInstance'
(variant == Newtype) tyName
(map toTyVarBndr instTys) cons
where
isDataFamInstance = case variant of
DataInstance -> True
NewtypeInstance -> True
Datatype -> False
Newtype -> False
toTyVarBndr :: Type -> TyVarBndrUnit
toTyVarBndr (VarT n) = plainTV n
toTyVarBndr (SigT (VarT n) k) = kindedTV n k
toTyVarBndr _ = error "toTyVarBndr"
makePrimForDI' :: BaseRules
-> Maybe (Name -> [Dec] -> Dec)
-> Bool -> Name -> [TyVarBndrUnit]
-> [ConstructorInfo] -> DecsQ
makePrimForDI' rules mkInstance' isNewtype tyName vars cons = do
let vars' = map VarT (typeVars vars)
let tyNameF = _baseRulesType rules tyName
let s = conAppsT tyName vars'
rName <- newName "r"
let r = VarT rName
let varsF = vars ++ [plainTV rName]
cons' <- traverse (conTypeTraversal resolveTypeSynonyms) cons
let consF
= toCon
. conNameMap (_baseRulesCon rules)
. conFieldNameMap (_baseRulesField rules)
. conTypeMap (substType s r)
<$> cons'
let dataDec = case consF of
#if MIN_VERSION_template_haskell(2,11,0)
[conF] | isNewtype ->
NewtypeD [] tyNameF varsF Nothing conF deriveds
_ ->
DataD [] tyNameF varsF Nothing consF deriveds
#else
[conF] | isNewtype ->
NewtypeD [] tyNameF varsF conF deriveds
_ ->
DataD [] tyNameF varsF consF deriveds
#endif
where
deriveds =
#if MIN_VERSION_template_haskell(2,12,0)
[DerivClause Nothing
[ ConT functorTypeName
, ConT foldableTypeName
, ConT traversableTypeName ]]
#elif MIN_VERSION_template_haskell(2,11,0)
[ ConT functorTypeName
, ConT foldableTypeName
, ConT traversableTypeName ]
#else
[functorTypeName, foldableTypeName, traversableTypeName]
#endif
baseDec <- tySynInstDCompat baseTypeName Nothing
[pure s] (pure $ conAppsT tyNameF vars')
let mkInstance :: Name -> [Dec] -> Dec
mkInstance = case mkInstance' of
Just f -> f
Nothing -> \n ->
#if MIN_VERSION_template_haskell(2,11,0)
InstanceD Nothing [] (ConT n `AppT` s)
#else
InstanceD [] (ConT n `AppT` s)
#endif
projDec <- FunD projectValName <$> mkMorphism id (_baseRulesCon rules) cons'
let recursiveDec = mkInstance recursiveTypeName [projDec]
embedDec <- FunD embedValName <$> mkMorphism (_baseRulesCon rules) id cons'
let corecursiveDec = mkInstance corecursiveTypeName [embedDec]
A.pure [dataDec, baseDec, recursiveDec, corecursiveDec]
mkMorphism
:: (Name -> Name)
-> (Name -> Name)
-> [ConstructorInfo]
-> Q [Clause]
mkMorphism nFrom nTo args = for args $ \ci -> do
let n = constructorName ci
fs <- replicateM (length (constructorFields ci)) (newName "x")
pure $ Clause [ConP (nFrom n) (map VarP fs)]
(NormalB $ foldl AppE (ConE $ nTo n) (map VarE fs))
[]
conNameTraversal :: Traversal' ConstructorInfo Name
conNameTraversal = lens constructorName (\s v -> s { constructorName = v })
conFieldNameTraversal :: Traversal' ConstructorInfo Name
conFieldNameTraversal = lens constructorVariant (\s v -> s { constructorVariant = v })
. conVariantTraversal
where
conVariantTraversal :: Traversal' ConstructorVariant Name
conVariantTraversal _ NormalConstructor = pure NormalConstructor
conVariantTraversal _ InfixConstructor = pure InfixConstructor
conVariantTraversal f (RecordConstructor fs) = RecordConstructor <$> traverse f fs
conTypeTraversal :: Traversal' ConstructorInfo Type
conTypeTraversal = lens constructorFields (\s v -> s { constructorFields = v })
. traverse
conNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conNameMap = over conNameTraversal
conFieldNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conFieldNameMap = over conFieldNameTraversal
conTypeMap :: (Type -> Type) -> ConstructorInfo -> ConstructorInfo
conTypeMap = over conTypeTraversal
type Lens' s a = forall f. Functor f => (a -> f a) -> s -> f s
type Traversal' s a = forall f. Applicative f => (a -> f a) -> s -> f s
lens :: (s -> a) -> (s -> a -> s) -> Lens' s a
lens sa sas afa s = sas s <$> afa (sa s)
{-# INLINE lens #-}
over :: Traversal' s a -> (a -> a) -> s -> s
over l f = runIdentity . l (Identity . f)
{-# INLINE over #-}
headOfType :: Type -> Q Name
headOfType (AppT t _) = headOfType t
headOfType (VarT n) = return n
headOfType (ConT n) = return n
headOfType t = fail $ "headOfType: " ++ show t
typeVars :: [TyVarBndr_ flag] -> [Name]
typeVars = map tvName
conAppsT :: Name -> [Type] -> Type
conAppsT conName = foldl AppT (ConT conName)
substType
:: Type
-> Type
-> Type
-> Type
substType a b = go
where
go x | x == a = b
go (VarT n) = VarT n
go (AppT l r) = AppT (go l) (go r)
go (ForallT xs ctx t) = ForallT xs ctx (go t)
go (SigT t k) = SigT (go t) k
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT l n r) = InfixT (go l) n (go r)
go (UInfixT l n r) = UInfixT (go l) n (go r)
go (ParensT t) = ParensT (go t)
#endif
go x = x
toCon :: ConstructorInfo -> Con
toCon (ConstructorInfo { constructorName = name
, constructorVars = vars
, constructorContext = ctxt
, constructorFields = ftys
, constructorStrictness = fstricts
, constructorVariant = variant })
| not (null vars && null ctxt)
= error "makeBaseFunctor: GADTs are not currently supported."
| otherwise
= let bangs = map toBang fstricts
in case variant of
NormalConstructor -> NormalC name $ zip bangs ftys
RecordConstructor fnames -> RecC name $ zip3 fnames bangs ftys
InfixConstructor
| [bang1, bang2] <- bangs
, [fty1, fty2] <- ftys
-> InfixC (bang1, fty1) name (bang2, fty2)
| otherwise
-> error $ "makeBaseFunctor: Encountered an InfixConstructor "
++ "without exactly two fields"
where
#if MIN_VERSION_template_haskell(2,11,0)
toBang (FieldStrictness upkd strct) = Bang (toSourceUnpackedness upkd)
(toSourceStrictness strct)
where
toSourceUnpackedness :: Unpackedness -> SourceUnpackedness
toSourceUnpackedness UnspecifiedUnpackedness = NoSourceUnpackedness
toSourceUnpackedness NoUnpack = SourceNoUnpack
toSourceUnpackedness Unpack = SourceUnpack
toSourceStrictness :: Strictness -> SourceStrictness
toSourceStrictness UnspecifiedStrictness = NoSourceStrictness
toSourceStrictness Lazy = SourceLazy
toSourceStrictness TH.Abs.Strict = SourceStrict
#else
toBang (FieldStrictness UnspecifiedUnpackedness Strict) = IsStrict
toBang (FieldStrictness UnspecifiedUnpackedness UnspecifiedStrictness) = NotStrict
toBang (FieldStrictness Unpack Strict) = Unpacked
toBang FieldStrictness{} = NotStrict
#endif
isSymbolChar :: Char -> Bool
isSymbolChar c = not (isPuncChar c) && case generalCategory c of
MathSymbol -> True
CurrencySymbol -> True
ModifierSymbol -> True
OtherSymbol -> True
DashPunctuation -> True
OtherPunctuation -> c `notElem` "'\""
ConnectorPunctuation -> c /= '_'
_ -> False
isPuncChar :: Char -> Bool
isPuncChar c = c `elem` ",;()[]{}`"
rsPackageKey :: String
#ifdef CURRENT_PACKAGE_KEY
rsPackageKey = CURRENT_PACKAGE_KEY
#else
rsPackageKey = "recursion-schemes-" ++ showVersion version
#endif
mkRsName_tc :: String -> String -> Name
mkRsName_tc = mkNameG_tc rsPackageKey
mkRsName_v :: String -> String -> Name
mkRsName_v = mkNameG_v rsPackageKey
baseTypeName :: Name
baseTypeName = mkRsName_tc "Data.Functor.Foldable" "Base"
recursiveTypeName :: Name
recursiveTypeName = mkRsName_tc "Data.Functor.Foldable" "Recursive"
corecursiveTypeName :: Name
corecursiveTypeName = mkRsName_tc "Data.Functor.Foldable" "Corecursive"
projectValName :: Name
projectValName = mkRsName_v "Data.Functor.Foldable" "project"
embedValName :: Name
embedValName = mkRsName_v "Data.Functor.Foldable" "embed"
functorTypeName :: Name
functorTypeName = mkNameG_tc "base" "GHC.Base" "Functor"
foldableTypeName :: Name
foldableTypeName = mkNameG_tc "base" "Data.Foldable" "Foldable"
traversableTypeName :: Name
traversableTypeName = mkNameG_tc "base" "Data.Traversable" "Traversable"