{-# LANGUAGE CPP #-}
module CoreSubst (
Subst(..),
TvSubstEnv, IdSubstEnv, InScopeSet,
deShadowBinds, substSpec, substRulesForImportedIds,
substTy, substCo, substExpr, substExprSC, substBind, substBindSC,
substUnfolding, substUnfoldingSC,
lookupIdSubst, lookupTCvSubst, substIdOcc,
substTickish, substDVarSet, substIdInfo,
emptySubst, mkEmptySubst, mkSubst, mkOpenSubst, substInScope, isEmptySubst,
extendIdSubst, extendIdSubstList, extendTCvSubst, extendTvSubstList,
extendSubst, extendSubstList, extendSubstWithVar, zapSubstEnv,
addInScopeSet, extendInScope, extendInScopeList, extendInScopeIds,
isInScope, setInScope, getTCvSubst, extendTvSubst, extendCvSubst,
delBndr, delBndrs,
substBndr, substBndrs, substRecBndrs, substTyVarBndr, substCoVarBndr,
cloneBndr, cloneBndrs, cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs,
) where
#include "HsVersions.h"
import CoreSyn
import CoreFVs
import CoreSeq
import CoreUtils
import qualified Type
import qualified Coercion
import Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
, isInScope, substTyVarBndr, cloneTyVarBndr )
import Coercion hiding ( substCo, substCoVarBndr )
import PrelNames
import VarSet
import VarEnv
import Id
import Name ( Name )
import Var
import IdInfo
import UniqSupply
import Maybes
import Util
import Outputable
import PprCore ()
import Data.List
data Subst
= Subst InScopeSet
IdSubstEnv
TvSubstEnv
CvSubstEnv
type IdSubstEnv = IdEnv CoreExpr
isEmptySubst :: Subst -> Bool
isEmptySubst (Subst _ id_env tv_env cv_env)
= isEmptyVarEnv id_env && isEmptyVarEnv tv_env && isEmptyVarEnv cv_env
emptySubst :: Subst
emptySubst = Subst emptyInScopeSet emptyVarEnv emptyVarEnv emptyVarEnv
mkEmptySubst :: InScopeSet -> Subst
mkEmptySubst in_scope = Subst in_scope emptyVarEnv emptyVarEnv emptyVarEnv
mkSubst :: InScopeSet -> TvSubstEnv -> CvSubstEnv -> IdSubstEnv -> Subst
mkSubst in_scope tvs cvs ids = Subst in_scope ids tvs cvs
substInScope :: Subst -> InScopeSet
substInScope (Subst in_scope _ _ _) = in_scope
zapSubstEnv :: Subst -> Subst
zapSubstEnv (Subst in_scope _ _ _) = Subst in_scope emptyVarEnv emptyVarEnv emptyVarEnv
extendIdSubst :: Subst -> Id -> CoreExpr -> Subst
extendIdSubst (Subst in_scope ids tvs cvs) v r
= ASSERT2( isNonCoVarId v, ppr v $$ ppr r )
Subst in_scope (extendVarEnv ids v r) tvs cvs
extendIdSubstList :: Subst -> [(Id, CoreExpr)] -> Subst
extendIdSubstList (Subst in_scope ids tvs cvs) prs
= ASSERT( all (isNonCoVarId . fst) prs )
Subst in_scope (extendVarEnvList ids prs) tvs cvs
extendTvSubst :: Subst -> TyVar -> Type -> Subst
extendTvSubst (Subst in_scope ids tvs cvs) tv ty
= ASSERT( isTyVar tv )
Subst in_scope ids (extendVarEnv tvs tv ty) cvs
extendTvSubstList :: Subst -> [(TyVar,Type)] -> Subst
extendTvSubstList subst vrs
= foldl' extend subst vrs
where
extend subst (v, r) = extendTvSubst subst v r
extendCvSubst :: Subst -> CoVar -> Coercion -> Subst
extendCvSubst (Subst in_scope ids tvs cvs) v r
= ASSERT( isCoVar v )
Subst in_scope ids tvs (extendVarEnv cvs v r)
extendSubst :: Subst -> Var -> CoreArg -> Subst
extendSubst subst var arg
= case arg of
Type ty -> ASSERT( isTyVar var ) extendTvSubst subst var ty
Coercion co -> ASSERT( isCoVar var ) extendCvSubst subst var co
_ -> ASSERT( isId var ) extendIdSubst subst var arg
extendSubstWithVar :: Subst -> Var -> Var -> Subst
extendSubstWithVar subst v1 v2
| isTyVar v1 = ASSERT( isTyVar v2 ) extendTvSubst subst v1 (mkTyVarTy v2)
| isCoVar v1 = ASSERT( isCoVar v2 ) extendCvSubst subst v1 (mkCoVarCo v2)
| otherwise = ASSERT( isId v2 ) extendIdSubst subst v1 (Var v2)
extendSubstList :: Subst -> [(Var,CoreArg)] -> Subst
extendSubstList subst [] = subst
extendSubstList subst ((var,rhs):prs) = extendSubstList (extendSubst subst var rhs) prs
lookupIdSubst :: SDoc -> Subst -> Id -> CoreExpr
lookupIdSubst doc (Subst in_scope ids _ _) v
| not (isLocalId v) = Var v
| Just e <- lookupVarEnv ids v = e
| Just v' <- lookupInScope in_scope v = Var v'
| otherwise = WARN( True, text "CoreSubst.lookupIdSubst" <+> doc <+> ppr v
$$ ppr in_scope)
Var v
lookupTCvSubst :: Subst -> TyVar -> Type
lookupTCvSubst (Subst _ _ tvs cvs) v
| isTyVar v
= lookupVarEnv tvs v `orElse` Type.mkTyVarTy v
| otherwise
= mkCoercionTy $ lookupVarEnv cvs v `orElse` mkCoVarCo v
delBndr :: Subst -> Var -> Subst
delBndr (Subst in_scope ids tvs cvs) v
| isCoVar v = Subst in_scope ids tvs (delVarEnv cvs v)
| isTyVar v = Subst in_scope ids (delVarEnv tvs v) cvs
| otherwise = Subst in_scope (delVarEnv ids v) tvs cvs
delBndrs :: Subst -> [Var] -> Subst
delBndrs (Subst in_scope ids tvs cvs) vs
= Subst in_scope (delVarEnvList ids vs) (delVarEnvList tvs vs) (delVarEnvList cvs vs)
mkOpenSubst :: InScopeSet -> [(Var,CoreArg)] -> Subst
mkOpenSubst in_scope pairs = Subst in_scope
(mkVarEnv [(id,e) | (id, e) <- pairs, isId id])
(mkVarEnv [(tv,ty) | (tv, Type ty) <- pairs])
(mkVarEnv [(v,co) | (v, Coercion co) <- pairs])
isInScope :: Var -> Subst -> Bool
isInScope v (Subst in_scope _ _ _) = v `elemInScopeSet` in_scope
addInScopeSet :: Subst -> VarSet -> Subst
addInScopeSet (Subst in_scope ids tvs cvs) vs
= Subst (in_scope `extendInScopeSetSet` vs) ids tvs cvs
extendInScope :: Subst -> Var -> Subst
extendInScope (Subst in_scope ids tvs cvs) v
= Subst (in_scope `extendInScopeSet` v)
(ids `delVarEnv` v) (tvs `delVarEnv` v) (cvs `delVarEnv` v)
extendInScopeList :: Subst -> [Var] -> Subst
extendInScopeList (Subst in_scope ids tvs cvs) vs
= Subst (in_scope `extendInScopeSetList` vs)
(ids `delVarEnvList` vs) (tvs `delVarEnvList` vs) (cvs `delVarEnvList` vs)
extendInScopeIds :: Subst -> [Id] -> Subst
extendInScopeIds (Subst in_scope ids tvs cvs) vs
= Subst (in_scope `extendInScopeSetList` vs)
(ids `delVarEnvList` vs) tvs cvs
setInScope :: Subst -> InScopeSet -> Subst
setInScope (Subst _ ids tvs cvs) in_scope = Subst in_scope ids tvs cvs
instance Outputable Subst where
ppr (Subst in_scope ids tvs cvs)
= text "<InScope =" <+> in_scope_doc
$$ text " IdSubst =" <+> ppr ids
$$ text " TvSubst =" <+> ppr tvs
$$ text " CvSubst =" <+> ppr cvs
<> char '>'
where
in_scope_doc = pprVarSet (getInScopeVars in_scope) (braces . fsep . map ppr)
substExprSC :: SDoc -> Subst -> CoreExpr -> CoreExpr
substExprSC doc subst orig_expr
| isEmptySubst subst = orig_expr
| otherwise =
subst_expr doc subst orig_expr
substExpr :: SDoc -> Subst -> CoreExpr -> CoreExpr
substExpr doc subst orig_expr = subst_expr doc subst orig_expr
subst_expr :: SDoc -> Subst -> CoreExpr -> CoreExpr
subst_expr doc subst expr
= go expr
where
go (Var v) = lookupIdSubst (doc $$ text "subst_expr") subst v
go (Type ty) = Type (substTy subst ty)
go (Coercion co) = Coercion (substCo subst co)
go (Lit lit) = Lit lit
go (App fun arg) = App (go fun) (go arg)
go (Tick tickish e) = mkTick (substTickish subst tickish) (go e)
go (Cast e co) = Cast (go e) (substCo subst co)
go (Lam bndr body) = Lam bndr' (subst_expr doc subst' body)
where
(subst', bndr') = substBndr subst bndr
go (Let bind body) = Let bind' (subst_expr doc subst' body)
where
(subst', bind') = substBind subst bind
go (Case scrut bndr ty alts) = Case (go scrut) bndr' (substTy subst ty) (map (go_alt subst') alts)
where
(subst', bndr') = substBndr subst bndr
go_alt subst (con, bndrs, rhs) = (con, bndrs', subst_expr doc subst' rhs)
where
(subst', bndrs') = substBndrs subst bndrs
substBind, substBindSC :: Subst -> CoreBind -> (Subst, CoreBind)
substBindSC subst bind
| not (isEmptySubst subst)
= substBind subst bind
| otherwise
= case bind of
NonRec bndr rhs -> (subst', NonRec bndr' rhs)
where
(subst', bndr') = substBndr subst bndr
Rec pairs -> (subst', Rec (bndrs' `zip` rhss'))
where
(bndrs, rhss) = unzip pairs
(subst', bndrs') = substRecBndrs subst bndrs
rhss' | isEmptySubst subst'
= rhss
| otherwise
= map (subst_expr (text "substBindSC") subst') rhss
substBind subst (NonRec bndr rhs)
= (subst', NonRec bndr' (subst_expr (text "substBind") subst rhs))
where
(subst', bndr') = substBndr subst bndr
substBind subst (Rec pairs)
= (subst', Rec (bndrs' `zip` rhss'))
where
(bndrs, rhss) = unzip pairs
(subst', bndrs') = substRecBndrs subst bndrs
rhss' = map (subst_expr (text "substBind") subst') rhss
deShadowBinds :: CoreProgram -> CoreProgram
deShadowBinds binds = snd (mapAccumL substBind emptySubst binds)
substBndr :: Subst -> Var -> (Subst, Var)
substBndr subst bndr
| isTyVar bndr = substTyVarBndr subst bndr
| isCoVar bndr = substCoVarBndr subst bndr
| otherwise = substIdBndr (text "var-bndr") subst subst bndr
substBndrs :: Subst -> [Var] -> (Subst, [Var])
substBndrs subst bndrs = mapAccumL substBndr subst bndrs
substRecBndrs :: Subst -> [Id] -> (Subst, [Id])
substRecBndrs subst bndrs
= (new_subst, new_bndrs)
where
(new_subst, new_bndrs) = mapAccumL (substIdBndr (text "rec-bndr") new_subst) subst bndrs
substIdBndr :: SDoc
-> Subst
-> Subst -> Id
-> (Subst, Id)
substIdBndr _doc rec_subst subst@(Subst in_scope env tvs cvs) old_id
=
(Subst (in_scope `extendInScopeSet` new_id) new_env tvs cvs, new_id)
where
id1 = uniqAway in_scope old_id
id2 | no_type_change = id1
| otherwise = setIdType id1 (substTy subst old_ty)
old_ty = idType old_id
no_type_change = (isEmptyVarEnv tvs && isEmptyVarEnv cvs) ||
noFreeVarsOfType old_ty
new_id = maybeModifyIdInfo mb_new_info id2
mb_new_info = substIdInfo rec_subst id2 (idInfo id2)
new_env | no_change = delVarEnv env old_id
| otherwise = extendVarEnv env old_id (Var new_id)
no_change = id1 == old_id
cloneIdBndr :: Subst -> UniqSupply -> Id -> (Subst, Id)
cloneIdBndr subst us old_id
= clone_id subst subst (old_id, uniqFromSupply us)
cloneIdBndrs :: Subst -> UniqSupply -> [Id] -> (Subst, [Id])
cloneIdBndrs subst us ids
= mapAccumL (clone_id subst) subst (ids `zip` uniqsFromSupply us)
cloneBndrs :: Subst -> UniqSupply -> [Var] -> (Subst, [Var])
cloneBndrs subst us vs
= mapAccumL (\subst (v, u) -> cloneBndr subst u v) subst (vs `zip` uniqsFromSupply us)
cloneBndr :: Subst -> Unique -> Var -> (Subst, Var)
cloneBndr subst uniq v
| isTyVar v = cloneTyVarBndr subst v uniq
| otherwise = clone_id subst subst (v,uniq)
cloneRecIdBndrs :: Subst -> UniqSupply -> [Id] -> (Subst, [Id])
cloneRecIdBndrs subst us ids
= (subst', ids')
where
(subst', ids') = mapAccumL (clone_id subst') subst
(ids `zip` uniqsFromSupply us)
clone_id :: Subst
-> Subst -> (Id, Unique)
-> (Subst, Id)
clone_id rec_subst subst@(Subst in_scope idvs tvs cvs) (old_id, uniq)
= (Subst (in_scope `extendInScopeSet` new_id) new_idvs tvs new_cvs, new_id)
where
id1 = setVarUnique old_id uniq
id2 = substIdType subst id1
new_id = maybeModifyIdInfo (substIdInfo rec_subst id2 (idInfo old_id)) id2
(new_idvs, new_cvs) | isCoVar old_id = (idvs, extendVarEnv cvs old_id (mkCoVarCo new_id))
| otherwise = (extendVarEnv idvs old_id (Var new_id), cvs)
substTyVarBndr :: Subst -> TyVar -> (Subst, TyVar)
substTyVarBndr (Subst in_scope id_env tv_env cv_env) tv
= case Type.substTyVarBndr (TCvSubst in_scope tv_env cv_env) tv of
(TCvSubst in_scope' tv_env' cv_env', tv')
-> (Subst in_scope' id_env tv_env' cv_env', tv')
cloneTyVarBndr :: Subst -> TyVar -> Unique -> (Subst, TyVar)
cloneTyVarBndr (Subst in_scope id_env tv_env cv_env) tv uniq
= case Type.cloneTyVarBndr (TCvSubst in_scope tv_env cv_env) tv uniq of
(TCvSubst in_scope' tv_env' cv_env', tv')
-> (Subst in_scope' id_env tv_env' cv_env', tv')
substCoVarBndr :: Subst -> TyVar -> (Subst, TyVar)
substCoVarBndr (Subst in_scope id_env tv_env cv_env) cv
= case Coercion.substCoVarBndr (TCvSubst in_scope tv_env cv_env) cv of
(TCvSubst in_scope' tv_env' cv_env', cv')
-> (Subst in_scope' id_env tv_env' cv_env', cv')
substTy :: Subst -> Type -> Type
substTy subst ty = Type.substTyUnchecked (getTCvSubst subst) ty
getTCvSubst :: Subst -> TCvSubst
getTCvSubst (Subst in_scope _ tenv cenv) = TCvSubst in_scope tenv cenv
substCo :: Subst -> Coercion -> Coercion
substCo subst co = Coercion.substCo (getTCvSubst subst) co
substIdType :: Subst -> Id -> Id
substIdType subst@(Subst _ _ tv_env cv_env) id
| (isEmptyVarEnv tv_env && isEmptyVarEnv cv_env) || noFreeVarsOfType old_ty = id
| otherwise = setIdType id (substTy subst old_ty)
where
old_ty = idType id
substIdInfo :: Subst -> Id -> IdInfo -> Maybe IdInfo
substIdInfo subst new_id info
| nothing_to_do = Nothing
| otherwise = Just (info `setRuleInfo` substSpec subst new_id old_rules
`setUnfoldingInfo` substUnfolding subst old_unf)
where
old_rules = ruleInfo info
old_unf = unfoldingInfo info
nothing_to_do = isEmptyRuleInfo old_rules && not (isFragileUnfolding old_unf)
substUnfolding, substUnfoldingSC :: Subst -> Unfolding -> Unfolding
substUnfoldingSC subst unf
| isEmptySubst subst = unf
| otherwise = substUnfolding subst unf
substUnfolding subst df@(DFunUnfolding { df_bndrs = bndrs, df_args = args })
= df { df_bndrs = bndrs', df_args = args' }
where
(subst',bndrs') = substBndrs subst bndrs
args' = map (substExpr (text "subst-unf:dfun") subst') args
substUnfolding subst unf@(CoreUnfolding { uf_tmpl = tmpl, uf_src = src })
| not (isStableSource src)
= NoUnfolding
| otherwise
= seqExpr new_tmpl `seq`
unf { uf_tmpl = new_tmpl }
where
new_tmpl = substExpr (text "subst-unf") subst tmpl
substUnfolding _ unf = unf
substIdOcc :: Subst -> Id -> Id
substIdOcc subst v = case lookupIdSubst (text "substIdOcc") subst v of
Var v' -> v'
other -> pprPanic "substIdOcc" (vcat [ppr v <+> ppr other, ppr subst])
substSpec :: Subst -> Id -> RuleInfo -> RuleInfo
substSpec subst new_id (RuleInfo rules rhs_fvs)
= seqRuleInfo new_spec `seq` new_spec
where
subst_ru_fn = const (idName new_id)
new_spec = RuleInfo (map (substRule subst subst_ru_fn) rules)
(substDVarSet subst rhs_fvs)
substRulesForImportedIds :: Subst -> [CoreRule] -> [CoreRule]
substRulesForImportedIds subst rules
= map (substRule subst not_needed) rules
where
not_needed name = pprPanic "substRulesForImportedIds" (ppr name)
substRule :: Subst -> (Name -> Name) -> CoreRule -> CoreRule
substRule _ _ rule@(BuiltinRule {}) = rule
substRule subst subst_ru_fn rule@(Rule { ru_bndrs = bndrs, ru_args = args
, ru_fn = fn_name, ru_rhs = rhs
, ru_local = is_local })
= rule { ru_bndrs = bndrs'
, ru_fn = if is_local
then subst_ru_fn fn_name
else fn_name
, ru_args = map (substExpr doc subst') args
, ru_rhs = substExpr (text "foo") subst' rhs }
where
doc = text "subst-rule" <+> ppr fn_name
(subst', bndrs') = substBndrs subst bndrs
substDVarSet :: Subst -> DVarSet -> DVarSet
substDVarSet subst fvs
= mkDVarSet $ fst $ foldr (subst_fv subst) ([], emptyVarSet) $ dVarSetElems fvs
where
subst_fv subst fv acc
| isId fv = expr_fvs (lookupIdSubst (text "substDVarSet") subst fv) isLocalVar emptyVarSet $! acc
| otherwise = tyCoFVsOfType (lookupTCvSubst subst fv) (const True) emptyVarSet $! acc
substTickish :: Subst -> Tickish Id -> Tickish Id
substTickish subst (Breakpoint n ids)
= Breakpoint n (map do_one ids)
where
do_one = getIdFromTrivialExpr . lookupIdSubst (text "subst_tickish") subst
substTickish _subst other = other