{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}
module Grisette.Core.TH
(
makeUnionWrapper,
makeUnionWrapper',
)
where
import Control.Monad
import Grisette.Core.THCompat
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
makeUnionWrapper' ::
[String] ->
Name ->
Q [Dec]
makeUnionWrapper' :: [String] -> Name -> Q [Dec]
makeUnionWrapper' [String]
names Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([String] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [String]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Con] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
constructors) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Number of names does not match the number of constructors"
[[Dec]]
ds <- (String -> Con -> Q [Dec]) -> [String] -> [Con] -> Q [[Dec]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Con -> Q [Dec]
mkSingleWrapper [String]
names [Con]
constructors
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Dec] -> Q [Dec]) -> [Dec] -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ [[Dec]] -> [Dec]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join [[Dec]]
ds
occName :: Name -> String
occName :: Name -> String
occName (Name (OccName String
name) NameFlavour
_) = String
name
getConstructorName :: Con -> Q String
getConstructorName :: Con -> Q String
getConstructorName (NormalC Name
name [BangType]
_) = String -> Q String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecC Name
name [VarBangType]
_) = String -> Q String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName InfixC {} =
String -> Q String
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"You should use makeUnionWrapper' to manually provide the name for infix constructors"
getConstructorName (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
c) = Con -> Q String
getConstructorName Con
c
getConstructorName (GadtC [Name
name] [BangType]
_ Type
_) = String -> Q String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName (RecGadtC [Name
name] [VarBangType]
_ Type
_) = String -> Q String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ Name -> String
occName Name
name
getConstructorName Con
c = String -> Q String
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q String) -> String -> Q String
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor at this time: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
c
getConstructors :: Name -> Q [Con]
getConstructors :: Name -> Q [Con]
getConstructors Name
typName = do
Info
d <- Name -> Q Info
reify Name
typName
case Info
d of
TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_) -> [Con] -> Q [Con]
forall (m :: * -> *) a. Monad m => a -> m a
return [Con]
constructors
TyConI (NewtypeD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ Con
constructor [DerivClause]
_) -> [Con] -> Q [Con]
forall (m :: * -> *) a. Monad m => a -> m a
return [Con
constructor]
Info
_ -> String -> Q [Con]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Con]) -> String -> Q [Con]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Ppr a => a -> String
pprint Info
d
makeUnionWrapper ::
String ->
Name ->
Q [Dec]
makeUnionWrapper :: String -> Name -> Q [Dec]
makeUnionWrapper String
prefix Name
typName = do
[Con]
constructors <- Name -> Q [Con]
getConstructors Name
typName
[String]
constructorNames <- (Con -> Q String) -> [Con] -> Q [String]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Con -> Q String
getConstructorName [Con]
constructors
[String] -> Name -> Q [Dec]
makeUnionWrapper' ((String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> [String] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String]
constructorNames) Name
typName
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr :: Int -> Exp -> Q Exp
augmentNormalCExpr Int
n Exp
f = do
[Name]
xs <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x")
let args :: [Pat]
args = (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
xs
Exp
mrgSingleFun <- [|mrgSingle|]
Exp -> Q Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$
[Pat] -> Exp -> Exp
LamE
[Pat]
args
( Exp -> Exp -> Exp
AppE Exp
mrgSingleFun (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
(Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
f ((Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
xs)
)
augmentNormalCType :: Type -> Q Type
augmentNormalCType :: Type -> Q Type
augmentNormalCType (ForallT [TyVarBndr Specificity]
tybinders Cxt
ctx Type
ty1) = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
ty1
Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT ([TyVarBndr Specificity]
bndrs [TyVarBndr Specificity]
-> [TyVarBndr Specificity] -> [TyVarBndr Specificity]
forall a. [a] -> [a] -> [a]
++ [TyVarBndr Specificity]
tybinders) (Cxt
preds Cxt -> Cxt -> Cxt
forall a. [a] -> [a] -> [a]
++ Cxt
ctx) Type
augmentedTyp
augmentNormalCType Type
t = do
(([TyVarBndr Specificity]
bndrs, Cxt
preds), Type
augmentedTyp) <- Type -> Q (([TyVarBndr Specificity], Cxt), Type)
augmentFinalType Type
t
Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
preds Type
augmentedTyp
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper :: String -> Con -> Q [Dec]
mkSingleWrapper String
name (NormalC Name
oriName [BangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([BangType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
b) (Name -> Exp
ConE Name
oriName)
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
name (RecC Name
oriName [VarBangType]
b) = do
DataConI Name
_ Type
constructorTyp Name
_ <- Name -> Q Info
reify Name
oriName
Type
augmentedTyp <- Type -> Q Type
augmentNormalCType Type
constructorTyp
let retName :: Name
retName = String -> Name
mkName String
name
Exp
expr <- Int -> Exp -> Q Exp
augmentNormalCExpr ([VarBangType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
b) (Name -> Exp
ConE Name
oriName)
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ Name -> Type -> Dec
SigD Name
retName Type
augmentedTyp,
Name -> [Clause] -> Dec
FunD Name
retName [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
expr) []]
]
mkSingleWrapper String
_ Con
v = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ String
"Unsupported constructor" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Ppr a => a -> String
pprint Con
v