{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TupleSections #-}

-- | Adds cost-centers after the core piple has run.
module GHC.Core.LateCC
    ( addLateCostCentres
    ) where

import Control.Applicative
import GHC.Utils.Monad.State.Strict
import Control.Monad

import GHC.Prelude
import GHC.Driver.Session
import GHC.Types.CostCentre
import GHC.Types.CostCentre.State
import GHC.Types.Name hiding (varName)
import GHC.Types.Tickish
import GHC.Unit.Module.ModGuts
import GHC.Types.Var
import GHC.Unit.Types
import GHC.Data.FastString
import GHC.Core
import GHC.Core.Opt.Monad
import GHC.Types.Id
import GHC.Core.Utils (mkTick)

addLateCostCentres :: ModGuts -> CoreM ModGuts
addLateCostCentres :: ModGuts -> CoreM ModGuts
addLateCostCentres ModGuts
guts = do
  DynFlags
dflags <- CoreM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
  let env :: Env
      env :: Env
env = Env
        { thisModule :: Module
thisModule = ModGuts -> Module
mg_module ModGuts
guts
        , ccState :: CostCentreState
ccState = CostCentreState
newCostCentreState
        , dflags :: DynFlags
dflags = DynFlags
dflags
        }
  let guts' :: ModGuts
guts' = ModGuts
guts { mg_binds :: CoreProgram
mg_binds = Env -> CoreProgram -> CoreProgram
doCoreProgram Env
env (ModGuts -> CoreProgram
mg_binds ModGuts
guts)
                   }
  ModGuts -> CoreM ModGuts
forall a. a -> CoreM a
forall (m :: * -> *) a. Monad m => a -> m a
return ModGuts
guts'

doCoreProgram :: Env -> CoreProgram -> CoreProgram
doCoreProgram :: Env -> CoreProgram -> CoreProgram
doCoreProgram Env
env CoreProgram
binds = (State CostCentreState CoreProgram
 -> CostCentreState -> CoreProgram)
-> CostCentreState
-> State CostCentreState CoreProgram
-> CoreProgram
forall a b c. (a -> b -> c) -> b -> a -> c
flip State CostCentreState CoreProgram -> CostCentreState -> CoreProgram
forall s a. State s a -> s -> a
evalState CostCentreState
newCostCentreState (State CostCentreState CoreProgram -> CoreProgram)
-> State CostCentreState CoreProgram -> CoreProgram
forall a b. (a -> b) -> a -> b
$ do
    (CoreBind -> State CostCentreState CoreBind)
-> CoreProgram -> State CostCentreState CoreProgram
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Env -> CoreBind -> State CostCentreState CoreBind
doBind Env
env) CoreProgram
binds

doBind :: Env -> CoreBind -> M CoreBind
doBind :: Env -> CoreBind -> State CostCentreState CoreBind
doBind Env
env (NonRec CoreBndr
b Expr CoreBndr
rhs) = CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b (Expr CoreBndr -> CoreBind)
-> State CostCentreState (Expr CoreBndr)
-> State CostCentreState CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env
-> CoreBndr
-> Expr CoreBndr
-> State CostCentreState (Expr CoreBndr)
doBndr Env
env CoreBndr
b Expr CoreBndr
rhs
doBind Env
env (Rec [(CoreBndr, Expr CoreBndr)]
bs) = [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ([(CoreBndr, Expr CoreBndr)] -> CoreBind)
-> State CostCentreState [(CoreBndr, Expr CoreBndr)]
-> State CostCentreState CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CoreBndr, Expr CoreBndr)
 -> State CostCentreState (CoreBndr, Expr CoreBndr))
-> [(CoreBndr, Expr CoreBndr)]
-> State CostCentreState [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CoreBndr, Expr CoreBndr)
-> State CostCentreState (CoreBndr, Expr CoreBndr)
doPair [(CoreBndr, Expr CoreBndr)]
bs
  where
    doPair :: ((Id, CoreExpr) -> M (Id, CoreExpr))
    doPair :: (CoreBndr, Expr CoreBndr)
-> State CostCentreState (CoreBndr, Expr CoreBndr)
doPair (CoreBndr
b,Expr CoreBndr
rhs) = (CoreBndr
b,) (Expr CoreBndr -> (CoreBndr, Expr CoreBndr))
-> State CostCentreState (Expr CoreBndr)
-> State CostCentreState (CoreBndr, Expr CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env
-> CoreBndr
-> Expr CoreBndr
-> State CostCentreState (Expr CoreBndr)
doBndr Env
env CoreBndr
b Expr CoreBndr
rhs

doBndr :: Env -> Id -> CoreExpr -> M CoreExpr
doBndr :: Env
-> CoreBndr
-> Expr CoreBndr
-> State CostCentreState (Expr CoreBndr)
doBndr Env
env CoreBndr
bndr Expr CoreBndr
rhs = do
    let name :: Name
name = CoreBndr -> Name
idName CoreBndr
bndr
        name_loc :: SrcSpan
name_loc = Name -> SrcSpan
nameSrcSpan Name
name
        cc_name :: FastString
cc_name = Name -> FastString
forall a. NamedThing a => a -> FastString
getOccFS Name
name
        count :: Bool
count = GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
Opt_ProfCountEntries (Env -> DynFlags
dflags Env
env)
    CCFlavour
cc_flavour <- FastString -> M CCFlavour
getCCExprFlavour FastString
cc_name
    let cc_mod :: Module
cc_mod = Env -> Module
thisModule Env
env
        bndrCC :: CostCentre
bndrCC = CCFlavour -> FastString -> Module -> SrcSpan -> CostCentre
NormalCC CCFlavour
cc_flavour FastString
cc_name Module
cc_mod SrcSpan
name_loc
        note :: GenTickish 'TickishPassCore
note = CostCentre -> Bool -> Bool -> GenTickish 'TickishPassCore
forall (pass :: TickishPass).
CostCentre -> Bool -> Bool -> GenTickish pass
ProfNote CostCentre
bndrCC Bool
count Bool
True
    Expr CoreBndr -> State CostCentreState (Expr CoreBndr)
forall a. a -> State CostCentreState a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr CoreBndr -> State CostCentreState (Expr CoreBndr))
-> Expr CoreBndr -> State CostCentreState (Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ GenTickish 'TickishPassCore -> Expr CoreBndr -> Expr CoreBndr
mkTick GenTickish 'TickishPassCore
note Expr CoreBndr
rhs

type M = State CostCentreState

getCCExprFlavour :: FastString -> M CCFlavour
getCCExprFlavour :: FastString -> M CCFlavour
getCCExprFlavour FastString
name = CostCentreIndex -> CCFlavour
ExprCC (CostCentreIndex -> CCFlavour)
-> State CostCentreState CostCentreIndex -> M CCFlavour
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FastString -> State CostCentreState CostCentreIndex
getCCIndex' FastString
name

getCCIndex' :: FastString -> M CostCentreIndex
getCCIndex' :: FastString -> State CostCentreState CostCentreIndex
getCCIndex' FastString
name = (CostCentreState -> (CostCentreIndex, CostCentreState))
-> State CostCentreState CostCentreIndex
forall s a. (s -> (a, s)) -> State s a
state (FastString -> CostCentreState -> (CostCentreIndex, CostCentreState)
getCCIndex FastString
name)

data Env = Env
  { Env -> Module
thisModule  :: Module
  , Env -> DynFlags
dflags      :: DynFlags
  , Env -> CostCentreState
ccState     :: CostCentreState
  }