{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize where
import Control.Concurrent.Supply (Supply)
import Control.Exception (throw)
import qualified Control.Lens as Lens
import Control.Monad (when)
import Control.Monad.State.Strict (State)
import Data.Default (def)
import Data.Either (lefts,partitionEithers)
import qualified Data.IntMap as IntMap
import Data.IntMap.Strict (IntMap)
import Data.List
(intersect, mapAccumL)
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
#if MIN_VERSION_prettyprinter(1,7,0)
import Prettyprinter (vcat)
#else
import Data.Text.Prettyprint.Doc (vcat)
#endif
#if MIN_VERSION_ghc(9,0,0)
import GHC.Types.Basic (InlineSpec (..))
#else
import BasicTypes (InlineSpec (..))
#endif
import Clash.Annotations.BitRepresentation.Internal
(CustomReprs)
#if EXPERIMENTAL_EVALUATOR
import Clash.Core.PartialEval (Evaluator)
#else
import Clash.Core.Evaluator.Types (Evaluator)
#endif
import Clash.Core.FreeVars
(freeLocalIds, globalIds, globalIdOccursIn, localIdDoesNotOccurIn)
import Clash.Core.Pretty (PrettyOptions(..), showPpr, showPpr', ppr)
import Clash.Core.Subst
(extendGblSubstList, mkSubst, substTm)
import Clash.Core.Term (Term (..), collectArgsTicks
,mkApps, mkTicks)
import Clash.Core.Type (Type, splitCoreFunForallTy)
import Clash.Core.TyCon
(TyConMap, TyConName)
import Clash.Core.Type (isPolyTy)
import Clash.Core.Var (Id, varName, varType)
import Clash.Core.VarEnv
(VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv,
extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv,
mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv)
import Clash.Debug (traceIf)
import Clash.Driver.Types
(BindingMap, Binding(..), ClashOpts (..), DebugLevel (..))
import Clash.Netlist.Types
(HWMap, FilteredHWType(..))
import Clash.Netlist.Util
(splitNormalized)
import Clash.Normalize.Strategy
import Clash.Normalize.Transformations
import Clash.Normalize.Types
import Clash.Normalize.Util
import Clash.Primitives.Types (CompiledPrimMap)
import Clash.Rewrite.Combinators ((>->),(!->),repeatR,topdownR)
import Clash.Rewrite.Types
(RewriteEnv (..), RewriteState (..), bindings, dbgLevel, dbgRewriteHistoryFile, extra,
tcCache, topEntities)
import Clash.Rewrite.Util
(apply, isUntranslatableType, runRewriteSession)
import Clash.Util
import Clash.Util.Interpolate (i)
import Data.Binary (encode)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import System.IO.Unsafe (unsafePerformIO)
import Clash.Rewrite.Types (RewriteStep(..))
runNormalization
:: ClashOpts
-> Supply
-> BindingMap
-> (CustomReprs -> TyConMap -> Type ->
State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> IntMap TyConName
-> Evaluator
-> CompiledPrimMap
-> VarEnv Bool
-> [Id]
-> NormalizeSession a
-> a
runNormalization :: ClashOpts
-> Supply
-> BindingMap
-> (CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> IntMap TyConName
-> Evaluator
-> CompiledPrimMap
-> VarEnv Bool
-> [Id]
-> NormalizeSession a
-> a
runNormalization ClashOpts
opts Supply
supply BindingMap
globals CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
typeTrans CustomReprs
reprs TyConMap
tcm IntMap TyConName
tupTcm Evaluator
eval CompiledPrimMap
primMap VarEnv Bool
rcsMap [Id]
topEnts
= RewriteEnv
-> RewriteState NormalizeState -> NormalizeSession a -> a
forall extra a.
RewriteEnv -> RewriteState extra -> RewriteMonad extra a -> a
runRewriteSession RewriteEnv
rwEnv RewriteState NormalizeState
rwState
where
rwEnv :: RewriteEnv
rwEnv = DebugLevel
-> Set String
-> Int
-> Int
-> Maybe String
-> Bool
-> (CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType)))
-> TyConMap
-> IntMap TyConName
-> Evaluator
-> VarSet
-> CustomReprs
-> Word
-> RewriteEnv
RewriteEnv
(ClashOpts -> DebugLevel
opt_dbgLevel ClashOpts
opts)
(ClashOpts -> Set String
opt_dbgTransformations ClashOpts
opts)
(ClashOpts -> Int
opt_dbgTransformationsFrom ClashOpts
opts)
(ClashOpts -> Int
opt_dbgTransformationsLimit ClashOpts
opts)
(ClashOpts -> Maybe String
opt_dbgRewriteHistoryFile ClashOpts
opts)
(ClashOpts -> Bool
opt_aggressiveXOpt ClashOpts
opts)
CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
typeTrans
TyConMap
tcm
IntMap TyConName
tupTcm
Evaluator
eval
([Id] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [Id]
topEnts)
CustomReprs
reprs
(ClashOpts -> Word
opt_evaluatorFuelLimit ClashOpts
opts)
rwState :: RewriteState NormalizeState
rwState = Int
-> BindingMap
-> Supply
-> (Id, SrcSpan)
-> Int
-> PrimHeap
-> VarEnv Bool
-> NormalizeState
-> RewriteState NormalizeState
forall extra.
Int
-> BindingMap
-> Supply
-> (Id, SrcSpan)
-> Int
-> PrimHeap
-> VarEnv Bool
-> extra
-> RewriteState extra
RewriteState
Int
0
BindingMap
globals
Supply
supply
(String -> Id
forall a. HasCallStack => String -> a
error (String -> Id) -> String -> Id
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Report as bug: no curFun",SrcSpan
noSrcSpan)
Int
0
#if EXPERIMENTAL_EVALUATOR
IntMap.empty
0
#else
(IntMap Term
forall a. IntMap a
IntMap.empty, Int
0)
#endif
VarEnv Bool
forall a. VarEnv a
emptyVarEnv
NormalizeState
normState
normState :: NormalizeState
normState = BindingMap
-> Map (Id, Int, Either Term Type) Id
-> VarEnv Int
-> Int
-> VarEnv (VarEnv Int)
-> Int
-> Word
-> Word
-> CompiledPrimMap
-> Map Text (Set Int)
-> VarEnv Bool
-> Bool
-> Bool
-> Word
-> NormalizeState
NormalizeState
BindingMap
forall a. VarEnv a
emptyVarEnv
Map (Id, Int, Either Term Type) Id
forall k a. Map k a
Map.empty
VarEnv Int
forall a. VarEnv a
emptyVarEnv
(ClashOpts -> Int
opt_specLimit ClashOpts
opts)
VarEnv (VarEnv Int)
forall a. VarEnv a
emptyVarEnv
(ClashOpts -> Int
opt_inlineLimit ClashOpts
opts)
(ClashOpts -> Word
opt_inlineFunctionLimit ClashOpts
opts)
(ClashOpts -> Word
opt_inlineConstantLimit ClashOpts
opts)
CompiledPrimMap
primMap
Map Text (Set Int)
forall k a. Map k a
Map.empty
VarEnv Bool
rcsMap
(ClashOpts -> Bool
opt_newInlineStrat ClashOpts
opts)
(ClashOpts -> Bool
opt_ultra ClashOpts
opts)
(ClashOpts -> Word
opt_inlineWFCacheLimit ClashOpts
opts)
normalize
:: [Id]
-> NormalizeSession BindingMap
normalize :: [Id] -> NormalizeSession BindingMap
normalize [] = BindingMap -> NormalizeSession BindingMap
forall (m :: Type -> Type) a. Monad m => a -> m a
return BindingMap
forall a. VarEnv a
emptyVarEnv
normalize [Id]
top = do
([[Id]]
new,[(Id, Binding Term)]
topNormalized) <- [([Id], (Id, Binding Term))] -> ([[Id]], [(Id, Binding Term)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Id], (Id, Binding Term))] -> ([[Id]], [(Id, Binding Term)]))
-> RewriteMonad NormalizeState [([Id], (Id, Binding Term))]
-> RewriteMonad NormalizeState ([[Id]], [(Id, Binding Term)])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Id -> RewriteMonad NormalizeState ([Id], (Id, Binding Term)))
-> [Id] -> RewriteMonad NormalizeState [([Id], (Id, Binding Term))]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Id -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
normalize' [Id]
top
BindingMap
newNormalized <- [Id] -> NormalizeSession BindingMap
normalize ([[Id]] -> [Id]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[Id]]
new)
BindingMap -> NormalizeSession BindingMap
forall (m :: Type -> Type) a. Monad m => a -> m a
return (BindingMap -> BindingMap -> BindingMap
forall a. VarEnv a -> VarEnv a -> VarEnv a
unionVarEnv ([(Id, Binding Term)] -> BindingMap
forall a b. [(Var a, b)] -> VarEnv b
mkVarEnv [(Id, Binding Term)]
topNormalized) BindingMap
newNormalized)
normalize' :: Id -> NormalizeSession ([Id], (Id, Binding Term))
normalize' :: Id -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
normalize' Id
nm = do
Maybe (Binding Term)
exprM <- Id -> BindingMap -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
nm (BindingMap -> Maybe (Binding Term))
-> NormalizeSession BindingMap
-> RewriteMonad NormalizeState (Maybe (Binding Term))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState NormalizeState) BindingMap
-> NormalizeSession BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
let nmS :: String
nmS = Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)
case Maybe (Binding Term)
exprM of
Just (Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm) -> do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
VarSet
topEnts <- Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
let isTop :: Bool
isTop = Id
nm Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`elemVarSet` VarSet
topEnts
ty0 :: Type
ty0 = Id -> Type
forall a. Var a -> Type
varType Id
nm'
ty1 :: Type
ty1 = if Bool
isTop then Type -> Type
tvSubstWithTyEq Type
ty0 else Type
ty0
Bool
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
isPolyTy Type
ty1) (RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ())
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall a b. (a -> b) -> a -> b
$
let msg :: String
msg = String
$curLoc String -> String -> String
forall a. [a] -> [a] -> [a]
++ [i|
Clash can only normalize monomorphic functions, but this is polymorphic:
#{showPpr' def{displayUniques=False\} nm'}
|]
msgExtra :: Maybe String
msgExtra | Type
ty0 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ty1 = Maybe String
forall a. Maybe a
Nothing
| Bool
otherwise = String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ [i|
Even after applying type equality constraints it remained polymorphic:
#{showPpr' def{displayUniques=False\} nm'{varType=ty1\}}
|]
in ClashException -> RewriteMonad NormalizeState ()
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp String
msg Maybe String
msgExtra)
let ([Either TyVar Type]
args,Type
resTy) = TyConMap -> Type -> ([Either TyVar Type], Type)
splitCoreFunForallTy TyConMap
tcm Type
ty1
isTopEnt :: Bool
isTopEnt = Id
nm Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`elemVarSet` VarSet
topEnts
isFunction :: Bool
isFunction = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [TyVar] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ [Either TyVar Type] -> [TyVar]
forall a b. [Either a b] -> [a]
lefts [Either TyVar Type]
args
Bool
resTyRep <- Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Type -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType Bool
False Type
resTy
if Bool
resTyRep
then do
Binding Term
tmNorm <- Bool -> Id -> Binding Term -> NormalizeSession (Binding Term)
normalizeTopLvlBndr Bool
isTopEnt Id
nm (Id -> SrcSpan -> InlineSpec -> IsPrim -> Term -> Binding Term
forall a. Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Binding a
Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm)
let usedBndrs :: [Id]
usedBndrs = Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
globalIds (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
tmNorm)
Bool
-> String
-> RewriteMonad NormalizeState ()
-> RewriteMonad NormalizeState ()
forall a. Bool -> String -> a -> a
traceIf (Id
nm Id -> [Id] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Id]
usedBndrs)
([String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc),String
"Expr belonging to bndr: ",String
nmS ,String
" (:: "
, Type -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Type
forall a. Var a -> Type
varType (Binding Term -> Id
forall a. Binding a -> Id
bindingId Binding Term
tmNorm))
, String
") remains recursive after normalization:\n"
, Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
tmNorm) ])
(() -> RewriteMonad NormalizeState ()
forall (m :: Type -> Type) a. Monad m => a -> m a
return ())
VarEnv Id
prevNorm <- (Binding Term -> Id) -> BindingMap -> VarEnv Id
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv Binding Term -> Id
forall a. Binding a -> Id
bindingId (BindingMap -> VarEnv Id)
-> NormalizeSession BindingMap
-> RewriteMonad NormalizeState (VarEnv Id)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState NormalizeState) BindingMap
-> NormalizeSession BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const BindingMap NormalizeState)
-> RewriteState NormalizeState
-> Const BindingMap (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const BindingMap NormalizeState)
-> RewriteState NormalizeState
-> Const BindingMap (RewriteState NormalizeState))
-> ((BindingMap -> Const BindingMap BindingMap)
-> NormalizeState -> Const BindingMap NormalizeState)
-> Getting BindingMap (RewriteState NormalizeState) BindingMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BindingMap -> Const BindingMap BindingMap)
-> NormalizeState -> Const BindingMap NormalizeState
Lens' NormalizeState BindingMap
normalized)
let toNormalize :: [Id]
toNormalize = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`notElemVarSet` VarSet
topEnts)
([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarEnv Id -> Bool
forall a b. Var a -> VarEnv b -> Bool
`notElemVarEnv` (Id -> Id -> VarEnv Id -> VarEnv Id
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
nm Id
nm VarEnv Id
prevNorm)) [Id]
usedBndrs
([Id], (Id, Binding Term))
-> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([Id]
toNormalize,(Id
nm,Binding Term
tmNorm))
else
do
Bool
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
when (Bool
isTopEnt Bool -> Bool -> Bool
|| Bool
isFunction) (RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ())
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall a b. (a -> b) -> a -> b
$
let msg :: String
msg = $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ [i|
This bndr has a non-representable return type and can't be normalized:
#{showPpr' def{displayUniques=False\} nm'}
|]
in ClashException -> RewriteMonad NormalizeState ()
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp String
msg Maybe String
forall a. Maybe a
Nothing)
DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad NormalizeState DebugLevel
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
Bool
-> String
-> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
-> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone)
([String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [$(String
curLoc), String
"Expr belonging to bndr: ", String
nmS, String
" (:: "
, Type -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Type
forall a. Var a -> Type
varType Id
nm')
, String
") has a non-representable return type."
, String
" Not normalising:\n", Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
tm] )
(([Id], (Id, Binding Term))
-> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([],(Id
nm,(Id -> SrcSpan -> InlineSpec -> IsPrim -> Term -> Binding Term
forall a. Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Binding a
Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm))))
Maybe (Binding Term)
Nothing -> String -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall a. HasCallStack => String -> a
error (String -> RewriteMonad NormalizeState ([Id], (Id, Binding Term)))
-> String -> RewriteMonad NormalizeState ([Id], (Id, Binding Term))
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Expr belonging to bndr: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
nmS String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not found"
checkNonRecursive
:: BindingMap
-> BindingMap
checkNonRecursive :: BindingMap -> BindingMap
checkNonRecursive BindingMap
norm = case (Binding Term -> Maybe (Id, Term))
-> BindingMap -> VarEnv (Id, Term)
forall a b. (a -> Maybe b) -> VarEnv a -> VarEnv b
mapMaybeVarEnv Binding Term -> Maybe (Id, Term)
go BindingMap
norm of
VarEnv (Id, Term)
rcs | VarEnv (Id, Term) -> Bool
forall a. VarEnv a -> Bool
nullVarEnv VarEnv (Id, Term)
rcs -> BindingMap
norm
VarEnv (Id, Term)
rcs -> String -> BindingMap
forall a. HasCallStack => String -> a
error (String -> BindingMap) -> String -> BindingMap
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Callgraph after normalization contains following recursive components: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Doc ClashAnnotation -> String
forall a. Show a => a -> String
show ([Doc ClashAnnotation] -> Doc ClashAnnotation
forall ann. [Doc ann] -> Doc ann
vcat [ Id -> Doc ClashAnnotation
forall p. PrettyPrec p => p -> Doc ClashAnnotation
ppr Id
a Doc ClashAnnotation -> Doc ClashAnnotation -> Doc ClashAnnotation
forall a. Semigroup a => a -> a -> a
<> Term -> Doc ClashAnnotation
forall p. PrettyPrec p => p -> Doc ClashAnnotation
ppr Term
b
| (Id
a,Term
b) <- VarEnv (Id, Term) -> [(Id, Term)]
forall a. VarEnv a -> [a]
eltsVarEnv VarEnv (Id, Term)
rcs
])
where
go :: Binding Term -> Maybe (Id, Term)
go (Binding Id
nm SrcSpan
_ InlineSpec
_ IsPrim
_ Term
tm) =
if Id
nm Id -> Term -> Bool
`globalIdOccursIn` Term
tm
then (Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm,Term
tm)
else Maybe (Id, Term)
forall a. Maybe a
Nothing
cleanupGraph
:: Id
-> BindingMap
-> NormalizeSession BindingMap
cleanupGraph :: Id -> BindingMap -> NormalizeSession BindingMap
cleanupGraph Id
topEntity BindingMap
norm
| Just CallTree
ct <- [Id] -> BindingMap -> Id -> Maybe CallTree
mkCallTree [] BindingMap
norm Id
topEntity
= do CallTree
ctFlat <- CallTree -> NormalizeSession CallTree
flattenCallTree CallTree
ct
BindingMap -> NormalizeSession BindingMap
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([(Id, Binding Term)] -> BindingMap
forall a b. [(Var a, b)] -> VarEnv b
mkVarEnv ([(Id, Binding Term)] -> BindingMap)
-> [(Id, Binding Term)] -> BindingMap
forall a b. (a -> b) -> a -> b
$ ([Id], [(Id, Binding Term)]) -> [(Id, Binding Term)]
forall a b. (a, b) -> b
snd (([Id], [(Id, Binding Term)]) -> [(Id, Binding Term)])
-> ([Id], [(Id, Binding Term)]) -> [(Id, Binding Term)]
forall a b. (a -> b) -> a -> b
$ [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList [] CallTree
ctFlat)
cleanupGraph Id
_ BindingMap
norm = BindingMap -> NormalizeSession BindingMap
forall (m :: Type -> Type) a. Monad m => a -> m a
return BindingMap
norm
data CallTree
= CLeaf (Id, Binding Term)
| CBranch (Id, Binding Term) [CallTree]
mkCallTree
:: [Id]
-> BindingMap
-> Id
-> Maybe CallTree
mkCallTree :: [Id] -> BindingMap -> Id -> Maybe CallTree
mkCallTree [Id]
visited BindingMap
bindingMap Id
root
| Just Binding Term
rootTm <- Id -> BindingMap -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
root BindingMap
bindingMap
= let used :: [Id]
used = Set Id -> [Id]
forall a. Set a -> [a]
Set.toList (Set Id -> [Id]) -> Set Id -> [Id]
forall a b. (a -> b) -> a -> b
$ Getting (Set Id) Term Id -> Term -> Set Id
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set Id) Term Id
Fold Term Id
globalIds (Term -> Set Id) -> Term -> Set Id
forall a b. (a -> b) -> a -> b
$ (Binding Term -> Term
forall a. Binding a -> a
bindingTerm Binding Term
rootTm)
other :: [CallTree]
other = (Id -> Maybe CallTree) -> [Id] -> [CallTree]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe ([Id] -> BindingMap -> Id -> Maybe CallTree
mkCallTree (Id
rootId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited) BindingMap
bindingMap) ((Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> [Id] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`notElem` [Id]
visited) [Id]
used)
in case [Id]
used of
[] -> CallTree -> Maybe CallTree
forall a. a -> Maybe a
Just ((Id, Binding Term) -> CallTree
CLeaf (Id
root,Binding Term
rootTm))
[Id]
_ -> CallTree -> Maybe CallTree
forall a. a -> Maybe a
Just ((Id, Binding Term) -> [CallTree] -> CallTree
CBranch (Id
root,Binding Term
rootTm) [CallTree]
other)
mkCallTree [Id]
_ BindingMap
_ Id
_ = Maybe CallTree
forall a. Maybe a
Nothing
stripArgs
:: [Id]
-> [Id]
-> [Either Term Type]
-> Maybe [Either Term Type]
stripArgs :: [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
_ (Id
_:[Id]
_) [] = Maybe [Either Term Type]
forall a. Maybe a
Nothing
stripArgs [Id]
allIds [] [Either Term Type]
args = if (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any Either Term Type -> Bool
forall b. Either Term b -> Bool
mentionsId [Either Term Type]
args
then Maybe [Either Term Type]
forall a. Maybe a
Nothing
else [Either Term Type] -> Maybe [Either Term Type]
forall a. a -> Maybe a
Just [Either Term Type]
args
where
mentionsId :: Either Term b -> Bool
mentionsId Either Term b
t = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Id] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null ((Term -> [Id]) -> (b -> [Id]) -> Either Term b -> [Id]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
freeLocalIds) ([Id] -> b -> [Id]
forall a b. a -> b -> a
const []) Either Term b
t
[Id] -> [Id] -> [Id]
forall a. Eq a => [a] -> [a] -> [a]
`intersect`
[Id]
allIds)
stripArgs [Id]
allIds (Id
id_:[Id]
ids) (Left (Var Id
nm):[Either Term Type]
args)
| Id
id_ Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
nm = [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
allIds [Id]
ids [Either Term Type]
args
| Bool
otherwise = Maybe [Either Term Type]
forall a. Maybe a
Nothing
stripArgs [Id]
_ [Id]
_ [Either Term Type]
_ = Maybe [Either Term Type]
forall a. Maybe a
Nothing
flattenNode
:: CallTree
-> NormalizeSession (Either CallTree ((Id,Term),[CallTree]))
flattenNode :: CallTree
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
flattenNode c :: CallTree
c@(CLeaf (Id
_,(Binding Id
_ SrcSpan
_ InlineSpec
NoInline IsPrim
_ Term
_))) = Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
c)
flattenNode c :: CallTree
c@(CLeaf (Id
nm,(Binding Id
_ SrcSpan
_ InlineSpec
_ IsPrim
_ Term
e))) = do
Bool
isTopEntity <- Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Id
nm (VarSet -> Bool)
-> RewriteMonad NormalizeState VarSet
-> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
if Bool
isTopEntity then Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
c) else do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
let norm :: Either String ([Id], [(Id, Term)], Id)
norm = TyConMap -> Term -> Either String ([Id], [(Id, Term)], Id)
splitNormalized TyConMap
tcm Term
e
case Either String ([Id], [(Id, Term)], Id)
norm of
Right ([Id]
ids,[(Id
bId,Term
bExpr)],Id
_) -> do
let (Term
fun,[Either Term Type]
args,[TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
bExpr
case [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
ids ([Id] -> [Id]
forall a. [a] -> [a]
reverse [Id]
ids) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
args) of
Just [Either Term Type]
remainder | Id
bId Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
bExpr ->
Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
remainder)),[]))
Maybe [Either Term Type]
_ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[]))
Either String ([Id], [(Id, Term)], Id)
_ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[]))
flattenNode b :: CallTree
b@(CBranch (Id
_,(Binding Id
_ SrcSpan
_ InlineSpec
NoInline IsPrim
_ Term
_)) [CallTree]
_) =
Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b)
flattenNode b :: CallTree
b@(CBranch (Id
nm,(Binding Id
_ SrcSpan
_ InlineSpec
_ IsPrim
_ Term
e)) [CallTree]
us) = do
Bool
isTopEntity <- Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Id
nm (VarSet -> Bool)
-> RewriteMonad NormalizeState VarSet
-> RewriteMonad NormalizeState Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
if Bool
isTopEntity then Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b) else do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
let norm :: Either String ([Id], [(Id, Term)], Id)
norm = TyConMap -> Term -> Either String ([Id], [(Id, Term)], Id)
splitNormalized TyConMap
tcm Term
e
case Either String ([Id], [(Id, Term)], Id)
norm of
Right ([Id]
ids,[(Id
bId,Term
bExpr)],Id
_) -> do
let (Term
fun,[Either Term Type]
args,[TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
bExpr
case [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
ids ([Id] -> [Id]
forall a. [a] -> [a]
reverse [Id]
ids) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
args) of
Just [Either Term Type]
remainder | Id
bId Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
bExpr ->
Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
remainder)),[CallTree]
us))
Maybe [Either Term Type]
_ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[CallTree]
us))
Either String ([Id], [(Id, Term)], Id)
_ -> do
Bool
newInlineStrat <- Getting Bool (RewriteState NormalizeState) Bool
-> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use ((NormalizeState -> Const Bool NormalizeState)
-> RewriteState NormalizeState
-> Const Bool (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Bool NormalizeState)
-> RewriteState NormalizeState
-> Const Bool (RewriteState NormalizeState))
-> ((Bool -> Const Bool Bool)
-> NormalizeState -> Const Bool NormalizeState)
-> Getting Bool (RewriteState NormalizeState) Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Bool -> Const Bool Bool)
-> NormalizeState -> Const Bool NormalizeState
Lens' NormalizeState Bool
newInlineStrategy)
if Bool
newInlineStrat Bool -> Bool -> Bool
|| Term -> Bool
isCheapFunction Term
e
then Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[CallTree]
us))
else Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b)
flattenCallTree
:: CallTree
-> NormalizeSession CallTree
flattenCallTree :: CallTree -> NormalizeSession CallTree
flattenCallTree c :: CallTree
c@(CLeaf (Id, Binding Term)
_) = CallTree -> NormalizeSession CallTree
forall (m :: Type -> Type) a. Monad m => a -> m a
return CallTree
c
flattenCallTree (CBranch (Id
nm,(Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
tm)) [CallTree]
used) = do
[CallTree]
flattenedUsed <- (CallTree -> NormalizeSession CallTree)
-> [CallTree] -> RewriteMonad NormalizeState [CallTree]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CallTree -> NormalizeSession CallTree
flattenCallTree [CallTree]
used
([CallTree]
newUsed,[((Id, Term), [CallTree])]
il_ct) <- [Either CallTree ((Id, Term), [CallTree])]
-> ([CallTree], [((Id, Term), [CallTree])])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either CallTree ((Id, Term), [CallTree])]
-> ([CallTree], [((Id, Term), [CallTree])]))
-> RewriteMonad
NormalizeState [Either CallTree ((Id, Term), [CallTree])]
-> RewriteMonad
NormalizeState ([CallTree], [((Id, Term), [CallTree])])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (CallTree
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree])))
-> [CallTree]
-> RewriteMonad
NormalizeState [Either CallTree ((Id, Term), [CallTree])]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CallTree
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
flattenNode [CallTree]
flattenedUsed
let ([(Id, Term)]
toInline,[[CallTree]]
il_used) = [((Id, Term), [CallTree])] -> ([(Id, Term)], [[CallTree]])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, Term), [CallTree])]
il_ct
subst :: Subst
subst = Subst -> [(Id, Term)] -> Subst
extendGblSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet) [(Id, Term)]
toInline
Term
newExpr <- case [(Id, Term)]
toInline of
[] -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
tm
[(Id, Term)]
_ -> do
let tm1 :: Term
tm1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"flattenCallTree.flattenExpr" Subst
subst Term
tm
Maybe String
rewriteHistFile <- Getting (Maybe String) RewriteEnv (Maybe String)
-> RewriteMonad NormalizeState (Maybe String)
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (Maybe String) RewriteEnv (Maybe String)
Lens' RewriteEnv (Maybe String)
dbgRewriteHistoryFile
Bool
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
when (Maybe String -> Bool
forall a. Maybe a -> Bool
Maybe.isJust Maybe String
rewriteHistFile) (RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ())
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall a b. (a -> b) -> a -> b
$
let !()
_ = IO () -> ()
forall a. IO a -> a
unsafePerformIO
(IO () -> ()) -> IO () -> ()
forall a b. (a -> b) -> a -> b
$ String -> ByteString -> IO ()
BS.appendFile (Maybe String -> String
forall a. HasCallStack => Maybe a -> a
Maybe.fromJust Maybe String
rewriteHistFile)
(ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict
(ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ RewriteStep -> ByteString
forall a. Binary a => a -> ByteString
encode RewriteStep :: Context -> String -> String -> Term -> Term -> RewriteStep
RewriteStep
{ t_ctx :: Context
t_ctx = []
, t_name :: String
t_name = String
"INLINE"
, t_bndrS :: String
t_bndrS = Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm')
, t_before :: Term
t_before = Term
tm
, t_after :: Term
t_after = Term
tm1
}
in () -> RewriteMonad NormalizeState ()
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
(String, NormRewrite)
-> (String, Term)
-> (Id, SrcSpan)
-> RewriteMonad NormalizeState Term
rewriteExpr (String
"flattenExpr",NormRewrite
flatten) (Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm, Term
tm1) (Id
nm', SrcSpan
sp)
let allUsed :: [CallTree]
allUsed = [CallTree]
newUsed [CallTree] -> [CallTree] -> [CallTree]
forall a. [a] -> [a] -> [a]
++ [[CallTree]] -> [CallTree]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[CallTree]]
il_used
if InlineSpec
inl InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
/= InlineSpec
NoInline Bool -> Bool -> Bool
&& Term -> Bool
isCheapFunction Term
newExpr
then do
let ([Maybe (Id, Term)]
toInline',[[CallTree]]
allUsed') = [(Maybe (Id, Term), [CallTree])]
-> ([Maybe (Id, Term)], [[CallTree]])
forall a b. [(a, b)] -> ([a], [b])
unzip ((CallTree -> (Maybe (Id, Term), [CallTree]))
-> [CallTree] -> [(Maybe (Id, Term), [CallTree])]
forall a b. (a -> b) -> [a] -> [b]
map CallTree -> (Maybe (Id, Term), [CallTree])
goCheap [CallTree]
allUsed)
subst' :: Subst
subst' = Subst -> [(Id, Term)] -> Subst
extendGblSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet)
([Maybe (Id, Term)] -> [(Id, Term)]
forall a. [Maybe a] -> [a]
Maybe.catMaybes [Maybe (Id, Term)]
toInline')
let tm1 :: Term
tm1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"flattenCallTree.flattenCheap" Subst
subst' Term
newExpr
Term
newExpr' <- (String, NormRewrite)
-> (String, Term)
-> (Id, SrcSpan)
-> RewriteMonad NormalizeState Term
rewriteExpr (String
"flattenCheap",NormRewrite
flatten) (Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm, Term
tm1) (Id
nm', SrcSpan
sp)
CallTree -> NormalizeSession CallTree
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Id, Binding Term) -> [CallTree] -> CallTree
CBranch (Id
nm,(Id -> SrcSpan -> InlineSpec -> IsPrim -> Term -> Binding Term
forall a. Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Binding a
Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
newExpr')) ([[CallTree]] -> [CallTree]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[CallTree]]
allUsed'))
else CallTree -> NormalizeSession CallTree
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Id, Binding Term) -> [CallTree] -> CallTree
CBranch (Id
nm,(Id -> SrcSpan -> InlineSpec -> IsPrim -> Term -> Binding Term
forall a. Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Binding a
Binding Id
nm' SrcSpan
sp InlineSpec
inl IsPrim
pr Term
newExpr)) [CallTree]
allUsed)
where
flatten :: NormRewrite
flatten =
NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"appPropFast" HasCallStack => NormRewrite
NormRewrite
appPropFast NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"bindConstantVar" HasCallStack => NormRewrite
NormRewrite
bindConstantVar NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"caseCon" HasCallStack => NormRewrite
NormRewrite
caseCon NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
#if EXPERIMENTAL_EVALUATOR
apply "deadcode" deadCode >->
#else
(String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"reduceConst" HasCallStack => NormRewrite
NormRewrite
reduceConst NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"deadcode" HasCallStack => NormRewrite
NormRewrite
deadCode) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
#endif
String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"reduceNonRepPrim" HasCallStack => NormRewrite
NormRewrite
reduceNonRepPrim NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"removeUnusedExpr" HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"flattenLet" HasCallStack => NormRewrite
NormRewrite
flattenLet)) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"topLet" HasCallStack => NormRewrite
NormRewrite
topLet)
goCheap :: CallTree -> (Maybe (Id, Term), [CallTree])
goCheap c :: CallTree
c@(CLeaf (Id
nm2,(Binding Id
_ SrcSpan
_ InlineSpec
inl2 IsPrim
_ Term
e)))
| InlineSpec
inl2 InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
== InlineSpec
NoInline = (Maybe (Id, Term)
forall a. Maybe a
Nothing ,[CallTree
c])
| Bool
otherwise = ((Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm2,Term
e),[])
goCheap c :: CallTree
c@(CBranch (Id
nm2,(Binding Id
_ SrcSpan
_ InlineSpec
inl2 IsPrim
_ Term
e)) [CallTree]
us)
| InlineSpec
inl2 InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
== InlineSpec
NoInline = (Maybe (Id, Term)
forall a. Maybe a
Nothing, [CallTree
c])
| Bool
otherwise = ((Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm2,Term
e),[CallTree]
us)
callTreeToList :: [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList :: [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList [Id]
visited (CLeaf (Id
nm,Binding Term
bndr))
| Id
nm Id -> [Id] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Id]
visited = ([Id]
visited,[])
| Bool
otherwise = (Id
nmId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited,[(Id
nm,Binding Term
bndr)])
callTreeToList [Id]
visited (CBranch (Id
nm,Binding Term
bndr) [CallTree]
used)
| Id
nm Id -> [Id] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Id]
visited = ([Id]
visited,[])
| Bool
otherwise = ([Id]
visited',(Id
nm,Binding Term
bndr)(Id, Binding Term) -> [(Id, Binding Term)] -> [(Id, Binding Term)]
forall a. a -> [a] -> [a]
:([[(Id, Binding Term)]] -> [(Id, Binding Term)]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[(Id, Binding Term)]]
others))
where
([Id]
visited',[[(Id, Binding Term)]]
others) = ([Id] -> CallTree -> ([Id], [(Id, Binding Term)]))
-> [Id] -> [CallTree] -> ([Id], [[(Id, Binding Term)]])
forall (t :: Type -> Type) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL [Id] -> CallTree -> ([Id], [(Id, Binding Term)])
callTreeToList (Id
nmId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited) [CallTree]
used