{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Normalize.Util
( isConstantArg
, shouldReduce
, alreadyInlined
, addNewInline
, specializeNorm
, isRecursiveBndr
, isClosed
, callGraph
, classifyFunction
, isCheapFunction
, isNonRecursiveGlobalVar
, canConstantSpec
, normalizeTopLvlBndr
, rewriteExpr
, removedTm
)
where
import Control.Lens ((&),(+~),(%=),(^.),_4,(.=))
import qualified Control.Lens as Lens
import qualified Data.List as List
import qualified Data.Map as Map
import qualified Data.HashMap.Strict as HashMapS
import Data.Text (Text)
import BasicTypes (InlineSpec)
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.FreeVars
(globalIds, hasLocalFreeVars, globalIdOccursIn)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst (deShadowTerm)
import Clash.Core.Term
(Context, CoreContext(AppArg), PrimInfo (..), Term (..), WorkInfo (..),
collectArgs)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (Type, undefinedTy)
import Clash.Core.Util (isClockOrReset, isPolyFun, termType)
import Clash.Core.Var (Id, Var (..), isGlobalId)
import Clash.Core.VarEnv
(VarEnv, emptyInScopeSet, emptyVarEnv, extendVarEnv, extendVarEnvWith,
lookupVarEnv, unionVarEnvWith, unitVarEnv)
import Clash.Driver.Types (BindingMap, DebugLevel (..))
import {-# SOURCE #-} Clash.Normalize.Strategy (normalization)
import Clash.Normalize.Types
import Clash.Primitives.Util (constantArgs)
import Clash.Rewrite.Types
(RewriteMonad, bindings, curFun, dbgLevel, extra, tcCache)
import Clash.Rewrite.Util (runRewrite, specialise)
import Clash.Unique
import Clash.Util (SrcSpan, anyM, makeCachedU, traceIf)
isConstantArg
:: Text
-> Int
-> RewriteMonad NormalizeState Bool
isConstantArg :: Text -> Int -> RewriteMonad NormalizeState Bool
isConstantArg nm :: Text
nm i :: Int
i = do
Map Text (Set Int)
argMap <- Getting
(Map Text (Set Int))
(RewriteState NormalizeState)
(Map Text (Set Int))
-> RewriteMonad NormalizeState (Map Text (Set Int))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const (Map Text (Set Int)) NormalizeState)
-> RewriteState NormalizeState
-> Const (Map Text (Set Int)) (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Const (Map Text (Set Int)) NormalizeState)
-> RewriteState NormalizeState
-> Const (Map Text (Set Int)) (RewriteState NormalizeState))
-> ((Map Text (Set Int)
-> Const (Map Text (Set Int)) (Map Text (Set Int)))
-> NormalizeState -> Const (Map Text (Set Int)) NormalizeState)
-> Getting
(Map Text (Set Int))
(RewriteState NormalizeState)
(Map Text (Set Int))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map Text (Set Int)
-> Const (Map Text (Set Int)) (Map Text (Set Int)))
-> NormalizeState -> Const (Map Text (Set Int)) NormalizeState
Lens' NormalizeState (Map Text (Set Int))
primitiveArgs)
case Text -> Map Text (Set Int) -> Maybe (Set Int)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
nm Map Text (Set Int)
argMap of
Nothing -> do
CompiledPrimMap
prims <- Getting
CompiledPrimMap (RewriteState NormalizeState) CompiledPrimMap
-> RewriteMonad NormalizeState CompiledPrimMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const CompiledPrimMap NormalizeState)
-> RewriteState NormalizeState
-> Const CompiledPrimMap (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Const CompiledPrimMap NormalizeState)
-> RewriteState NormalizeState
-> Const CompiledPrimMap (RewriteState NormalizeState))
-> ((CompiledPrimMap -> Const CompiledPrimMap CompiledPrimMap)
-> NormalizeState -> Const CompiledPrimMap NormalizeState)
-> Getting
CompiledPrimMap (RewriteState NormalizeState) CompiledPrimMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(CompiledPrimMap -> Const CompiledPrimMap CompiledPrimMap)
-> NormalizeState -> Const CompiledPrimMap NormalizeState
Lens' NormalizeState CompiledPrimMap
primitives)
case PrimitiveGuard CompiledPrimitive -> Maybe CompiledPrimitive
forall a. PrimitiveGuard a -> Maybe a
extractPrim (PrimitiveGuard CompiledPrimitive -> Maybe CompiledPrimitive)
-> Maybe (PrimitiveGuard CompiledPrimitive)
-> Maybe CompiledPrimitive
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> CompiledPrimMap -> Maybe (PrimitiveGuard CompiledPrimitive)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMapS.lookup Text
nm CompiledPrimMap
prims of
Nothing ->
Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
Just p :: CompiledPrimitive
p -> do
let m :: Set Int
m = Text -> CompiledPrimitive -> Set Int
constantArgs Text
nm CompiledPrimitive
p
((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> ((Map Text (Set Int) -> Identity (Map Text (Set Int)))
-> NormalizeState -> Identity NormalizeState)
-> (Map Text (Set Int) -> Identity (Map Text (Set Int)))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Map Text (Set Int) -> Identity (Map Text (Set Int)))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (Map Text (Set Int))
primitiveArgs) ((Map Text (Set Int) -> Identity (Map Text (Set Int)))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (Map Text (Set Int) -> Map Text (Set Int))
-> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
Lens.%= Text -> Set Int -> Map Text (Set Int) -> Map Text (Set Int)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Text
nm Set Int
m
Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
i Int -> Set Int -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Set Int
m)
Just m :: Set Int
m ->
Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
i Int -> Set Int -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Set Int
m)
shouldReduce
:: Context
-> RewriteMonad NormalizeState Bool
shouldReduce :: Context -> RewriteMonad NormalizeState Bool
shouldReduce = (CoreContext -> RewriteMonad NormalizeState Bool)
-> Context -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => (a -> m Bool) -> [a] -> m Bool
anyM CoreContext -> RewriteMonad NormalizeState Bool
isConstantArg'
where
isConstantArg' :: CoreContext -> RewriteMonad NormalizeState Bool
isConstantArg' (AppArg (Just (nm :: Text
nm, _, i :: Int
i))) = Text -> Int -> RewriteMonad NormalizeState Bool
isConstantArg Text
nm Int
i
isConstantArg' _ = Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
alreadyInlined
:: Id
-> Id
-> NormalizeMonad (Maybe Int)
alreadyInlined :: Id -> Id -> NormalizeMonad (Maybe Int)
alreadyInlined f :: Id
f cf :: Id
cf = do
VarEnv (VarEnv Int)
inlinedHM <- Getting (VarEnv (VarEnv Int)) NormalizeState (VarEnv (VarEnv Int))
-> StateT NormalizeState Identity (VarEnv (VarEnv Int))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (VarEnv (VarEnv Int)) NormalizeState (VarEnv (VarEnv Int))
Lens' NormalizeState (VarEnv (VarEnv Int))
inlineHistory
case Id -> VarEnv (VarEnv Int) -> Maybe (VarEnv Int)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
cf VarEnv (VarEnv Int)
inlinedHM of
Nothing -> Maybe Int -> NormalizeMonad (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Int
forall a. Maybe a
Nothing
Just inlined' :: VarEnv Int
inlined' -> Maybe Int -> NormalizeMonad (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv Int
inlined')
addNewInline
:: Id
-> Id
-> NormalizeMonad ()
addNewInline :: Id -> Id -> NormalizeMonad ()
addNewInline f :: Id
f cf :: Id
cf =
(VarEnv (VarEnv Int) -> Identity (VarEnv (VarEnv Int)))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (VarEnv (VarEnv Int))
inlineHistory ((VarEnv (VarEnv Int) -> Identity (VarEnv (VarEnv Int)))
-> NormalizeState -> Identity NormalizeState)
-> (VarEnv (VarEnv Int) -> VarEnv (VarEnv Int))
-> NormalizeMonad ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id
-> VarEnv Int
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv (VarEnv Int)
-> VarEnv (VarEnv Int)
forall b a. Var b -> a -> (a -> a -> a) -> VarEnv a -> VarEnv a
extendVarEnvWith
Id
cf
(Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
unitVarEnv Id
f 1)
(\_ hm :: VarEnv Int
hm -> Id -> Int -> (Int -> Int -> Int) -> VarEnv Int -> VarEnv Int
forall b a. Var b -> a -> (a -> a -> a) -> VarEnv a -> VarEnv a
extendVarEnvWith Id
f 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) VarEnv Int
hm)
specializeNorm :: NormRewrite
specializeNorm :: NormRewrite
specializeNorm = Lens' NormalizeState (Map (Id, Int, Either Term Type) Id)
-> Lens' NormalizeState (VarEnv Int)
-> Lens' NormalizeState Int
-> NormRewrite
forall extra.
Lens' extra (Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int) -> Lens' extra Int -> Rewrite extra
specialise Lens' NormalizeState (Map (Id, Int, Either Term Type) Id)
specialisationCache Lens' NormalizeState (VarEnv Int)
specialisationHistory Lens' NormalizeState Int
specialisationLimit
isClosed :: TyConMap
-> Term
-> Bool
isClosed :: TyConMap -> Term -> Bool
isClosed tcm :: TyConMap
tcm = Bool -> Bool
not (Bool -> Bool) -> (Term -> Bool) -> Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Term -> Bool
isPolyFun TyConMap
tcm
isNonRecursiveGlobalVar
:: Term
-> NormalizeSession Bool
isNonRecursiveGlobalVar :: Term -> RewriteMonad NormalizeState Bool
isNonRecursiveGlobalVar (Term -> (Term, [Either Term Type])
collectArgs -> (Var i :: Id
i, _args :: [Either Term Type]
_args)) = do
let eIsGlobal :: Bool
eIsGlobal = Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
i
Bool
eIsRec <- Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
i
Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
eIsGlobal Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
eIsRec)
isNonRecursiveGlobalVar _ = Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
isRecursiveBndr
:: Id
-> NormalizeSession Bool
isRecursiveBndr :: Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr f :: Id
f = do
VarEnv Bool
cg <- Getting (VarEnv Bool) (RewriteState NormalizeState) (VarEnv Bool)
-> RewriteMonad NormalizeState (VarEnv Bool)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> RewriteState NormalizeState
-> Const (VarEnv Bool) (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> RewriteState NormalizeState
-> Const (VarEnv Bool) (RewriteState NormalizeState))
-> ((VarEnv Bool -> Const (VarEnv Bool) (VarEnv Bool))
-> NormalizeState -> Const (VarEnv Bool) NormalizeState)
-> Getting
(VarEnv Bool) (RewriteState NormalizeState) (VarEnv Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Bool -> Const (VarEnv Bool) (VarEnv Bool))
-> NormalizeState -> Const (VarEnv Bool) NormalizeState
Lens' NormalizeState (VarEnv Bool)
recursiveComponents)
case Id -> VarEnv Bool -> Maybe Bool
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv Bool
cg of
Just isR :: Bool
isR -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
isR
Nothing -> do
Maybe (Id, SrcSpan, InlineSpec, Term)
fBodyM <- Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
NormalizeState (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
(VarEnv (Id, SrcSpan, InlineSpec, Term))
(RewriteState NormalizeState)
(VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
(VarEnv (Id, SrcSpan, InlineSpec, Term))
(RewriteState NormalizeState)
(VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra1.
Lens'
(RewriteState extra1) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
case Maybe (Id, SrcSpan, InlineSpec, Term)
fBodyM of
Nothing -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Just (_,_,_,fBody :: Term
fBody) -> do
let isR :: Bool
isR = Id
f Id -> Term -> Bool
`globalIdOccursIn` Term
fBody
((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> Identity NormalizeState)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> ((VarEnv Bool -> Identity (VarEnv Bool))
-> NormalizeState -> Identity NormalizeState)
-> (VarEnv Bool -> Identity (VarEnv Bool))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv Bool -> Identity (VarEnv Bool))
-> NormalizeState -> Identity NormalizeState
Lens' NormalizeState (VarEnv Bool)
recursiveComponents) ((VarEnv Bool -> Identity (VarEnv Bool))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (VarEnv Bool -> VarEnv Bool) -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Id -> Bool -> VarEnv Bool -> VarEnv Bool
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
f Bool
isR
Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
isR
canConstantSpec
:: Term
-> RewriteMonad NormalizeState Bool
canConstantSpec :: Term -> RewriteMonad NormalizeState Bool
canConstantSpec e :: Term
e = do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm (TyConMap -> Term -> Type
termType TyConMap
tcm Term
e) then
case Term -> (Term, [Either Term Type])
collectArgs Term
e of
(Prim nm :: Text
nm _, _) -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Transformations.removedArg")
_ -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
else
case Term -> (Term, [Either Term Type])
collectArgs Term
e of
(Data _, args :: [Either Term Type]
args) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool)
-> RewriteMonad NormalizeState [Bool]
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Either Term Type -> RewriteMonad NormalizeState Bool)
-> [Either Term Type] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Term -> RewriteMonad NormalizeState Bool)
-> (Type -> RewriteMonad NormalizeState Bool)
-> Either Term Type
-> RewriteMonad NormalizeState Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> RewriteMonad NormalizeState Bool
canConstantSpec (RewriteMonad NormalizeState Bool
-> Type -> RewriteMonad NormalizeState Bool
forall a b. a -> b -> a
const (Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True))) [Either Term Type]
args
(Prim _ _, args :: [Either Term Type]
args) -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool)
-> RewriteMonad NormalizeState [Bool]
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Either Term Type -> RewriteMonad NormalizeState Bool)
-> [Either Term Type] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Term -> RewriteMonad NormalizeState Bool)
-> (Type -> RewriteMonad NormalizeState Bool)
-> Either Term Type
-> RewriteMonad NormalizeState Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> RewriteMonad NormalizeState Bool
canConstantSpec (RewriteMonad NormalizeState Bool
-> Type -> RewriteMonad NormalizeState Bool
forall a b. a -> b -> a
const (Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True))) [Either Term Type]
args
(Lam _ _, _) -> Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e))
(Var f :: Id
f, args :: [Either Term Type]
args) -> do
(curF :: Id
curF, _) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
Bool
argsConst <- [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool)
-> RewriteMonad NormalizeState [Bool]
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Either Term Type -> RewriteMonad NormalizeState Bool)
-> [Either Term Type] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Term -> RewriteMonad NormalizeState Bool)
-> (Type -> RewriteMonad NormalizeState Bool)
-> Either Term Type
-> RewriteMonad NormalizeState Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> RewriteMonad NormalizeState Bool
canConstantSpec (RewriteMonad NormalizeState Bool
-> Type -> RewriteMonad NormalizeState Bool
forall a b. a -> b -> a
const (Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True))) [Either Term Type]
args
Bool
isNonRecGlobVar <- Term -> RewriteMonad NormalizeState Bool
isNonRecursiveGlobalVar Term
e
Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
argsConst Bool -> Bool -> Bool
&& Bool
isNonRecGlobVar Bool -> Bool -> Bool
&& Id
f Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
/= Id
curF)
(Literal _,_) -> Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
_ -> Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
type CallGraph = VarEnv (VarEnv Word)
callGraph
:: BindingMap
-> Id
-> CallGraph
callGraph :: VarEnv (Id, SrcSpan, InlineSpec, Term) -> Id -> CallGraph
callGraph bndrs :: VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs rt :: Id
rt = CallGraph -> Int -> CallGraph
forall a. Num a => UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a)
go CallGraph
forall a. VarEnv a
emptyVarEnv (Id -> Int
forall a. Var a -> Int
varUniq Id
rt)
where
go :: UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a)
go cg :: UniqMap (VarEnv a)
cg root :: Int
root
| Maybe (VarEnv a)
Nothing <- Int -> UniqMap (VarEnv a) -> Maybe (VarEnv a)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap Int
root UniqMap (VarEnv a)
cg
, Just rootTm :: (Id, SrcSpan, InlineSpec, Term)
rootTm <- Int
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap Int
root VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs =
let used :: VarEnv a
used = Fold Term Id
-> (VarEnv a -> VarEnv a -> VarEnv a)
-> VarEnv a
-> (Id -> VarEnv a)
-> Term
-> VarEnv a
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf Fold Term Id
globalIds ((a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith a -> a -> a
forall a. Num a => a -> a -> a
(+))
VarEnv a
forall a. VarEnv a
emptyVarEnv (Id -> a -> VarEnv a
forall a b. Uniquable a => a -> b -> UniqMap b
`unitUniqMap` 1) ((Id, SrcSpan, InlineSpec, Term)
rootTm (Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4)
cg' :: UniqMap (VarEnv a)
cg' = Int -> VarEnv a -> UniqMap (VarEnv a) -> UniqMap (VarEnv a)
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap Int
root VarEnv a
used UniqMap (VarEnv a)
cg
in (UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a))
-> UniqMap (VarEnv a) -> [Int] -> UniqMap (VarEnv a)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' UniqMap (VarEnv a) -> Int -> UniqMap (VarEnv a)
go UniqMap (VarEnv a)
cg' (VarEnv a -> [Int]
forall a. UniqMap a -> [Int]
keysUniqMap VarEnv a
used)
go cg :: UniqMap (VarEnv a)
cg _ = UniqMap (VarEnv a)
cg
classifyFunction
:: Term
-> TermClassification
classifyFunction :: Term -> TermClassification
classifyFunction = TermClassification -> Term -> TermClassification
go (Int -> Int -> Int -> TermClassification
TermClassification 0 0 0)
where
go :: TermClassification -> Term -> TermClassification
go !TermClassification
c (Lam _ e :: Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
go !TermClassification
c (TyLam _ e :: Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
go !TermClassification
c (Letrec bs :: [LetBinding]
bs _) = (TermClassification -> Term -> TermClassification)
-> TermClassification -> [Term] -> TermClassification
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' TermClassification -> Term -> TermClassification
go TermClassification
c ((LetBinding -> Term) -> [LetBinding] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Term
forall a b. (a, b) -> b
snd [LetBinding]
bs)
go !TermClassification
c e :: Term
e@(App {}) = case (Term, [Either Term Type]) -> Term
forall a b. (a, b) -> a
fst (Term -> (Term, [Either Term Type])
collectArgs Term
e) of
Prim {} -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Int
primitive ((Int -> Identity Int)
-> TermClassification -> Identity TermClassification)
-> Int -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ 1
Var {} -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Int
function ((Int -> Identity Int)
-> TermClassification -> Identity TermClassification)
-> Int -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ 1
_ -> TermClassification
c
go !TermClassification
c (Case _ _ alts :: [Alt]
alts) = case [Alt]
alts of
(_:_:_) -> TermClassification
c TermClassification
-> (TermClassification -> TermClassification) -> TermClassification
forall a b. a -> (a -> b) -> b
& (Int -> Identity Int)
-> TermClassification -> Identity TermClassification
Lens' TermClassification Int
selection ((Int -> Identity Int)
-> TermClassification -> Identity TermClassification)
-> Int -> TermClassification -> TermClassification
forall a s t. Num a => ASetter s t a a -> a -> s -> t
+~ 1
_ -> TermClassification
c
go !TermClassification
c (Tick _ e :: Term
e) = TermClassification -> Term -> TermClassification
go TermClassification
c Term
e
go c :: TermClassification
c _ = TermClassification
c
isCheapFunction
:: Term
-> Bool
isCheapFunction :: Term -> Bool
isCheapFunction tm :: Term
tm = case Term -> TermClassification
classifyFunction Term
tm of
TermClassification {..}
| Int
_function Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 -> Int
_primitive Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 Bool -> Bool -> Bool
&& Int
_selection Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0
| Int
_primitive Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 -> Int
_function Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 Bool -> Bool -> Bool
&& Int
_selection Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0
| Int
_selection Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 -> Int
_function Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 Bool -> Bool -> Bool
&& Int
_primitive Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0
| Bool
otherwise -> Bool
False
normalizeTopLvlBndr
:: Id
-> (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
normalizeTopLvlBndr :: Id
-> (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
normalizeTopLvlBndr nm :: Id
nm (nm' :: Id
nm',sp :: SrcSpan
sp,inl :: InlineSpec
inl,tm :: Term
tm) = Id
-> Lens'
(RewriteState NormalizeState)
(VarEnv (Id, SrcSpan, InlineSpec, Term))
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
forall s (m :: * -> *) k v.
(MonadState s m, Uniquable k) =>
k -> Lens' s (UniqMap v) -> m v -> m v
makeCachedU Id
nm ((NormalizeState -> f NormalizeState)
-> RewriteState NormalizeState -> f (RewriteState NormalizeState)
forall extra1 extra2.
Lens (RewriteState extra1) (RewriteState extra2) extra1 extra2
extra((NormalizeState -> f NormalizeState)
-> RewriteState NormalizeState -> f (RewriteState NormalizeState))
-> ((VarEnv (Id, SrcSpan, InlineSpec, Term)
-> f (VarEnv (Id, SrcSpan, InlineSpec, Term)))
-> NormalizeState -> f NormalizeState)
-> (VarEnv (Id, SrcSpan, InlineSpec, Term)
-> f (VarEnv (Id, SrcSpan, InlineSpec, Term)))
-> RewriteState NormalizeState
-> f (RewriteState NormalizeState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(VarEnv (Id, SrcSpan, InlineSpec, Term)
-> f (VarEnv (Id, SrcSpan, InlineSpec, Term)))
-> NormalizeState -> f NormalizeState
Lens' NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
normalized) (NormalizeSession (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term))
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
forall a b. (a -> b) -> a -> b
$ do
TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
let nmS :: String
nmS = Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)
let tm1 :: Term
tm1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
emptyInScopeSet Term
tm
(Id, SrcSpan)
old <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun
Term
tm2 <- (String, NormRewrite)
-> (String, Term) -> (Id, SrcSpan) -> NormalizeSession Term
rewriteExpr ("normalization",NormRewrite
normalization) (String
nmS,Term
tm1) (Id
nm',SrcSpan
sp)
((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun (((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (Id, SrcSpan) -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= (Id, SrcSpan)
old
let ty' :: Type
ty' = TyConMap -> Term -> Type
termType TyConMap
tcm Term
tm2
(Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
nm' {varType :: Type
varType = Type
ty'},SrcSpan
sp,InlineSpec
inl,Term
tm2)
rewriteExpr :: (String,NormRewrite)
-> (String,Term)
-> (Id, SrcSpan)
-> NormalizeSession Term
rewriteExpr :: (String, NormRewrite)
-> (String, Term) -> (Id, SrcSpan) -> NormalizeSession Term
rewriteExpr (nrwS :: String
nrwS,nrw :: NormRewrite
nrw) (bndrS :: String
bndrS,expr :: Term
expr) (nm :: Id
nm, sp :: SrcSpan
sp) = do
((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra1. Lens' (RewriteState extra1) (Id, SrcSpan)
curFun (((Id, SrcSpan) -> Identity (Id, SrcSpan))
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState))
-> (Id, SrcSpan) -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= (Id
nm, SrcSpan
sp)
DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad NormalizeState DebugLevel
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
let before :: String
before = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
expr
let expr' :: Term
expr' = Bool -> String -> Term -> Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugFinal)
(String
bndrS String -> String -> String
forall a. [a] -> [a] -> [a]
++ " before " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
nrwS String -> String -> String
forall a. [a] -> [a] -> [a]
++ ":\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n")
Term
expr
Term
rewritten <- String
-> InScopeSet -> NormRewrite -> Term -> NormalizeSession Term
forall extra.
String
-> InScopeSet -> Rewrite extra -> Term -> RewriteMonad extra Term
runRewrite String
nrwS InScopeSet
emptyInScopeSet NormRewrite
nrw Term
expr'
let after :: String
after = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
rewritten
Bool -> String -> NormalizeSession Term -> NormalizeSession Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugFinal)
(String
bndrS String -> String -> String
forall a. [a] -> [a] -> [a]
++ " after " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
nrwS String -> String -> String
forall a. [a] -> [a] -> [a]
++ ":\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n") (NormalizeSession Term -> NormalizeSession Term)
-> NormalizeSession Term -> NormalizeSession Term
forall a b. (a -> b) -> a -> b
$
Term -> NormalizeSession Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
rewritten
removedTm
:: Type
-> Term
removedTm :: Type -> Term
removedTm =
Term -> Type -> Term
TyApp (Text -> PrimInfo -> Term
Prim "Clash.Transformations.removedArg" (Type -> WorkInfo -> PrimInfo
PrimInfo Type
undefinedTy WorkInfo
WorkNever))