--
-- Copyright (c) 2019 Andreas Klebinger
--

{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE CPP                        #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TupleSections              #-}
{-# LANGUAGE TypeFamilies               #-}

module GHC.Stg.InferTags.Rewrite (rewriteTopBinds)
where

import GHC.Prelude

import GHC.Builtin.PrimOps ( PrimOp(..) )
import GHC.Types.Basic     ( CbvMark (..), isMarkedCbv
                           , TopLevelFlag(..), isTopLevel
                           , Levity(..) )
import GHC.Types.Id
import GHC.Types.Name
import GHC.Types.Unique.Supply
import GHC.Types.Unique.FM
import GHC.Types.RepType
import GHC.Types.Var.Set
import GHC.Unit.Types

import GHC.Core.DataCon
import GHC.Core            ( AltCon(..) )
import GHC.Core.Type

import GHC.StgToCmm.Types

import GHC.Stg.Utils
import GHC.Stg.Syntax as StgSyn

import GHC.Data.Maybe
import GHC.Utils.Panic

import GHC.Utils.Outputable
import GHC.Utils.Monad.State.Strict
import GHC.Utils.Misc

import GHC.Stg.InferTags.Types

import Control.Monad

newtype RM a = RM { forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM :: (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a) }
    deriving ((forall a b. (a -> b) -> RM a -> RM b)
-> (forall a b. a -> RM b -> RM a) -> Functor RM
forall a b. a -> RM b -> RM a
forall a b. (a -> b) -> RM a -> RM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> RM a -> RM b
fmap :: forall a b. (a -> b) -> RM a -> RM b
$c<$ :: forall a b. a -> RM b -> RM a
<$ :: forall a b. a -> RM b -> RM a
Functor, Applicative RM
Applicative RM =>
(forall a b. RM a -> (a -> RM b) -> RM b)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a. a -> RM a)
-> Monad RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM a -> (a -> RM b) -> RM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. RM a -> (a -> RM b) -> RM b
>>= :: forall a b. RM a -> (a -> RM b) -> RM b
$c>> :: forall a b. RM a -> RM b -> RM b
>> :: forall a b. RM a -> RM b -> RM b
$creturn :: forall a. a -> RM a
return :: forall a. a -> RM a
Monad, Functor RM
Functor RM =>
(forall a. a -> RM a)
-> (forall a b. RM (a -> b) -> RM a -> RM b)
-> (forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a b. RM a -> RM b -> RM a)
-> Applicative RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM (a -> b) -> RM a -> RM b
forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> RM a
pure :: forall a. a -> RM a
$c<*> :: forall a b. RM (a -> b) -> RM a -> RM b
<*> :: forall a b. RM (a -> b) -> RM a -> RM b
$cliftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
liftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
$c*> :: forall a b. RM a -> RM b -> RM b
*> :: forall a b. RM a -> RM b -> RM b
$c<* :: forall a b. RM a -> RM b -> RM a
<* :: forall a b. RM a -> RM b -> RM a
Applicative)

------------------------------------------------------------
-- Add cases around strict fields where required.
------------------------------------------------------------
{-
The work of this pass is simple:
* We traverse the STG AST looking for constructor allocations.
* For all allocations we check if there are strict fields in the constructor.
* For any strict field we check if the argument is known to be properly tagged.
* If it's not known to be properly tagged, we wrap the whole thing in a case,
  which will force the argument before allocation.
This is described in detail in Note [Strict Field Invariant].

The only slight complication is that we have to make sure not to invalidate free
variable analysis in the process.

Note [Partially applied workers]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sometimes we will get a function f of the form
    -- Arity 1
    f :: Dict a -> a -> b -> (c -> d)
    f dict a b = case dict of
        C m1 m2 -> m1 a b

Which will result in a W/W split along the lines of
    -- Arity 1
    f :: Dict a -> a -> b -> (c -> d)
    f dict a = case dict of
        C m1 m2 -> $wf m1 a b

    -- Arity 4
    $wf :: (a -> b -> d -> c) -> a -> b -> c -> d
    $wf m1 a b c = m1 a b c

It's notable that the worker is called *undersaturated* in the wrapper.
At runtime what happens is that the wrapper will allocate a PAP which
once fully applied will call the worker. And all is fine.

But what about a call by value function! Well the function returned by `f` would
be a unknown call, so we lose the ability to enforce the invariant that
cbv marked arguments from StictWorkerId's are actually properly tagged
as the annotations would be unavailable at the (unknown) call site.

The fix is easy. We eta-expand all calls to functions taking call-by-value
arguments during CorePrep just like we do with constructor allocations.

Note [Upholding free variable annotations]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The code generator requires us to maintain exact information
about free variables about closures. Since we convert some
RHSs from constructor allocations to closures we have to provide
fvs of these closures. Not all constructor arguments will become
free variables. Only these which are not bound at the top level
have to be captured.
To facilitate this we keep track of a set of locally bound variables in
the current context which we then use to filter constructor arguments
when building the free variable list.
-}

--------------------------------
-- Utilities
--------------------------------

instance MonadUnique RM where
    getUniqueSupplyM :: RM UniqSupply
getUniqueSupplyM = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
 -> RM UniqSupply)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a b. (a -> b) -> a -> b
$ do
        (UniqFM Id TagSig
m, UniqSupply
us, Module
mod,IdSet
lcls) <- State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
        let (UniqSupply
us1, UniqSupply
us2) = UniqSupply -> (UniqSupply, UniqSupply)
splitUniqSupply UniqSupply
us
        (forall s. s -> State s ()
put) (UniqFM Id TagSig
m,UniqSupply
us2,Module
mod,IdSet
lcls)
        UniqSupply
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
forall a.
a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
forall (m :: * -> *) a. Monad m => a -> m a
return UniqSupply
us1

getMap :: RM (UniqFM Id TagSig)
getMap :: RM (UniqFM Id TagSig)
getMap = State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State
   (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
 -> RM (UniqFM Id TagSig))
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
fst,UniqSupply
_,Module
_,IdSet
_) -> UniqFM Id TagSig
fst) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> UniqFM Id TagSig)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)

setMap :: (UniqFM Id TagSig) -> RM ()
setMap :: UniqFM Id TagSig -> RM ()
setMap !UniqFM Id TagSig
m = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
    (UniqFM Id TagSig
_,UniqSupply
us,Module
mod,IdSet
lcls) <- State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
    (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) ()
forall s. s -> State s ()
put (UniqFM Id TagSig
m, UniqSupply
us,Module
mod,IdSet
lcls)

getMod :: RM Module
getMod :: RM Module
getMod = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
 -> RM Module)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a b. (a -> b) -> a -> b
$ ( (\(UniqFM Id TagSig
_,UniqSupply
_,Module
thrd,IdSet
_) -> Module
thrd) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> Module)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)

getFVs :: RM IdSet
getFVs :: RM IdSet
getFVs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
 -> RM IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
_,UniqSupply
_,Module
_,IdSet
lcls) -> IdSet
lcls) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> IdSet)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)

setFVs :: IdSet -> RM ()
setFVs :: IdSet -> RM ()
setFVs !IdSet
fvs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
    (UniqFM Id TagSig
tag_map,UniqSupply
us,Module
mod,IdSet
_lcls) <- State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
    (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) ()
forall s. s -> State s ()
put (UniqFM Id TagSig
tag_map, UniqSupply
us,Module
mod,IdSet
fvs)

-- Rewrite the RHS(s) while making the id and it's sig available
-- to determine if things are tagged/need to be captured as FV.
withBind :: TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind :: forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag (StgNonRec BinderP 'InferTaggedBinders
bnd GenStgRhs 'InferTaggedBinders
_) RM a
cont = TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id, TagSig)
BinderP 'InferTaggedBinders
bnd RM a
cont
withBind TopLevelFlag
top_flag (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) RM a
cont = do
    let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds :: ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
    TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
top_flag [(Id, TagSig)]
bnds RM a
cont

addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind (StgNonRec (Id
id, TagSig
tag) GenStgRhs 'InferTaggedBinders
_) = do
    UniqFM Id TagSig
s <- RM (UniqFM Id TagSig)
getMap
    -- pprTraceM "AddBind" (ppr id)
    UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> Id -> TagSig -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> key -> elt -> UniqFM key elt
addToUFM UniqFM Id TagSig
s Id
id TagSig
tag
    () -> RM ()
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
addTopBind (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) = do
    let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
    !UniqFM Id TagSig
s <- RM (UniqFM Id TagSig)
getMap
    -- pprTraceM "AddBinds" (ppr $ map fst bnds)
    UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$! UniqFM Id TagSig -> [(Id, TagSig)] -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> [(key, elt)] -> UniqFM key elt
addListToUFM UniqFM Id TagSig
s [(Id, TagSig)]
bnds

withBinder :: TopLevelFlag ->  (Id, TagSig) -> RM a -> RM a
withBinder :: forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id
id,TagSig
sig) RM a
cont = do
    UniqFM Id TagSig
oldMap <- RM (UniqFM Id TagSig)
getMap
    UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> Id -> TagSig -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> key -> elt -> UniqFM key elt
addToUFM UniqFM Id TagSig
oldMap Id
id TagSig
sig
    a
a <- if TopLevelFlag -> Bool
isTopLevel TopLevelFlag
top_flag
            then RM a
cont
            else Id -> RM a -> RM a
forall a. Id -> RM a -> RM a
withLcl Id
id RM a
cont
    UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
oldMap
    a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

withBinders :: TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders :: forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
TopLevel [(Id, TagSig)]
sigs RM a
cont = do
    UniqFM Id TagSig
oldMap <- RM (UniqFM Id TagSig)
getMap
    UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> [(Id, TagSig)] -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> [(key, elt)] -> UniqFM key elt
addListToUFM UniqFM Id TagSig
oldMap [(Id, TagSig)]
sigs
    a
a <- RM a
cont
    UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
oldMap
    a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
sigs RM a
cont = do
    UniqFM Id TagSig
oldMap <- RM (UniqFM Id TagSig)
getMap
    IdSet
oldFvs <- RM IdSet
getFVs
    UniqFM Id TagSig -> RM ()
setMap (UniqFM Id TagSig -> RM ()) -> UniqFM Id TagSig -> RM ()
forall a b. (a -> b) -> a -> b
$ UniqFM Id TagSig -> [(Id, TagSig)] -> UniqFM Id TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> [(key, elt)] -> UniqFM key elt
addListToUFM UniqFM Id TagSig
oldMap [(Id, TagSig)]
sigs
    IdSet -> RM ()
setFVs (IdSet -> RM ()) -> IdSet -> RM ()
forall a b. (a -> b) -> a -> b
$ IdSet -> [Id] -> IdSet
extendVarSetList IdSet
oldFvs (((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
sigs)
    a
a <- RM a
cont
    UniqFM Id TagSig -> RM ()
setMap UniqFM Id TagSig
oldMap
    IdSet -> RM ()
setFVs IdSet
oldFvs
    a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- | Compute the argument with the given set of ids treated as requiring capture
-- as free variables.
withClosureLcls :: DIdSet -> RM a -> RM a
withClosureLcls :: forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
fvs RM a
act = do
    IdSet
old_fvs <- RM IdSet
getFVs
    let !fvs' :: IdSet
fvs' = (Id -> IdSet -> IdSet) -> IdSet -> DIdSet -> IdSet
forall a. (Id -> a -> a) -> a -> DIdSet -> a
nonDetStrictFoldDVarSet ((IdSet -> Id -> IdSet) -> Id -> IdSet -> IdSet
forall a b c. (a -> b -> c) -> b -> a -> c
flip IdSet -> Id -> IdSet
extendVarSet) IdSet
old_fvs DIdSet
fvs
    IdSet -> RM ()
setFVs IdSet
fvs'
    !a
r <- RM a
act
    IdSet -> RM ()
setFVs IdSet
old_fvs
    a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | Compute the argument with the given id treated as requiring capture
-- as free variables in closures.
withLcl :: Id -> RM a -> RM a
withLcl :: forall a. Id -> RM a -> RM a
withLcl Id
fv RM a
act = do
    IdSet
old_fvs <- RM IdSet
getFVs
    let !fvs' :: IdSet
fvs' = IdSet -> Id -> IdSet
extendVarSet IdSet
old_fvs Id
fv
    IdSet -> RM ()
setFVs IdSet
fvs'
    !a
r <- RM a
act
    IdSet -> RM ()
setFVs IdSet
old_fvs
    a -> RM a
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

{- Note [Tag inference for interactive contexts]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When compiling bytecode we call myCoreToStg to get STG code first.
myCoreToStg in turn calls out to stg2stg which runs the STG to STG
passes followed by free variables analysis and the tag inference pass including
it's rewriting phase at the end.
Running tag inference is important as it upholds Note [Strict Field Invariant].
While code executed by GHCi doesn't take advantage of the SFI it can call into
compiled code which does. So it must still make sure that the SFI is upheld.
See also #21083 and #22042.

However there one important difference in code generation for GHCi and regular
compilation. When compiling an entire module (not a GHCi expression), we call
`stg2stg` on the entire module which allows us to build up a map which is guaranteed
to have an entry for every binder in the current module.
For non-interactive compilation the tag inference rewrite pass takes advantage
of this by building up a map from binders to their tag signatures.

When compiling a GHCi expression on the other hand we invoke stg2stg separately
for each expression on the prompt. This means in GHCi for a sequence of:
    > let x = True
    > let y = StrictJust x
We first run stg2stg for `[x = True]`. And then again for [y = StrictJust x]`.

While computing the tag signature for `y` during tag inference inferConTag will check
if `x` is already tagged by looking up the tagsig of `x` in the binder->signature mapping.
However since this mapping isn't persistent between stg2stg
invocations the lookup will fail. This isn't a correctness issue since it's always
safe to assume a binding isn't tagged and that's what we do in such cases.

However for non-interactive mode we *don't* want to do this. Since in non-interactive mode
we have all binders of the module available for each invocation we can expect the binder->signature
mapping to be complete and all lookups to succeed. This means in non-interactive contexts a failed lookup
indicates a bug in the tag inference implementation.
For this reason we assert that we are running in interactive mode if a lookup fails.
-}
isTagged :: Id -> RM Bool
isTagged :: Id -> RM Bool
isTagged Id
v = do
    Module
this_mod <- RM Module
getMod
    -- See Note [Tag inference for interactive contexts]
    let lookupDefault :: Id -> TagSig
lookupDefault Id
v = Bool -> SDoc -> TagSig -> TagSig
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Module -> Bool
isInteractiveModule Module
this_mod)
                                    (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"unknown Id:" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
this_mod SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
v)
                                    (TagInfo -> TagSig
TagSig TagInfo
TagDunno)
    case Module -> Name -> Bool
nameIsLocalOrFrom Module
this_mod (Id -> Name
idName Id
v) of
        Bool
True
            | Just Levity
Unlifted <- (() :: Constraint) => Type -> Maybe Levity
Type -> Maybe Levity
typeLevity_maybe (Id -> Type
idType Id
v)
              -- NB: v might be the Id of a representation-polymorphic join point,
              -- so we shouldn't use isUnliftedType here. See T22212.
            -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            | Bool
otherwise -> do -- Local binding
                !UniqFM Id TagSig
s <- RM (UniqFM Id TagSig)
getMap
                let !sig :: TagSig
sig = UniqFM Id TagSig -> TagSig -> Id -> TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> elt -> key -> elt
lookupWithDefaultUFM UniqFM Id TagSig
s (Id -> TagSig
lookupDefault Id
v) Id
v
                Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RM Bool) -> Bool -> RM Bool
forall a b. (a -> b) -> a -> b
$ case TagSig
sig of
                    TagSig TagInfo
info ->
                        case TagInfo
info of
                            TagInfo
TagDunno -> Bool
False
                            TagInfo
TagProper -> Bool
True
                            TagInfo
TagTagged -> Bool
True
                            TagTuple [TagInfo]
_ -> Bool
True -- Consider unboxed tuples tagged.
        Bool
False -- Imported
            | Just DataCon
con <- (Id -> Maybe DataCon
isDataConWorkId_maybe Id
v)
            , DataCon -> Bool
isNullaryRepDataCon DataCon
con
            -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            | Just LambdaFormInfo
lf_info <- Id -> Maybe LambdaFormInfo
idLFInfo_maybe Id
v
            -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RM Bool) -> Bool -> RM Bool
forall a b. (a -> b) -> a -> b
$!
                -- Can we treat the thing as tagged based on it's LFInfo?
                case LambdaFormInfo
lf_info of
                    -- Function, applied not entered.
                    LFReEntrant {}
                        -> Bool
True
                    -- Thunks need to be entered.
                    LFThunk {}
                        -> Bool
False
                    -- LFCon means we already know the tag, and it's tagged.
                    LFCon {}
                        -> Bool
True
                    LFUnknown {}
                        -> Bool
False
                    LFUnlifted {}
                        -> Bool
True
                    LFLetNoEscape {}
                    -- Shouldn't be possible. I don't think we can export letNoEscapes
                        -> Bool
True

            | Bool
otherwise
            -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False


isArgTagged :: StgArg -> RM Bool
isArgTagged :: StgArg -> RM Bool
isArgTagged (StgLitArg Literal
_) = Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
isArgTagged (StgVarArg Id
v) = Id -> RM Bool
isTagged Id
v

mkLocalArgId :: Id -> RM Id
mkLocalArgId :: Id -> RM Id
mkLocalArgId Id
id = do
    !Unique
u <- RM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
    Id -> RM Id
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> RM Id) -> Id -> RM Id
forall a b. (a -> b) -> a -> b
$! Id -> Unique -> Id
setIdUnique (Id -> Id
localiseId Id
id) Unique
u

---------------------------
-- Actual rewrite pass
---------------------------


rewriteTopBinds :: Module -> UniqSupply -> [GenStgTopBinding 'InferTaggedBinders] -> [TgStgTopBinding]
rewriteTopBinds :: Module
-> UniqSupply
-> [GenStgTopBinding 'InferTaggedBinders]
-> [TgStgTopBinding]
rewriteTopBinds Module
mod UniqSupply
us [GenStgTopBinding 'InferTaggedBinders]
binds =
    let doBinds :: RM [TgStgTopBinding]
doBinds = (GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding)
-> [GenStgTopBinding 'InferTaggedBinders] -> RM [TgStgTopBinding]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop [GenStgTopBinding 'InferTaggedBinders]
binds

    in State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
-> (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> [TgStgTopBinding]
forall s a. State s a -> s -> a
evalState (RM [TgStgTopBinding]
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM RM [TgStgTopBinding]
doBinds) (UniqFM Id TagSig
forall a. Monoid a => a
mempty, UniqSupply
us, Module
mod, IdSet
forall a. Monoid a => a
mempty)

rewriteTop :: InferStgTopBinding -> RM TgStgTopBinding
rewriteTop :: GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop (StgTopStringLit Id
v ByteString
s) = TgStgTopBinding -> RM TgStgTopBinding
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgTopBinding -> RM TgStgTopBinding)
-> TgStgTopBinding -> RM TgStgTopBinding
forall a b. (a -> b) -> a -> b
$! (Id -> ByteString -> TgStgTopBinding
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
v ByteString
s)
rewriteTop (StgTopLifted GenStgBinding 'InferTaggedBinders
bind)   = do
    -- Top level bindings can, and must remain in scope
    GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind GenStgBinding 'InferTaggedBinders
bind
    (GenStgBinding 'CodeGen -> TgStgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted) (GenStgBinding 'CodeGen -> TgStgTopBinding)
-> RM (GenStgBinding 'CodeGen) -> RM TgStgTopBinding
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> (TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
TopLevel GenStgBinding 'InferTaggedBinders
bind)

-- For top level binds, the wrapper is guaranteed to be `id`
rewriteBinds :: TopLevelFlag -> InferStgBinding -> RM (TgStgBinding)
rewriteBinds :: TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
_top_flag (StgNonRec BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs) = do
        (!TgStgRhs
rhs) <-  (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id, TagSig)
BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs
        GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen))
-> GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$! (BinderP 'CodeGen -> TgStgRhs -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst (Id, TagSig)
BinderP 'InferTaggedBinders
v) TgStgRhs
rhs)
rewriteBinds TopLevelFlag
top_flag b :: GenStgBinding 'InferTaggedBinders
b@(StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) =
    -- Bring sigs of binds into scope for all rhss
    TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM (GenStgBinding 'CodeGen)
-> RM (GenStgBinding 'CodeGen)
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag GenStgBinding 'InferTaggedBinders
b (RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen))
-> RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
        ([TgStgRhs]
rhss) <- (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> RM [TgStgRhs]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (((Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs)
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
        GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen))
-> GenStgBinding 'CodeGen -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$! ([TgStgRhs] -> GenStgBinding 'CodeGen
mkRec [TgStgRhs]
rhss)
        where
            mkRec :: [TgStgRhs] -> TgStgBinding
            mkRec :: [TgStgRhs] -> GenStgBinding 'CodeGen
mkRec [TgStgRhs]
rhss = [(BinderP 'CodeGen, TgStgRhs)] -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ([Id] -> [TgStgRhs] -> [(Id, TgStgRhs)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> Id)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst ((Id, TagSig) -> Id)
-> (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig))
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders)
-> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig)
forall a b. (a, b) -> a
fst) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) [TgStgRhs]
rhss)

-- Rewrite a RHS
rewriteRhs :: (Id,TagSig) -> InferStgRhs
           -> RM (TgStgRhs)
rewriteRhs :: (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id
_id, TagSig
_tagSig) (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
cn [StgTickish]
ticks [StgArg]
args) = {-# SCC rewriteRhs_ #-} do
    -- pprTraceM "rewriteRhs" (ppr _id)

    -- Look up the nodes representing the constructor arguments.
    [Bool]
fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args

    -- Filter out non-strict fields.
    let strictFields :: [(StgArg, Bool)]
strictFields =
            DataCon -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([StgArg] -> [Bool] -> [(StgArg, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StgArg]
args [Bool]
fieldInfos) :: [(StgArg,Bool)] -- (nth-argument, tagInfo)
    -- Filter out already tagged arguments.
    let needsEval :: [StgArg]
needsEval = ((StgArg, Bool) -> StgArg) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, Bool) -> StgArg
forall a b. (a, b) -> a
fst ([(StgArg, Bool)] -> [StgArg])
-> ([(StgArg, Bool)] -> [(StgArg, Bool)])
-> [(StgArg, Bool)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. --get the actual argument
                        ((StgArg, Bool) -> Bool) -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((StgArg, Bool) -> Bool) -> (StgArg, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StgArg, Bool) -> Bool
forall a b. (a, b) -> b
snd) ([(StgArg, Bool)] -> [StgArg]) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$ -- Keep untagged (False) elements.
                        [(StgArg, Bool)]
strictFields :: [StgArg]
    let evalArgs :: [Id]
evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]

    if ([Id] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Id]
evalArgs)
        then TgStgRhs -> RM TgStgRhs
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgRhs -> RM TgStgRhs) -> TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$! (CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> TgStgRhs
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
cn [StgTickish]
ticks [StgArg]
args)
        else do
            --assert not (isTaggedSig tagSig)
            -- pprTraceM "CreatingSeqs for " $ ppr _id <+> ppr node_id

            -- At this point iff we have  possibly untagged arguments to strict fields
            -- we convert the RHS into a RhsClosure which will evaluate the arguments
            -- before allocating the constructor.
            let ty_stub :: a
ty_stub = String -> a
forall a. HasCallStack => String -> a
panic String
"mkSeqs shouldn't use the type arg"
            TgStgExpr
conExpr <- [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
evalArgs (\[StgArg]
taggedArgs -> DataCon -> ConstructorNumber -> [StgArg] -> [Type] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [Type]
forall {a}. a
ty_stub)

            DIdSet
fvs <- [StgArg] -> RM DIdSet
fvArgs [StgArg]
args
            -- lcls <- getFVs
            -- pprTraceM "RhsClosureConversion" (ppr (StgRhsClosure fvs ccs ReEntrant [] $! conExpr) $$ text "lcls:" <> ppr lcls)
            TgStgRhs -> RM TgStgRhs
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgRhs -> RM TgStgRhs) -> TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$! (XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> TgStgExpr
-> TgStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure DIdSet
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
ReEntrant [] (TgStgExpr -> TgStgRhs) -> TgStgExpr -> TgStgRhs
forall a b. (a -> b) -> a -> b
$! TgStgExpr
conExpr)
rewriteRhs (Id, TagSig)
_binding (StgRhsClosure XRhsClosure 'InferTaggedBinders
fvs CostCentreStack
ccs UpdateFlag
flag [BinderP 'InferTaggedBinders]
args GenStgExpr 'InferTaggedBinders
body) = do
    TopLevelFlag -> [(Id, TagSig)] -> RM TgStgRhs -> RM TgStgRhs
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
        DIdSet -> RM TgStgRhs -> RM TgStgRhs
forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
XRhsClosure 'InferTaggedBinders
fvs (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
            XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> TgStgExpr
-> TgStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'InferTaggedBinders
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
flag (((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args) (TgStgExpr -> TgStgRhs) -> RM TgStgExpr -> RM TgStgRhs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
body
        -- return (closure)

fvArgs :: [StgArg] -> RM DVarSet
fvArgs :: [StgArg] -> RM DIdSet
fvArgs [StgArg]
args = do
    IdSet
fv_lcls <- RM IdSet
getFVs
    -- pprTraceM "fvArgs" (text "args:" <> ppr args $$ text "lcls:" <> pprVarSet (fv_lcls) (braces . fsep . map ppr) )
    DIdSet -> RM DIdSet
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (DIdSet -> RM DIdSet) -> DIdSet -> RM DIdSet
forall a b. (a -> b) -> a -> b
$ [Id] -> DIdSet
mkDVarSet [ Id
v | StgVarArg Id
v <- [StgArg]
args, Id -> IdSet -> Bool
elemVarSet Id
v IdSet
fv_lcls]

rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs = (StgArg -> RM StgArg) -> [StgArg] -> RM [StgArg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM StgArg
rewriteArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg (StgVarArg Id
v) = Id -> StgArg
StgVarArg (Id -> StgArg) -> RM Id -> RM StgArg
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> Id -> RM Id
rewriteId Id
v
rewriteArg  (lit :: StgArg
lit@StgLitArg{}) = StgArg -> RM StgArg
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return StgArg
lit

rewriteId :: Id -> RM Id
rewriteId :: Id -> RM Id
rewriteId Id
v = do
    !Bool
is_tagged <- Id -> RM Bool
isTagged Id
v
    if Bool
is_tagged then Id -> RM Id
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> RM Id) -> Id -> RM Id
forall a b. (a -> b) -> a -> b
$! Id -> TagSig -> Id
setIdTagSig Id
v (TagInfo -> TagSig
TagSig TagInfo
TagProper)
                 else Id -> RM Id
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Id
v

rewriteExpr :: InferStgExpr -> RM TgStgExpr
rewriteExpr :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgCase {})          = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase GenStgExpr 'InferTaggedBinders
e
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgLet {})           = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet GenStgExpr 'InferTaggedBinders
e
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgLetNoEscape {})   = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
e
rewriteExpr (StgTick StgTickish
t GenStgExpr 'InferTaggedBinders
e)     = StgTickish -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
t (TgStgExpr -> TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
e
rewriteExpr e :: GenStgExpr 'InferTaggedBinders
e@(StgConApp {})        = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr e :: GenStgExpr 'InferTaggedBinders
e@(StgApp {})     = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr (StgLit Literal
lit)           = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (Literal -> TgStgExpr
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
rewriteExpr (StgOpApp op :: StgOp
op@(StgPrimOp PrimOp
DataToTagOp) [StgArg]
args Type
res_ty) = do
        (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op) ([StgArg] -> Type -> TgStgExpr)
-> RM [StgArg] -> RM (Type -> TgStgExpr)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> [StgArg] -> RM [StgArg]
rewriteArgs [StgArg]
args RM (Type -> TgStgExpr) -> RM Type -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
res_ty
rewriteExpr (StgOpApp StgOp
op [StgArg]
args Type
res_ty) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
args Type
res_ty)


rewriteCase :: InferStgExpr -> RM TgStgExpr
rewriteCase :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase (StgCase GenStgExpr 'InferTaggedBinders
scrut BinderP 'InferTaggedBinders
bndr AltType
alt_type [GenStgAlt 'InferTaggedBinders]
alts) =
    TopLevelFlag -> (Id, TagSig) -> RM TgStgExpr -> RM TgStgExpr
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
NotTopLevel (Id, TagSig)
BinderP 'InferTaggedBinders
bndr (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$
        (TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM
     (TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase RM
  (TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM TgStgExpr
-> RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
scrut RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM Id -> RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            Id -> RM Id
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst (Id, TagSig)
BinderP 'InferTaggedBinders
bndr) RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM AltType -> RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            AltType -> RM AltType
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AltType
alt_type RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM [GenStgAlt 'CodeGen] -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            (GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen))
-> [GenStgAlt 'InferTaggedBinders] -> RM [GenStgAlt 'CodeGen]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt [GenStgAlt 'InferTaggedBinders]
alts

rewriteCase GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible: nodeCase"

rewriteAlt :: InferStgAlt -> RM TgStgAlt
rewriteAlt :: GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt alt :: GenStgAlt 'InferTaggedBinders
alt@GenStgAlt{alt_con :: forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con=AltCon
_, alt_bndrs :: forall (pass :: StgPass). GenStgAlt pass -> [BinderP pass]
alt_bndrs=[BinderP 'InferTaggedBinders]
bndrs, alt_rhs :: forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs=GenStgExpr 'InferTaggedBinders
rhs} =
    TopLevelFlag
-> [(Id, TagSig)]
-> RM (GenStgAlt 'CodeGen)
-> RM (GenStgAlt 'CodeGen)
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
bndrs (RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen))
-> RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
        !TgStgExpr
rhs' <- GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
rhs
        GenStgAlt 'CodeGen -> RM (GenStgAlt 'CodeGen)
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenStgAlt 'CodeGen -> RM (GenStgAlt 'CodeGen))
-> GenStgAlt 'CodeGen -> RM (GenStgAlt 'CodeGen)
forall a b. (a -> b) -> a -> b
$! GenStgAlt 'InferTaggedBinders
alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'}

rewriteLet :: InferStgExpr -> RM TgStgExpr
rewriteLet :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet (StgLet XLet 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
    (!GenStgBinding 'CodeGen
bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
    TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM TgStgExpr
-> RM TgStgExpr
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ do
        -- pprTraceM "withBindLet" (ppr $ bindersOfX bind)
        !TgStgExpr
expr' <- GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
expr
        TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (XLet 'CodeGen -> GenStgBinding 'CodeGen -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'InferTaggedBinders
XLet 'CodeGen
xt GenStgBinding 'CodeGen
bind' TgStgExpr
expr')
rewriteLet GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
rewriteLetNoEscape :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape (StgLetNoEscape XLetNoEscape 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
    (!GenStgBinding 'CodeGen
bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
    TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM TgStgExpr
-> RM TgStgExpr
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ do
        !TgStgExpr
expr' <- GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
expr
        TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (XLetNoEscape 'CodeGen
-> GenStgBinding 'CodeGen -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape XLetNoEscape 'InferTaggedBinders
XLetNoEscape 'CodeGen
xt GenStgBinding 'CodeGen
bind' TgStgExpr
expr')
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

rewriteConApp :: InferStgExpr -> RM TgStgExpr
rewriteConApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp (StgConApp DataCon
con ConstructorNumber
cn [StgArg]
args [Type]
tys) = do
    -- We check if the strict field arguments are already known to be tagged.
    -- If not we evaluate them first.
    [Bool]
fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
    let strictIndices :: [(Bool, StgArg)]
strictIndices = DataCon -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([Bool] -> [StgArg] -> [(Bool, StgArg)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
fieldInfos [StgArg]
args) :: [(Bool, StgArg)]
    let needsEval :: [StgArg]
needsEval = ((Bool, StgArg) -> StgArg) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, StgArg) -> StgArg
forall a b. (a, b) -> b
snd ([(Bool, StgArg)] -> [StgArg])
-> ([(Bool, StgArg)] -> [(Bool, StgArg)])
-> [(Bool, StgArg)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, StgArg) -> Bool) -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Bool, StgArg) -> Bool) -> (Bool, StgArg) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, StgArg) -> Bool
forall a b. (a, b) -> a
fst) ([(Bool, StgArg)] -> [StgArg]) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$ [(Bool, StgArg)]
strictIndices :: [StgArg]
    let evalArgs :: [Id]
evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]
    if (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Id] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Id]
evalArgs)
        then do
            -- pprTraceM "Creating conAppSeqs for " $ ppr nodeId <+> parens ( ppr evalArgs ) -- <+> parens ( ppr fieldInfos )
            [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
evalArgs (\[StgArg]
taggedArgs -> DataCon -> ConstructorNumber -> [StgArg] -> [Type] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [Type]
tys)
        else TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (DataCon -> ConstructorNumber -> [StgArg] -> [Type] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
args [Type]
tys)

rewriteConApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

-- Special case: Atomic binders, usually in a case context like `case f of ...`.
rewriteApp :: InferStgExpr -> RM TgStgExpr
rewriteApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp (StgApp Id
f []) = do
    Id
f' <- Id -> RM Id
rewriteId Id
f
    TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f' []
rewriteApp (StgApp Id
f [StgArg]
args)
    -- pprTrace "rewriteAppOther" (ppr f <+> ppr args) False
    -- = undefined
    | Just [CbvMark]
marks <- Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f
    , [CbvMark]
relevant_marks <- (CbvMark -> Bool) -> [CbvMark] -> [CbvMark]
forall a. (a -> Bool) -> [a] -> [a]
dropWhileEndLE (Bool -> Bool
not (Bool -> Bool) -> (CbvMark -> Bool) -> CbvMark -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CbvMark -> Bool
isMarkedCbv) [CbvMark]
marks
    , (CbvMark -> Bool) -> [CbvMark] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any CbvMark -> Bool
isMarkedCbv [CbvMark]
relevant_marks
    = Bool
-> SDoc -> ([CbvMark] -> RM TgStgExpr) -> [CbvMark] -> RM TgStgExpr
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr ([CbvMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CbvMark]
relevant_marks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [StgArg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StgArg]
args) (Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
f SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [StgArg] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [StgArg]
args SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [CbvMark] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CbvMark]
relevant_marks)
      [CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks

    where
      -- If the function expects any argument to be call-by-value ensure the argument is already
      -- evaluated.
      unliftArg :: [CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks = do
        [Bool]
argTags <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
        let argInfo :: [(StgArg, CbvMark, Bool)]
argInfo = (StgArg -> CbvMark -> Bool -> (StgArg, CbvMark, Bool))
-> [StgArg] -> [CbvMark] -> [Bool] -> [(StgArg, CbvMark, Bool)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 ((,,)) [StgArg]
args ([CbvMark]
relevant_marks[CbvMark] -> [CbvMark] -> [CbvMark]
forall a. [a] -> [a] -> [a]
++CbvMark -> [CbvMark]
forall a. a -> [a]
repeat CbvMark
NotMarkedCbv)  [Bool]
argTags :: [(StgArg, CbvMark, Bool)]

            -- untagged cbv argument positions
            cbvArgInfo :: [(StgArg, CbvMark, Bool)]
cbvArgInfo = ((StgArg, CbvMark, Bool) -> Bool)
-> [(StgArg, CbvMark, Bool)] -> [(StgArg, CbvMark, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(StgArg, CbvMark, Bool)
x -> (StgArg, CbvMark, Bool) -> CbvMark
forall a b c. (a, b, c) -> b
sndOf3 (StgArg, CbvMark, Bool)
x CbvMark -> CbvMark -> Bool
forall a. Eq a => a -> a -> Bool
== CbvMark
MarkedCbv Bool -> Bool -> Bool
&& (StgArg, CbvMark, Bool) -> Bool
forall a b c. (a, b, c) -> c
thdOf3 (StgArg, CbvMark, Bool)
x Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False) [(StgArg, CbvMark, Bool)]
argInfo
            cbvArgIds :: [Id]
cbvArgIds = [Id
x | StgVarArg Id
x <- ((StgArg, CbvMark, Bool) -> StgArg)
-> [(StgArg, CbvMark, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, CbvMark, Bool) -> StgArg
forall a b c. (a, b, c) -> a
fstOf3 [(StgArg, CbvMark, Bool)]
cbvArgInfo] :: [Id]
        [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
cbvArgIds (\[StgArg]
cbv_args -> Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
cbv_args)

rewriteApp (StgApp Id
f [StgArg]
args) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
args
rewriteApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

-- `mkSeq` x x' e generates `case x of x' -> e`
-- We could also substitute x' for x in e but that's so rarely beneficial
-- that we don't bother.
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
id Id
bndr !TgStgExpr
expr =
    -- pprTrace "mkSeq" (ppr (id,bndr)) $
    let altTy :: AltType
altTy = Id -> [GenStgAlt 'CodeGen] -> AltType
forall (p :: StgPass). Id -> [GenStgAlt p] -> AltType
mkStgAltTypeFromStgAlts Id
bndr [GenStgAlt 'CodeGen]
alt
        alt :: [GenStgAlt 'CodeGen]
alt   = [GenStgAlt {alt_con :: AltCon
alt_con = AltCon
DEFAULT, alt_bndrs :: [BinderP 'CodeGen]
alt_bndrs = [], alt_rhs :: TgStgExpr
alt_rhs = TgStgExpr
expr}]
    in TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase (Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id []) Id
BinderP 'CodeGen
bndr AltType
altTy [GenStgAlt 'CodeGen]
alt

-- `mkSeqs args vs mkExpr` will force all vs, and construct
-- an argument list args' where each v is replaced by it's evaluated
-- counterpart v'.
-- That is if we call `mkSeqs [StgVar x, StgLit l] [x] mkExpr` then
-- the result will be (case x of x' { _DEFAULT -> <mkExpr [StgVar x', StgLit l]>}
{-# INLINE mkSeqs #-} -- We inline to avoid allocating mkExpr
mkSeqs  :: [StgArg] -- ^ Original arguments
        -> [Id]     -- ^ var args to be evaluated ahead of time
        -> ([StgArg] -> TgStgExpr)
                    -- ^ Function that reconstructs the expressions when passed
                    -- the list of evaluated arguments.
        -> RM TgStgExpr
mkSeqs :: [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
untaggedIds [StgArg] -> TgStgExpr
mkExpr = do
    [(Id, Id)]
argMap <- (Id -> RM (Id, Id)) -> [Id] -> RM [(Id, Id)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Id
arg -> (Id
arg,) (Id -> (Id, Id)) -> RM Id -> RM (Id, Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> RM Id
mkLocalArgId Id
arg ) [Id]
untaggedIds :: RM [(InId, OutId)]
    -- mapM_ (pprTraceM "Forcing strict args before allocation:" . ppr) argMap
    let [StgArg]
taggedArgs :: [StgArg]
            = (StgArg -> StgArg) -> [StgArg] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map   (\StgArg
v -> case StgArg
v of
                        StgVarArg Id
v' -> Id -> StgArg
StgVarArg (Id -> StgArg) -> Id -> StgArg
forall a b. (a -> b) -> a -> b
$ Id -> Maybe Id -> Id
forall a. a -> Maybe a -> a
fromMaybe Id
v' (Maybe Id -> Id) -> Maybe Id -> Id
forall a b. (a -> b) -> a -> b
$ Id -> [(Id, Id)] -> Maybe Id
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Id
v' [(Id, Id)]
argMap
                        StgArg
lit -> StgArg
lit)
                    [StgArg]
args

    let conBody :: TgStgExpr
conBody = [StgArg] -> TgStgExpr
mkExpr [StgArg]
taggedArgs
    let body :: TgStgExpr
body = ((Id, Id) -> TgStgExpr -> TgStgExpr)
-> TgStgExpr -> [(Id, Id)] -> TgStgExpr
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Id
v,Id
bndr) TgStgExpr
expr -> Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
v Id
bndr TgStgExpr
expr) TgStgExpr
conBody [(Id, Id)]
argMap
    TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! TgStgExpr
body

-- Out of all arguments passed at runtime only return these ending up in a
-- strict field
getStrictConArgs :: Outputable a => DataCon -> [a] -> [a]
getStrictConArgs :: forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con [a]
args
    -- These are always lazy in their arguments.
    | DataCon -> Bool
isUnboxedTupleDataCon DataCon
con = []
    | DataCon -> Bool
isUnboxedSumDataCon DataCon
con = []
    -- For proper data cons we have to check.
    | Bool
otherwise =
        Bool -> SDoc -> [a] -> [a]
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr   ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StrictnessMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((() :: Constraint) => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con))
                    (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Mismatched con arg and con rep strictness lengths:" SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$
                     String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Con" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> DataCon -> SDoc
forall a. Outputable a => a -> SDoc
ppr DataCon
con SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"is applied to" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [a] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [a]
args SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$
                     String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"But seems to have arity" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Int -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([StrictnessMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StrictnessMark]
repStrictness)) ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$
        [ a
arg | (a
arg,StrictnessMark
MarkedStrict)
                    <- String -> [a] -> [StrictnessMark] -> [(a, StrictnessMark)]
forall a b. (() :: Constraint) => String -> [a] -> [b] -> [(a, b)]
zipEqual String
"getStrictConArgs"
                                [a]
args
                                [StrictnessMark]
repStrictness]
        where
            repStrictness :: [StrictnessMark]
repStrictness = ((() :: Constraint) => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con)