{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module GHC.Stg.InferTags.Rewrite (rewriteTopBinds)
where
import GHC.Prelude
import GHC.Types.Id
import GHC.Types.Name
import GHC.Types.Unique.Supply
import GHC.Types.Unique.FM
import GHC.Types.RepType
import GHC.Unit.Types (Module)
import GHC.Core.DataCon
import GHC.Core (AltCon(..) )
import GHC.Core.Type
import GHC.StgToCmm.Types
import GHC.Stg.Utils
import GHC.Stg.Syntax as StgSyn
import GHC.Data.Maybe
import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
import GHC.Utils.Outputable
import GHC.Utils.Monad.State.Strict
import GHC.Utils.Misc
import GHC.Stg.InferTags.Types
import Control.Monad
import GHC.Types.Basic (CbvMark (NotMarkedCbv, MarkedCbv), isMarkedCbv, TopLevelFlag(..), isTopLevel)
import GHC.Types.Var.Set
newtype RM a = RM { forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM :: (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a) }
deriving ((forall a b. (a -> b) -> RM a -> RM b)
-> (forall a b. a -> RM b -> RM a) -> Functor RM
forall a b. a -> RM b -> RM a
forall a b. (a -> b) -> RM a -> RM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> RM a -> RM b
fmap :: forall a b. (a -> b) -> RM a -> RM b
$c<$ :: forall a b. a -> RM b -> RM a
<$ :: forall a b. a -> RM b -> RM a
Functor, Applicative RM
Applicative RM
-> (forall a b. RM a -> (a -> RM b) -> RM b)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a. a -> RM a)
-> Monad RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM a -> (a -> RM b) -> RM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. RM a -> (a -> RM b) -> RM b
>>= :: forall a b. RM a -> (a -> RM b) -> RM b
$c>> :: forall a b. RM a -> RM b -> RM b
>> :: forall a b. RM a -> RM b -> RM b
$creturn :: forall a. a -> RM a
return :: forall a. a -> RM a
Monad, Functor RM
Functor RM
-> (forall a. a -> RM a)
-> (forall a b. RM (a -> b) -> RM a -> RM b)
-> (forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a b. RM a -> RM b -> RM a)
-> Applicative RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM (a -> b) -> RM a -> RM b
forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> RM a
pure :: forall a. a -> RM a
$c<*> :: forall a b. RM (a -> b) -> RM a -> RM b
<*> :: forall a b. RM (a -> b) -> RM a -> RM b
$cliftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
liftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
$c*> :: forall a b. RM a -> RM b -> RM b
*> :: forall a b. RM a -> RM b -> RM b
$c<* :: forall a b. RM a -> RM b -> RM a
<* :: forall a b. RM a -> RM b -> RM a
Applicative)
instance MonadUnique RM where
getUniqueSupplyM :: RM UniqSupply
getUniqueSupplyM = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a b. (a -> b) -> a -> b
$ do
(UniqFM Id TagSig
m, UniqSupply
us, Module
mod,IdSet
lcls) <- State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
let (UniqSupply
us1, UniqSupply
us2) = UniqSupply -> (UniqSupply, UniqSupply)
splitUniqSupply UniqSupply
us
(forall s. s -> State s ()
put) (UniqFM Id TagSig
m,UniqSupply
us2,Module
mod,IdSet
lcls)
UniqSupply
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
forall a.
a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
forall (m :: * -> *) a. Monad m => a -> m a
return UniqSupply
us1
getMap :: RM (UniqFM Id TagSig)
getMap :: RM (UniqFM Id TagSig)
getMap = State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig))
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
fst,UniqSupply
_,Module
_,IdSet
_) -> UniqFM Id TagSig
fst) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> UniqFM Id TagSig)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)
setMap :: (UniqFM Id TagSig) -> RM ()
setMap :: UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
m = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
(UniqFM Id TagSig
_,UniqSupply
us,Module
mod,IdSet
lcls) <- State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) ()
forall s. s -> State s ()
put (UniqFM Id TagSig
m, UniqSupply
us,Module
mod,IdSet
lcls)
getMod :: RM Module
getMod :: RM Module
getMod = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a b. (a -> b) -> a -> b
$ ( (\(UniqFM Id TagSig
_,UniqSupply
_,Module
thrd,IdSet
_) -> Module
thrd) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> Module)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)
getFVs :: RM IdSet
getFVs :: RM IdSet
getFVs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
_,UniqSupply
_,Module
_,IdSet
lcls) -> IdSet
lcls) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> IdSet)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)
setFVs :: IdSet -> RM ()
setFVs :: IdSet -> RM ()
setFVs IdSet
fvs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
(UniqFM Id TagSig
tag_map,UniqSupply
us,Module
mod,IdSet
_lcls) <- State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) ()
forall s. s -> State s ()
put (UniqFM Id TagSig
tag_map, UniqSupply
us,Module
mod,IdSet
fvs)
withBind :: TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind :: forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag (StgNonRec BinderP 'InferTaggedBinders
bnd GenStgRhs 'InferTaggedBinders
_) RM a
cont = TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id, TagSig)
BinderP 'InferTaggedBinders
bnd RM a
cont
withBind TopLevelFlag
top_flag (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) RM a
cont = do
let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds :: ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
top_flag [(Id, TagSig)]
bnds RM a
cont
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind (StgNonRec (Id
id, TagSig
tag) GenStgRhs 'InferTaggedBinders
_) = do
UniqFM Id TagSig
s <- RM (UniqFM Id TagSig)
getMap
UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> Id -> TagSig -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> key -> elt -> UniqFM key elt
addToUFM UniqFM Id TagSig
s Id
id TagSig
tag
() -> RM ()
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
addTopBind (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) = do
let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
!UniqFM Id TagSig
s <- RM (UniqFM Id TagSig)
getMap
UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$! UniqFM Id TagSig -> [(Id, TagSig)] -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> [(key, elt)] -> UniqFM key elt
addListToUFM UniqFM Id TagSig
s [(Id, TagSig)]
bnds
withBinder :: TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder :: forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id
id,TagSig
sig) RM a
cont = do
UniqFM Id TagSig
oldMap <- RM (UniqFM Id TagSig)
getMap
UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> Id -> TagSig -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> key -> elt -> UniqFM key elt
addToUFM UniqFM Id TagSig
oldMap Id
id TagSig
sig
a
a <- if TopLevelFlag -> Bool
isTopLevel TopLevelFlag
top_flag
then RM a
cont
else Id -> RM a -> RM a
forall a. Id -> RM a -> RM a
withLcl Id
id RM a
cont
UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
oldMap
a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
withBinders :: TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders :: forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
TopLevel [(Id, TagSig)]
sigs RM a
cont = do
UniqFM Id TagSig
oldMap <- RM (UniqFM Id TagSig)
getMap
UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> [(Id, TagSig)] -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> [(key, elt)] -> UniqFM key elt
addListToUFM UniqFM Id TagSig
oldMap [(Id, TagSig)]
sigs
a
a <- RM a
cont
UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
oldMap
a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
sigs RM a
cont = do
UniqFM Id TagSig
oldMap <- RM (UniqFM Id TagSig)
getMap
IdSet
oldFvs <- RM IdSet
getFVs
UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> [(Id, TagSig)] -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> [(key, elt)] -> UniqFM key elt
addListToUFM UniqFM Id TagSig
oldMap [(Id, TagSig)]
sigs
IdSet -> RM ()
setFVs (IdSet -> RM ()) -> IdSet -> RM ()
forall a b. (a -> b) -> a -> b
$ IdSet -> [Id] -> IdSet
extendVarSetList IdSet
oldFvs (((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
sigs)
a
a <- RM a
cont
UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
oldMap
IdSet -> RM ()
setFVs IdSet
oldFvs
a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
withClosureLcls :: DIdSet -> RM a -> RM a
withClosureLcls :: forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
fvs RM a
act = do
IdSet
old_fvs <- RM IdSet
getFVs
let fvs' :: IdSet
fvs' = (Id -> IdSet -> IdSet) -> IdSet -> DIdSet -> IdSet
forall a. (Id -> a -> a) -> a -> DIdSet -> a
nonDetStrictFoldDVarSet ((IdSet -> Id -> IdSet) -> Id -> IdSet -> IdSet
forall a b c. (a -> b -> c) -> b -> a -> c
flip IdSet -> Id -> IdSet
extendVarSet) IdSet
old_fvs DIdSet
fvs
IdSet -> RM ()
setFVs IdSet
fvs'
a
r <- RM a
act
IdSet -> RM ()
setFVs IdSet
old_fvs
a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
withLcl :: Id -> RM a -> RM a
withLcl :: forall a. Id -> RM a -> RM a
withLcl Id
fv RM a
act = do
IdSet
old_fvs <- RM IdSet
getFVs
let fvs' :: IdSet
fvs' = IdSet -> Id -> IdSet
extendVarSet IdSet
old_fvs Id
fv
IdSet -> RM ()
setFVs IdSet
fvs'
a
r <- RM a
act
IdSet -> RM ()
setFVs IdSet
old_fvs
a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
isTagged :: Id -> RM Bool
isTagged :: Id -> RM Bool
isTagged Id
v = do
Module
this_mod <- RM Module
getMod
case Module -> Name -> Bool
nameIsLocalOrFrom Module
this_mod (Id -> Name
idName Id
v) of
Bool
True
| (() :: Constraint) => Type -> Bool
Type -> Bool
isUnliftedType (Id -> Type
idType Id
v)
-> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
| Bool
otherwise -> do
!UniqFM Id TagSig
s <- RM (UniqFM Id TagSig)
getMap
let !sig :: TagSig
sig = UniqFM Id TagSig -> TagSig -> Id -> TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> elt -> key -> elt
lookupWithDefaultUFM UniqFM Id TagSig
s (String -> SDoc -> TagSig
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"unknown Id:" (Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
v)) Id
v
Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RM Bool) -> Bool -> RM Bool
forall a b. (a -> b) -> a -> b
$ case TagSig
sig of
TagSig TagInfo
info ->
case TagInfo
info of
TagInfo
TagDunno -> Bool
False
TagInfo
TagProper -> Bool
True
TagInfo
TagTagged -> Bool
True
TagTuple [TagInfo]
_ -> Bool
True
Bool
False
| Just DataCon
con <- (Id -> Maybe DataCon
isDataConWorkId_maybe Id
v)
, DataCon -> Bool
isNullaryRepDataCon DataCon
con
-> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
| Just LambdaFormInfo
lf_info <- Id -> Maybe LambdaFormInfo
idLFInfo_maybe Id
v
-> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RM Bool) -> Bool -> RM Bool
forall a b. (a -> b) -> a -> b
$
case LambdaFormInfo
lf_info of
LFReEntrant {}
-> Bool
True
LFThunk {}
-> Bool
False
LFCon {}
-> Bool
True
LFUnknown {}
-> Bool
False
LFUnlifted {}
-> Bool
True
LFLetNoEscape {}
-> Bool
True
| Bool
otherwise
-> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
isArgTagged :: StgArg -> RM Bool
isArgTagged :: StgArg -> RM Bool
isArgTagged (StgLitArg Literal
_) = Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
isArgTagged (StgVarArg Id
v) = Id -> RM Bool
isTagged Id
v
mkLocalArgId :: Id -> RM Id
mkLocalArgId :: Id -> RM Id
mkLocalArgId Id
id = do
!Unique
u <- RM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
Id -> RM Id
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> RM Id) -> Id -> RM Id
forall a b. (a -> b) -> a -> b
$! Id -> Unique -> Id
setIdUnique (Id -> Id
localiseId Id
id) Unique
u
rewriteTopBinds :: Module -> UniqSupply -> [GenStgTopBinding 'InferTaggedBinders] -> [TgStgTopBinding]
rewriteTopBinds :: Module
-> UniqSupply
-> [GenStgTopBinding 'InferTaggedBinders]
-> [TgStgTopBinding]
rewriteTopBinds Module
mod UniqSupply
us [GenStgTopBinding 'InferTaggedBinders]
binds =
let doBinds :: RM [TgStgTopBinding]
doBinds = (GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding)
-> [GenStgTopBinding 'InferTaggedBinders] -> RM [TgStgTopBinding]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop [GenStgTopBinding 'InferTaggedBinders]
binds
in State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
-> (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> [TgStgTopBinding]
forall s a. State s a -> s -> a
evalState (RM [TgStgTopBinding]
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM RM [TgStgTopBinding]
doBinds) (UniqFM Id TagSig
forall a. Monoid a => a
mempty, UniqSupply
us, Module
mod, IdSet
forall a. Monoid a => a
mempty)
rewriteTop :: InferStgTopBinding -> RM TgStgTopBinding
rewriteTop :: GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop (StgTopStringLit Id
v ByteString
s) = TgStgTopBinding -> RM TgStgTopBinding
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgTopBinding -> RM TgStgTopBinding)
-> TgStgTopBinding -> RM TgStgTopBinding
forall a b. (a -> b) -> a -> b
$! (Id -> ByteString -> TgStgTopBinding
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
v ByteString
s)
rewriteTop (StgTopLifted GenStgBinding 'InferTaggedBinders
bind) = do
GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind GenStgBinding 'InferTaggedBinders
bind
(GenStgBinding 'CodeGen -> TgStgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted) (GenStgBinding 'CodeGen -> TgStgTopBinding)
-> RM (GenStgBinding 'CodeGen) -> RM TgStgTopBinding
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> (TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
TopLevel GenStgBinding 'InferTaggedBinders
bind)
rewriteBinds :: TopLevelFlag -> InferStgBinding -> RM (TgStgBinding)
rewriteBinds :: TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
_top_flag (StgNonRec BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs) = do
(!TgStgRhs
rhs) <- (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id, TagSig)
BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs
GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen))
-> GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$! (BinderP 'CodeGen -> TgStgRhs -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst (Id, TagSig)
BinderP 'InferTaggedBinders
v) TgStgRhs
rhs)
rewriteBinds TopLevelFlag
top_flag b :: GenStgBinding 'InferTaggedBinders
b@(StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) =
TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM (GenStgBinding 'CodeGen)
-> RM (GenStgBinding 'CodeGen)
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag GenStgBinding 'InferTaggedBinders
b (RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen))
-> RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
([TgStgRhs]
rhss) <- (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> RM [TgStgRhs]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (((Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs)
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen))
-> GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$! ([TgStgRhs] -> GenStgBinding 'CodeGen
mkRec [TgStgRhs]
rhss)
where
mkRec :: [TgStgRhs] -> TgStgBinding
mkRec :: [TgStgRhs] -> GenStgBinding 'CodeGen
mkRec [TgStgRhs]
rhss = [(BinderP 'CodeGen, TgStgRhs)] -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ([Id] -> [TgStgRhs] -> [(Id, TgStgRhs)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> Id)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst ((Id, TagSig) -> Id)
-> (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig))
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders)
-> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig)
forall a b. (a, b) -> a
fst) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) [TgStgRhs]
rhss)
rewriteRhs :: (Id,TagSig) -> InferStgRhs
-> RM (TgStgRhs)
rewriteRhs :: (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id
_id, TagSig
_tagSig) (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
cn [StgTickish]
ticks [StgArg]
args) = {-# SCC rewriteRhs_ #-} do
[Bool]
fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
let strictFields :: [(StgArg, Bool)]
strictFields =
DataCon -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([StgArg] -> [Bool] -> [(StgArg, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StgArg]
args [Bool]
fieldInfos) :: [(StgArg,Bool)]
let needsEval :: [StgArg]
needsEval = ((StgArg, Bool) -> StgArg) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, Bool) -> StgArg
forall a b. (a, b) -> a
fst ([(StgArg, Bool)] -> [StgArg])
-> ([(StgArg, Bool)] -> [(StgArg, Bool)])
-> [(StgArg, Bool)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
((StgArg, Bool) -> Bool) -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((StgArg, Bool) -> Bool) -> (StgArg, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StgArg, Bool) -> Bool
forall a b. (a, b) -> b
snd) ([(StgArg, Bool)] -> [StgArg]) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$
[(StgArg, Bool)]
strictFields :: [StgArg]
let evalArgs :: [Id]
evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]
if ([Id] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Id]
evalArgs)
then TgStgRhs -> RM TgStgRhs
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgRhs -> RM TgStgRhs) -> TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$! (CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> TgStgRhs
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
cn [StgTickish]
ticks [StgArg]
args)
else do
let ty_stub :: a
ty_stub = String -> a
forall a. String -> a
panic String
"mkSeqs shouldn't use the type arg"
TgStgExpr
conExpr <- [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
evalArgs (\[StgArg]
taggedArgs -> DataCon -> ConstructorNumber -> [StgArg] -> [Type] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [Type]
forall {a}. a
ty_stub)
DIdSet
fvs <- [StgArg] -> RM DIdSet
fvArgs [StgArg]
args
TgStgRhs -> RM TgStgRhs
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgRhs -> RM TgStgRhs) -> TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$! (XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> TgStgExpr
-> TgStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure DIdSet
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
ReEntrant [] (TgStgExpr -> TgStgRhs) -> TgStgExpr -> TgStgRhs
forall a b. (a -> b) -> a -> b
$! TgStgExpr
conExpr)
rewriteRhs (Id, TagSig)
_binding (StgRhsClosure XRhsClosure 'InferTaggedBinders
fvs CostCentreStack
ccs UpdateFlag
flag [BinderP 'InferTaggedBinders]
args GenStgExpr 'InferTaggedBinders
body) = do
TopLevelFlag -> [(Id, TagSig)] -> RM TgStgRhs -> RM TgStgRhs
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
DIdSet -> RM TgStgRhs -> RM TgStgRhs
forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
XRhsClosure 'InferTaggedBinders
fvs (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> TgStgExpr
-> TgStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'InferTaggedBinders
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
flag (((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args) (TgStgExpr -> TgStgRhs) -> RM TgStgExpr -> RM TgStgRhs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
False GenStgExpr 'InferTaggedBinders
body
fvArgs :: [StgArg] -> RM DVarSet
fvArgs :: [StgArg] -> RM DIdSet
fvArgs [StgArg]
args = do
IdSet
fv_lcls <- RM IdSet
getFVs
DIdSet -> RM DIdSet
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (DIdSet -> RM DIdSet) -> DIdSet -> RM DIdSet
forall a b. (a -> b) -> a -> b
$ [Id] -> DIdSet
mkDVarSet [ Id
v | StgVarArg Id
v <- [StgArg]
args, Id -> IdSet -> Bool
elemVarSet Id
v IdSet
fv_lcls]
type IsScrut = Bool
rewriteExpr :: IsScrut -> InferStgExpr -> RM TgStgExpr
rewriteExpr :: Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
_ (e :: GenStgExpr 'InferTaggedBinders
e@StgCase {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase GenStgExpr 'InferTaggedBinders
e
rewriteExpr Bool
_ (e :: GenStgExpr 'InferTaggedBinders
e@StgLet {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet GenStgExpr 'InferTaggedBinders
e
rewriteExpr Bool
_ (e :: GenStgExpr 'InferTaggedBinders
e@StgLetNoEscape {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
e
rewriteExpr Bool
isScrut (StgTick StgTickish
t GenStgExpr 'InferTaggedBinders
e) = StgTickish -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
t (TgStgExpr -> TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
isScrut GenStgExpr 'InferTaggedBinders
e
rewriteExpr Bool
_ e :: GenStgExpr 'InferTaggedBinders
e@(StgConApp {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr Bool
isScrut e :: GenStgExpr 'InferTaggedBinders
e@(StgApp {}) = Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp Bool
isScrut GenStgExpr 'InferTaggedBinders
e
rewriteExpr Bool
_ (StgLit Literal
lit) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (Literal -> TgStgExpr
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
rewriteExpr Bool
_ (StgOpApp StgOp
op [StgArg]
args Type
res_ty) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
args Type
res_ty)
rewriteCase :: InferStgExpr -> RM TgStgExpr
rewriteCase :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase (StgCase GenStgExpr 'InferTaggedBinders
scrut BinderP 'InferTaggedBinders
bndr AltType
alt_type [GenStgAlt 'InferTaggedBinders]
alts) =
TopLevelFlag -> (Id, TagSig) -> RM TgStgExpr -> RM TgStgExpr
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
NotTopLevel (Id, TagSig)
BinderP 'InferTaggedBinders
bndr (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$
(TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM
(TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase RM
(TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM TgStgExpr
-> RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
True GenStgExpr 'InferTaggedBinders
scrut RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM Id -> RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
Id -> RM Id
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst (Id, TagSig)
BinderP 'InferTaggedBinders
bndr) RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM AltType -> RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
AltType -> RM AltType
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AltType
alt_type RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM [GenStgAlt 'CodeGen] -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
(GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen))
-> [GenStgAlt 'InferTaggedBinders] -> RM [GenStgAlt 'CodeGen]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt [GenStgAlt 'InferTaggedBinders]
alts
rewriteCase GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. String -> a
panic String
"Impossible: nodeCase"
rewriteAlt :: InferStgAlt -> RM TgStgAlt
rewriteAlt :: GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt alt :: GenStgAlt 'InferTaggedBinders
alt@GenStgAlt{alt_con :: forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con=AltCon
_, alt_bndrs :: forall (pass :: StgPass). GenStgAlt pass -> [BinderP pass]
alt_bndrs=[BinderP 'InferTaggedBinders]
bndrs, alt_rhs :: forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs=GenStgExpr 'InferTaggedBinders
rhs} =
TopLevelFlag
-> [(Id, TagSig)]
-> RM (GenStgAlt 'CodeGen)
-> RM (GenStgAlt 'CodeGen)
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
bndrs (RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen))
-> RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
!TgStgExpr
rhs' <- Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
False GenStgExpr 'InferTaggedBinders
rhs
GenStgAlt 'CodeGen -> RM (GenStgAlt 'CodeGen)
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenStgAlt 'CodeGen -> RM (GenStgAlt 'CodeGen))
-> GenStgAlt 'CodeGen -> RM (GenStgAlt 'CodeGen)
forall a b. (a -> b) -> a -> b
$! GenStgAlt 'InferTaggedBinders
alt {alt_bndrs :: [BinderP 'CodeGen]
alt_bndrs = ((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
bndrs, alt_rhs :: TgStgExpr
alt_rhs = TgStgExpr
rhs'}
rewriteLet :: InferStgExpr -> RM TgStgExpr
rewriteLet :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet (StgLet XLet 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
(!GenStgBinding 'CodeGen
bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM TgStgExpr
-> RM TgStgExpr
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ do
!TgStgExpr
expr' <- Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
False GenStgExpr 'InferTaggedBinders
expr
TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (XLet 'CodeGen -> GenStgBinding 'CodeGen -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'InferTaggedBinders
XLet 'CodeGen
xt GenStgBinding 'CodeGen
bind' TgStgExpr
expr')
rewriteLet GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. String -> a
panic String
"Impossible"
rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
rewriteLetNoEscape :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape (StgLetNoEscape XLetNoEscape 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
(!GenStgBinding 'CodeGen
bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM TgStgExpr
-> RM TgStgExpr
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ do
!TgStgExpr
expr' <- Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr Bool
False GenStgExpr 'InferTaggedBinders
expr
TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (XLetNoEscape 'CodeGen
-> GenStgBinding 'CodeGen -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape XLetNoEscape 'InferTaggedBinders
XLetNoEscape 'CodeGen
xt GenStgBinding 'CodeGen
bind' TgStgExpr
expr')
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. String -> a
panic String
"Impossible"
rewriteConApp :: InferStgExpr -> RM TgStgExpr
rewriteConApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp (StgConApp DataCon
con ConstructorNumber
cn [StgArg]
args [Type]
tys) = do
[Bool]
fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
let strictIndices :: [(Bool, StgArg)]
strictIndices = DataCon -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([Bool] -> [StgArg] -> [(Bool, StgArg)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
fieldInfos [StgArg]
args) :: [(Bool, StgArg)]
let needsEval :: [StgArg]
needsEval = ((Bool, StgArg) -> StgArg) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, StgArg) -> StgArg
forall a b. (a, b) -> b
snd ([(Bool, StgArg)] -> [StgArg])
-> ([(Bool, StgArg)] -> [(Bool, StgArg)])
-> [(Bool, StgArg)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, StgArg) -> Bool) -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Bool, StgArg) -> Bool) -> (Bool, StgArg) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, StgArg) -> Bool
forall a b. (a, b) -> a
fst) ([(Bool, StgArg)] -> [StgArg]) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$ [(Bool, StgArg)]
strictIndices :: [StgArg]
let evalArgs :: [Id]
evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]
if (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Id] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Id]
evalArgs)
then do
[StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
evalArgs (\[StgArg]
taggedArgs -> DataCon -> ConstructorNumber -> [StgArg] -> [Type] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [Type]
tys)
else TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (DataCon -> ConstructorNumber -> [StgArg] -> [Type] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
args [Type]
tys)
rewriteConApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. String -> a
panic String
"Impossible"
rewriteApp :: IsScrut -> InferStgExpr -> RM TgStgExpr
rewriteApp :: Bool -> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp Bool
True (StgApp Id
f []) = do
Bool
f_tagged <- Id -> RM Bool
isTagged Id
f
let f' :: Id
f' = if Bool
f_tagged
then Id -> TagSig -> Id
setIdTagSig Id
f (TagInfo -> TagSig
TagSig TagInfo
TagProper)
else Id
f
TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f' []
rewriteApp Bool
_ (StgApp Id
f [StgArg]
args)
| Just [CbvMark]
marks <- Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f
, [CbvMark]
relevant_marks <- (CbvMark -> Bool) -> [CbvMark] -> [CbvMark]
forall a. (a -> Bool) -> [a] -> [a]
dropWhileEndLE (Bool -> Bool
not (Bool -> Bool) -> (CbvMark -> Bool) -> CbvMark -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CbvMark -> Bool
isMarkedCbv) [CbvMark]
marks
, (CbvMark -> Bool) -> [CbvMark] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any CbvMark -> Bool
isMarkedCbv [CbvMark]
relevant_marks
= Bool -> ([CbvMark] -> RM TgStgExpr) -> [CbvMark] -> RM TgStgExpr
forall a. HasCallStack => Bool -> a -> a
assert ([CbvMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CbvMark]
relevant_marks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [StgArg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StgArg]
args)
[CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks
where
unliftArg :: [CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks = do
[Bool]
argTags <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
let argInfo :: [(StgArg, CbvMark, Bool)]
argInfo = (StgArg -> CbvMark -> Bool -> (StgArg, CbvMark, Bool))
-> [StgArg] -> [CbvMark] -> [Bool] -> [(StgArg, CbvMark, Bool)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 ((,,)) [StgArg]
args ([CbvMark]
relevant_marks[CbvMark] -> [CbvMark] -> [CbvMark]
forall a. [a] -> [a] -> [a]
++CbvMark -> [CbvMark]
forall a. a -> [a]
repeat CbvMark
NotMarkedCbv) [Bool]
argTags :: [(StgArg, CbvMark, Bool)]
cbvArgInfo :: [(StgArg, CbvMark, Bool)]
cbvArgInfo = ((StgArg, CbvMark, Bool) -> Bool)
-> [(StgArg, CbvMark, Bool)] -> [(StgArg, CbvMark, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(StgArg, CbvMark, Bool)
x -> (StgArg, CbvMark, Bool) -> CbvMark
forall a b c. (a, b, c) -> b
sndOf3 (StgArg, CbvMark, Bool)
x CbvMark -> CbvMark -> Bool
forall a. Eq a => a -> a -> Bool
== CbvMark
MarkedCbv Bool -> Bool -> Bool
&& (StgArg, CbvMark, Bool) -> Bool
forall a b c. (a, b, c) -> c
thdOf3 (StgArg, CbvMark, Bool)
x Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False) [(StgArg, CbvMark, Bool)]
argInfo
cbvArgIds :: [Id]
cbvArgIds = [Id
x | StgVarArg Id
x <- ((StgArg, CbvMark, Bool) -> StgArg)
-> [(StgArg, CbvMark, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, CbvMark, Bool) -> StgArg
forall a b c. (a, b, c) -> a
fstOf3 [(StgArg, CbvMark, Bool)]
cbvArgInfo] :: [Id]
[StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
cbvArgIds (\[StgArg]
cbv_args -> Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
cbv_args)
rewriteApp Bool
_ (StgApp Id
f [StgArg]
args) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
args
rewriteApp Bool
_ GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. String -> a
panic String
"Impossible"
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
id Id
bndr !TgStgExpr
expr =
let altTy :: AltType
altTy = Id -> [GenStgAlt 'CodeGen] -> AltType
forall (p :: StgPass). Id -> [GenStgAlt p] -> AltType
mkStgAltTypeFromStgAlts Id
bndr [GenStgAlt 'CodeGen]
alt
alt :: [GenStgAlt 'CodeGen]
alt = [GenStgAlt {alt_con :: AltCon
alt_con = AltCon
DEFAULT, alt_bndrs :: [BinderP 'CodeGen]
alt_bndrs = [], alt_rhs :: TgStgExpr
alt_rhs = TgStgExpr
expr}]
in TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase (Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id []) Id
BinderP 'CodeGen
bndr AltType
altTy [GenStgAlt 'CodeGen]
alt
{-# INLINE mkSeqs #-}
mkSeqs :: [StgArg]
-> [Id]
-> ([StgArg] -> TgStgExpr)
-> RM TgStgExpr
mkSeqs :: [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
untaggedIds [StgArg] -> TgStgExpr
mkExpr = do
[(Id, Id)]
argMap <- (Id -> RM (Id, Id)) -> [Id] -> RM [(Id, Id)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Id
arg -> (Id
arg,) (Id -> (Id, Id)) -> RM Id -> RM (Id, Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> RM Id
mkLocalArgId Id
arg ) [Id]
untaggedIds :: RM [(InId, OutId)]
let [StgArg]
taggedArgs :: [StgArg]
= (StgArg -> StgArg) -> [StgArg] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (\StgArg
v -> case StgArg
v of
StgVarArg Id
v' -> Id -> StgArg
StgVarArg (Id -> StgArg) -> Id -> StgArg
forall a b. (a -> b) -> a -> b
$ Id -> Maybe Id -> Id
forall a. a -> Maybe a -> a
fromMaybe Id
v' (Maybe Id -> Id) -> Maybe Id -> Id
forall a b. (a -> b) -> a -> b
$ Id -> [(Id, Id)] -> Maybe Id
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Id
v' [(Id, Id)]
argMap
StgArg
lit -> StgArg
lit)
[StgArg]
args
let conBody :: TgStgExpr
conBody = [StgArg] -> TgStgExpr
mkExpr [StgArg]
taggedArgs
let body :: TgStgExpr
body = ((Id, Id) -> TgStgExpr -> TgStgExpr)
-> TgStgExpr -> [(Id, Id)] -> TgStgExpr
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Id
v,Id
bndr) TgStgExpr
expr -> Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
v Id
bndr TgStgExpr
expr) TgStgExpr
conBody [(Id, Id)]
argMap
TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! TgStgExpr
body
getStrictConArgs :: DataCon -> [a] -> [a]
getStrictConArgs :: forall a. DataCon -> [a] -> [a]
getStrictConArgs DataCon
con [a]
args
| DataCon -> Bool
isUnboxedTupleDataCon DataCon
con = []
| DataCon -> Bool
isUnboxedSumDataCon DataCon
con = []
| Bool
otherwise =
[ a
arg | (a
arg,StrictnessMark
MarkedStrict)
<- String -> [a] -> [StrictnessMark] -> [(a, StrictnessMark)]
forall a b. String -> [a] -> [b] -> [(a, b)]
zipEqual String
"getStrictConArgs"
[a]
args
((() :: Constraint) => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con)]