{-# LANGUAGE CPP, NoMonomorphismRestriction, ScopedTypeVariables #-}
module Language.Haskell.TH.Desugar.Expand (
expand, expandType,
expandUnsoundly
) where
import qualified Data.Map as M
import Control.Monad
#if __GLASGOW_HASKELL__ < 709
import Control.Applicative
#endif
import Language.Haskell.TH hiding (cxt)
import Language.Haskell.TH.Syntax ( Quasi(..) )
import Data.Data
import Data.Generics
import qualified Data.Traversable as T
import Language.Haskell.TH.Desugar.AST
import Language.Haskell.TH.Desugar.Core
import Language.Haskell.TH.Desugar.Util
import Language.Haskell.TH.Desugar.Sweeten
import Language.Haskell.TH.Desugar.Reify
import Language.Haskell.TH.Desugar.Subst
expandType :: DsMonad q => DType -> q DType
expandType = expand_type NoIgnore
expand_type :: forall q. DsMonad q => IgnoreKinds -> DType -> q DType
expand_type ign = go []
where
go :: [DTypeArg] -> DType -> q DType
go [] (DForallT tvbs cxt ty) =
DForallT <$> mapM (expand_tvb ign) tvbs
<*> mapM (expand_type ign) cxt
<*> expand_type ign ty
go _ (DForallT {}) =
impossible "A forall type is applied to another type."
go args (DAppT t1 t2) = do
t2' <- expand_type ign t2
go (DTANormal t2' : args) t1
go args (DAppKindT p k) = do
k' <- expand_type ign k
go (DTyArg k' : args) p
go args (DSigT ty ki) = do
ty' <- go [] ty
ki' <- go [] ki
finish (DSigT ty' ki') args
go args (DConT n) = expand_con ign n args
go args ty@(DVarT _) = finish ty args
go args ty@DArrowT = finish ty args
go args ty@(DLitT _) = finish ty args
go args ty@DWildCardT = finish ty args
finish :: DType -> [DTypeArg] -> q DType
finish ty args = return $ applyDType ty args
expand_tvb :: DsMonad q => IgnoreKinds -> DTyVarBndr -> q DTyVarBndr
expand_tvb _ tvb@DPlainTV{} = pure tvb
expand_tvb ign (DKindedTV n k) = DKindedTV n <$> expand_type ign k
expand_con :: forall q.
DsMonad q
=> IgnoreKinds
-> Name
-> [DTypeArg]
-> q DType
expand_con ign n args = do
info <- reifyWithLocals n
case info of
TyConI (TySynD _ _ StarT)
-> return $ applyDType (DConT typeKindName) args
_ -> go info
where
normal_args :: [DType]
normal_args = filterDTANormals args
go :: Info -> q DType
go info = do
dinfo <- dsInfo info
args_ok <- allM no_tyvars_tyfams normal_args
case dinfo of
DTyConI (DTySynD _n tvbs rhs) _
| length normal_args >= length tvbs
-> do
let (syn_args, rest_args) = splitAtList tvbs normal_args
ty <- substTy (M.fromList $ zip (map extractDTvbName tvbs) syn_args) rhs
ty' <- expand_type ign ty
return $ applyDType ty' $ map DTANormal rest_args
DTyConI (DOpenTypeFamilyD (DTypeFamilyHead _n tvbs _frs _ann)) _
| length normal_args >= length tvbs
#if __GLASGOW_HASKELL__ < 709
, args_ok
#endif
-> do
let (syn_args, rest_args) = splitAtList tvbs normal_args
insts <- qRecover (return []) $
qReifyInstances n (map typeToTH syn_args)
dinsts <- dsDecs insts
case dinsts of
[DTySynInstD (DTySynEqn _ lhs rhs)]
| (_, lhs_args) <- unfoldDType lhs
, let lhs_normal_args = filterDTANormals lhs_args
, Just subst <-
unionMaybeSubsts $ zipWith (matchTy ign) lhs_normal_args syn_args
-> do ty <- substTy subst rhs
ty' <- expand_type ign ty
return $ applyDType ty' $ map DTANormal rest_args
_ -> give_up
DTyConI (DClosedTypeFamilyD (DTypeFamilyHead _n tvbs _frs _ann) eqns) _
| length normal_args >= length tvbs
, args_ok
-> do
let (syn_args, rest_args) = splitAtList tvbs normal_args
rhss <- mapMaybeM (check_eqn syn_args) eqns
case rhss of
(rhs : _) -> do
rhs' <- expand_type ign rhs
return $ applyDType rhs' $ map DTANormal rest_args
[] -> give_up
where
check_eqn :: [DType] -> DTySynEqn -> q (Maybe DType)
check_eqn arg_tys (DTySynEqn _ lhs rhs) = do
let (_, lhs_args) = unfoldDType lhs
normal_lhs_args = filterDTANormals lhs_args
m_subst = unionMaybeSubsts $ zipWith (matchTy ign) normal_lhs_args arg_tys
T.mapM (flip substTy rhs) m_subst
_ -> give_up
give_up :: q DType
give_up = return $ applyDType (DConT n) args
no_tyvars_tyfams :: DType -> q Bool
no_tyvars_tyfams = go_ty
where
go_ty :: DType -> q Bool
go_ty (DVarT _) = return False
go_ty (DConT con_name) = do
m_info <- dsReify con_name
return $ case m_info of
Nothing -> False
Just (DTyConI (DOpenTypeFamilyD {}) _) -> False
Just (DTyConI (DDataFamilyD {}) _) -> False
Just (DTyConI (DClosedTypeFamilyD {}) _) -> False
_ -> True
go_ty (DForallT tvbs ctxt ty) =
liftM3 (\x y z -> x && y && z)
(allM go_tvb tvbs) (allM go_ty ctxt) (go_ty ty)
go_ty (DAppT t1 t2) = liftM2 (&&) (go_ty t1) (go_ty t2)
go_ty (DAppKindT t k) = liftM2 (&&) (go_ty t) (go_ty k)
go_ty (DSigT t k) = liftM2 (&&) (go_ty t) (go_ty k)
go_ty DLitT{} = return True
go_ty DArrowT = return True
go_ty DWildCardT = return True
go_tvb :: DTyVarBndr -> q Bool
go_tvb DPlainTV{} = return True
go_tvb (DKindedTV _ k) = go_ty k
allM :: Monad m => (a -> m Bool) -> [a] -> m Bool
allM f = foldM (\b x -> (b &&) `liftM` f x) True
extractDTvbName :: DTyVarBndr -> Name
extractDTvbName (DPlainTV n) = n
extractDTvbName (DKindedTV n _) = n
expand :: (DsMonad q, Data a) => a -> q a
expand = expand_ NoIgnore
expandUnsoundly :: (DsMonad q, Data a) => a -> q a
expandUnsoundly = expand_ YesIgnore
expand_ :: (DsMonad q, Data a) => IgnoreKinds -> a -> q a
expand_ ign = everywhereM (mkM (expand_type ign))