-- SPDX-FileCopyrightText: 2020 Tocqueville Group
--
-- SPDX-License-Identifier: LicenseRef-MIT-TQ

module Util.TH
  ( deriveGADTNFData
  , lookupTypeNameOrFail
  ) where

import Language.Haskell.TH

-- | Generates an NFData instance for a GADT. /Note:/ This will not generate
-- additional constraints to the generated instance if those are required.
deriveGADTNFData :: Name -> Q [Dec]
deriveGADTNFData :: Name -> Q [Dec]
deriveGADTNFData name :: Name
name = do
  (TyConI (DataD _ dataName :: Name
dataName vars :: [TyVarBndr]
vars _ cons :: [Con]
cons _)) <- Name -> Q Info
reify Name
name
  let
    getNameFromVar :: TyVarBndr -> Name
getNameFromVar (PlainTV n :: Name
n) = Name
n
    getNameFromVar (KindedTV n :: Name
n _) = Name
n
    convertTyVars :: Type -> Type
convertTyVars orig :: Type
orig = (Element [TyVarBndr] -> Type -> Type)
-> Type -> [TyVarBndr] -> Type
forall t b. Container t => (Element t -> b -> b) -> b -> t -> b
foldr (\a :: Element [TyVarBndr]
a b :: Type
b -> Type -> Type -> Type
AppT Type
b (Type -> Type) -> (Name -> Type) -> Name -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Type
VarT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr -> Name
getNameFromVar TyVarBndr
Element [TyVarBndr]
a) Type
orig [TyVarBndr]
vars

    -- Unfolds multiple constructors of form "A, B, C :: A -> Stuff"
    -- into a list of tuples of constructor names and their data
    unfoldConstructor :: Con -> [(Name, [BangType])]
unfoldConstructor (GadtC cs :: [Name]
cs bangs :: [BangType]
bangs _) = (Name -> (Name, [BangType])) -> [Name] -> [(Name, [BangType])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (,[BangType]
bangs) [Name]
cs
    unfoldConstructor (ForallC _ _ c :: Con
c) = Con -> [(Name, [BangType])]
unfoldConstructor Con
c
    unfoldConstructor _ = String -> [(Name, [BangType])]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "Non GADT constructors are not supported."

    -- Constructs a clause "rnf (ConName a1 a2 ...) = rnf (a1, a2, ...)
    makeClauses :: (Name, [a]) -> Q Clause
makeClauses (conName :: Name
conName, bangs :: [a]
bangs) = do
      [Name]
varNames <- (a -> Q Name) -> [a] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\_ -> String -> Q Name
newName "a") [a]
bangs
      let rnfExp :: Exp -> Exp
rnfExp e :: Exp
e = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE (Name -> Exp) -> Name -> Exp
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName "rnf") Exp
e
      Clause -> Q Clause
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> Q Clause) -> Clause -> Q Clause
forall a b. (a -> b) -> a -> b
$
        ([Pat] -> Body -> [Dec] -> Clause
Clause
          [Name -> [Pat] -> Pat
ConP Name
conName ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Name -> Pat
VarP [Name]
varNames]
          (Exp -> Body
NormalB (Exp -> Exp
rnfExp (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp
TupE ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Name -> Exp) -> [Name] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Name -> Exp
VarE [Name]
varNames))
          []
        )

    makeInstance :: [Clause] -> Dec
makeInstance clauses :: [Clause]
clauses =
      Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD
        Maybe Overlap
forall a. Maybe a
Nothing
        []
        (Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName "NFData") (Type -> Type
convertTyVars (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Type
ConT Name
dataName))
        [Name -> [Clause] -> Dec
FunD (String -> Name
mkName "rnf") [Clause]
clauses]

  [Clause]
clauses <- ((Name, [BangType]) -> Q Clause)
-> [(Name, [BangType])] -> Q [Clause]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Name, [BangType]) -> Q Clause
forall a. (Name, [a]) -> Q Clause
makeClauses ([(Name, [BangType])] -> Q [Clause])
-> [(Name, [BangType])] -> Q [Clause]
forall a b. (a -> b) -> a -> b
$ [Con]
cons [Con] -> (Con -> [(Name, [BangType])]) -> [(Name, [BangType])]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Con -> [(Name, [BangType])]
unfoldConstructor
  return [[Clause] -> Dec
makeInstance [Clause]
clauses]

lookupTypeNameOrFail :: String -> Q Name
lookupTypeNameOrFail :: String -> Q Name
lookupTypeNameOrFail typeStr :: String
typeStr =
  String -> Q (Maybe Name)
lookupTypeName String
typeStr Q (Maybe Name) -> (Maybe Name -> Q Name) -> Q Name
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Nothing -> String -> Q Name
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ "Failed type name lookup for: '" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
typeStr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "'."
    Just tn :: Name
tn -> Name -> Q Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
tn