{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations
( caseLet
, caseCon
, caseCase
, caseElemNonReachable
, elemExistentials
, inlineNonRep
, inlineOrLiftNonRep
, typeSpec
, nonRepSpec
, etaExpansionTL
, nonRepANF
, bindConstantVar
, constantSpec
, makeANF
, deadCode
, topLet
, recToLetRec
, inlineWorkFree
, inlineHO
, inlineSmall
, simpleCSE
, reduceConst
, reduceNonRepPrim
, caseFlat
, disjointExpressionConsolidation
, removeUnusedExpr
, inlineCleanup
, flattenLet
, splitCastWork
, inlineCast
, caseCast
, letCast
, eliminateCastCast
, argCastSpec
, etaExpandSyn
, appPropFast
, separateArguments
, separateLambda
, xOptimize
)
where
import Control.Exception (throw)
import Control.Lens (_2)
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
import Control.Monad.State (StateT (..), modify)
import Control.Monad.State.Strict (evalState)
import Control.Monad.Writer (lift, listen)
import Control.Monad.Trans.Except (runExcept)
import Data.Coerce (coerce)
import qualified Data.Either as Either
import qualified Data.HashMap.Lazy as HashMap
import qualified Data.HashMap.Strict as HashMapS
import Data.List ((\\))
import qualified Data.List as List
import qualified Data.List.Extra as List
import qualified Data.Maybe as Maybe
import qualified Data.Monoid as Monoid
import qualified Data.Primitive.ByteArray as BA
import qualified Data.Text as Text
import qualified Data.Vector.Primitive as PV
import GHC.Integer.GMP.Internals (Integer (..), BigNat (..))
import BasicTypes (InlineSpec (..))
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.DataCon (DataCon (..))
import Clash.Core.EqSolver
import Clash.Core.Name
(Name (..), NameSort (..), mkUnsafeSystemName, nameOcc)
import Clash.Core.FreeVars
(localIdOccursIn, localIdsDoNotOccurIn, freeLocalIds, termFreeTyVars,
typeFreeVars, localVarsDoNotOccurIn, localIdDoesNotOccurIn,
countFreeOccurances)
import Clash.Core.Literal (Literal (..))
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
import Clash.Core.Term
import Clash.Core.TermInfo
import Clash.Core.Type (Type (..), TypeView (..), applyFunTy,
isPolyFunCoreTy, isClassTy,
normalizeType, splitFunForallTy,
splitFunTy,
tyView, mkPolyFunTy, coreView,
LitTy (..), coreView1)
import Clash.Core.TyCon (TyConMap, tyConDataCons)
import Clash.Core.Util
( isSignalType, mkVec, tyNatSize, undefinedTm,
shouldSplit, inverseTopSortLetBindings)
import Clash.Core.Var
(Id, TyVar, Var (..), isGlobalId, isLocalId, mkLocalId)
import Clash.Core.VarEnv
(InScopeSet, VarEnv, VarSet, elemVarSet,
emptyVarEnv, extendInScopeSet, extendInScopeSetList, lookupVarEnv,
notElemVarSet, unionVarEnvWith, unionInScope, unitVarEnv,
unitVarSet, mkVarSet, mkInScopeSet, uniqAway, elemInScopeSet, elemVarEnv,
foldlWithUniqueVarEnv', lookupVarEnvDirectly, extendVarEnv, unionVarEnv,
eltsVarEnv, mkVarEnv, eltsVarSet)
import Clash.Debug
import Clash.Driver.Types (Binding(..), DebugLevel (..))
import Clash.Netlist.BlackBox.Types (Element(Err))
import Clash.Netlist.BlackBox.Util (getUsedArguments)
import Clash.Netlist.Types (BlackBox(..), HWType (..), FilteredHWType(..))
import Clash.Netlist.Util
(coreTypeToHWType, representableType, splitNormalized, bindsExistentials)
import Clash.Normalize.DEC
import Clash.Normalize.PrimitiveReductions
import Clash.Normalize.Types
import Clash.Normalize.Util
import Clash.Primitives.Types
(Primitive(..), TemplateKind(TExpr), CompiledPrimMap, UsedArguments(..))
import Clash.Rewrite.Combinators
import Clash.Rewrite.Types
import Clash.Rewrite.Util
import Clash.Unique (Unique, lookupUniqMap)
import Clash.Util
inlineOrLiftNonRep :: HasCallStack => NormRewrite
inlineOrLiftNonRep ctx eLet@(Letrec _ body) =
inlineOrLiftBinders nonRepTest inlineTest ctx eLet
where
bodyFreeOccs = countFreeOccurances body
nonRepTest :: (Id, Term) -> RewriteMonad extra Bool
nonRepTest (Id {varType = ty}, _)
= not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure False
<*> Lens.view tcCache
<*> pure ty)
nonRepTest _ = return False
inlineTest :: Term -> (Id, Term) -> Bool
inlineTest e (id_, e') =
not $ or
[
isJoinPointIn id_ e && not (isVoidWrapper e')
, maybe False (>1) (lookupVarEnv id_ bodyFreeOccs)
]
inlineOrLiftNonRep _ e = return e
{-# SCC inlineOrLiftNonRep #-}
typeSpec :: HasCallStack => NormRewrite
typeSpec ctx e@(TyApp e1 ty)
| (Var {}, args) <- collectArgs e1
, null $ Lens.toListOf typeFreeVars ty
, (_, []) <- Either.partitionEithers args
= specializeNorm ctx e
typeSpec _ e = return e
{-# SCC typeSpec #-}
nonRepSpec :: HasCallStack => NormRewrite
nonRepSpec ctx e@(App e1 e2)
| (Var {}, args) <- collectArgs e1
, (_, []) <- Either.partitionEithers args
, null $ Lens.toListOf termFreeTyVars e2
= do tcm <- Lens.view tcCache
let e2Ty = termType tcm e2
let localVar = isLocalVar e2
nonRepE2 <- not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure False
<*> Lens.view tcCache
<*> pure e2Ty)
if nonRepE2 && not localVar
then do
e2' <- inlineInternalSpecialisationArgument e2
specializeNorm ctx (App e1 e2')
else return e
where
inlineInternalSpecialisationArgument
:: Term
-> NormalizeSession Term
inlineInternalSpecialisationArgument app
| (Var f,fArgs,ticks) <- collectArgsTicks app
= do
fTmM <- lookupVarEnv f <$> Lens.use bindings
case fTmM of
Just b
| nameSort (varName (bindingId b)) == Internal
-> censor (const mempty)
(topdownR appPropFast ctx
(mkApps (mkTicks (bindingTerm b) ticks) fArgs))
_ -> return app
| otherwise = return app
nonRepSpec _ e = return e
{-# SCC nonRepSpec #-}
caseLet :: HasCallStack => NormRewrite
caseLet (TransformContext is0 _) (Case (collectTicks -> (Letrec xes e,ticks)) ty alts) = do
let (xes1,e1) = deshadowLetExpr is0 xes e
changed (Letrec (map (second (`mkTicks` ticks)) xes1)
(Case (mkTicks e1 ticks) ty alts))
caseLet _ e = return e
{-# SCC caseLet #-}
caseElemNonReachable :: HasCallStack => NormRewrite
caseElemNonReachable _ case0@(Case scrut altsTy alts0) = do
tcm <- Lens.view tcCache
let (altsAbsurd, altsOther) = List.partition (isAbsurdAlt tcm) alts0
case altsAbsurd of
[] -> return case0
_ -> changed =<< caseOneAlt (Case scrut altsTy altsOther)
caseElemNonReachable _ e = return e
{-# SCC caseElemNonReachable #-}
elemExistentials :: HasCallStack => NormRewrite
elemExistentials (TransformContext is0 _) (Case scrut altsTy alts0) = do
tcm <- Lens.view tcCache
alts1 <- mapM (go is0 tcm) alts0
caseOneAlt (Case scrut altsTy alts1)
where
go :: InScopeSet -> TyConMap -> (Pat, Term) -> NormalizeSession (Pat, Term)
go is2 tcm alt@(DataPat dc exts0 xs0, term0) =
case solveNonAbsurds tcm (altEqs tcm alt) of
[] -> return alt
sols ->
changed =<< go is2 tcm (DataPat dc exts1 xs1, term1)
where
is3 = extendInScopeSetList is2 exts0
xs1 = map (substTyInVar (extendTvSubstList (mkSubst is3) sols)) xs0
exts1 = substInExistentialsList is2 exts0 sols
is4 = extendInScopeSetList is3 xs1
subst = extendTvSubstList (mkSubst is4) sols
term1 = substTm "Replacing tyVar due to solved eq" subst term0
go _ _ alt = return alt
elemExistentials _ e = return e
{-# SCC elemExistentials #-}
caseCase :: HasCallStack => NormRewrite
caseCase (TransformContext is0 _) e@(Case (stripTicks -> Case scrut alts1Ty alts1) alts2Ty alts2)
= do
ty1Rep <- representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure False
<*> Lens.view tcCache
<*> pure alts1Ty
if not ty1Rep
then let newAlts = map
(second (\altE -> Case altE alts2Ty alts2))
(map (deShadowAlt is0) alts1)
in changed $ Case scrut alts2Ty newAlts
else return e
caseCase _ e = return e
{-# SCC caseCase #-}
inlineNonRep :: HasCallStack => NormRewrite
inlineNonRep _ e@(Case scrut altsTy alts)
| (Var f, args,ticks) <- collectArgsTicks scrut
, isGlobalId f
= do
(cf,_) <- Lens.use curFun
isInlined <- zoomExtra (alreadyInlined f cf)
limit <- Lens.use (extra.inlineLimit)
tcm <- Lens.view tcCache
let scrutTy = termType tcm scrut
noException = not (exception tcm scrutTy)
if noException && (Maybe.fromMaybe 0 isInlined) > limit
then
trace (concat [ $(curLoc) ++ "InlineNonRep: " ++ showPpr (varName f)
," already inlined " ++ show limit ++ " times in:"
, showPpr (varName cf)
, "\nType of the subject is: " ++ showPpr scrutTy
, "\nFunction " ++ showPpr (varName cf)
, " will not reach a normal form, and compilation"
, " might fail."
, "\nRun with '-fclash-inline-limit=N' to increase"
, " the inlining limit to N."
])
(return e)
else do
bodyMaybe <- lookupVarEnv f <$> Lens.use bindings
nonRepScrut <- not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure False
<*> Lens.view tcCache
<*> pure scrutTy)
case (nonRepScrut, bodyMaybe) of
(True,Just b) -> do
Monad.when noException (zoomExtra (addNewInline f cf))
let scrutBody0 = mkTicks (bindingTerm b) (mkInlineTick f : ticks)
let scrutBody1 = mkApps scrutBody0 args
changed $ Case scrutBody1 altsTy alts
_ -> return e
where
exception = isClassTy
inlineNonRep _ e = return e
{-# SCC inlineNonRep #-}
caseCon :: HasCallStack => NormRewrite
caseCon ctx@(TransformContext is0 _) e@(Case subj ty alts) = do
tcm <- Lens.view tcCache
case collectArgsTicks subj of
(Data dc, args, ticks) -> case List.find (equalCon . fst) alts of
Just (DataPat _ tvs xs, altE) -> do
let is1 = extendInScopeSetList (extendInScopeSetList is0 tvs) xs
let fvs = Lens.foldMapOf freeLocalIds unitVarSet altE
(binds,_) = List.partition ((`elemVarSet` fvs) . fst)
$ zip xs (Either.lefts args)
binds1 = map (second (`mkTicks` ticks)) binds
altE1 = case binds1 of
[] -> altE
_ ->
let
((is3,substIds),binds2) = List.mapAccumL newBinder (is1,[]) binds1
subst = extendIdSubstList (mkSubst is3) substIds
body = substTm "caseCon0" subst altE
in
case Maybe.catMaybes binds2 of
[] -> body
binds3 -> Letrec binds3 body
let subst = extendTvSubstList (mkSubst is0)
$ zip tvs (drop (length (dcUnivTyVars dc)) (Either.rights args))
changed (substTm "caseCon1" subst altE1)
_ -> case alts of
((DefaultPat,altE):_) -> changed altE
_ -> changed (undefinedTm ty)
where
equalCon (DataPat dcPat _ _) = dcTag dc == dcTag dcPat
equalCon _ = False
newBinder (isN0,substN) (x,arg)
| isWorkFree arg
= ((isN0,(x,arg):substN),Nothing)
| otherwise
= let x' = uniqAway isN0 x
isN1 = extendInScopeSet isN0 x'
in ((isN1,(x,Var x'):substN),Just (x',arg))
(Literal l,_,_) -> case List.find (equalLit . fst) alts of
Just (LitPat _,altE) -> changed altE
_ -> matchLiteralContructor e l alts
where
equalLit (LitPat l') = l == l'
equalLit _ = False
(Prim _,_,_) ->
whnfRW True ctx subj $ \ctx1 subj1 -> case collectArgsTicks subj1 of
(Literal l,_,_) -> caseCon ctx1 (Case (Literal l) ty alts)
(Data _,_,_) -> caseCon ctx1 (Case subj1 ty alts)
#if MIN_VERSION_ghc(8,2,2)
(Prim pInfo,_:msgOrCallStack:_,ticks)
| primName pInfo == "Control.Exception.Base.absentError" ->
let e1 = mkApps (mkTicks (Prim pInfo) ticks)
[Right ty,msgOrCallStack]
in changed e1
#endif
(Prim pInfo,repTy:_:msgOrCallStack:_,ticks)
| primName pInfo `elem` ["Control.Exception.Base.patError"
#if !MIN_VERSION_ghc(8,2,2)
,"Control.Exception.Base.absentError"
#endif
,"GHC.Err.undefined"] ->
let e1 = mkApps (mkTicks (Prim pInfo) ticks)
[repTy,Right ty,msgOrCallStack]
in changed e1
(Prim pInfo,[_],ticks)
| primName pInfo `elem` [ "Clash.Transformations.undefined"
, "Clash.GHC.Evaluator.undefined"
, "EmptyCase"] ->
let e1 = mkApps (mkTicks (Prim pInfo) ticks) [Right ty]
in changed e1
_ -> do
let subjTy = termType tcm subj
tran <- Lens.view typeTranslator
reprs <- Lens.view customReprs
case (`evalState` HashMapS.empty) (coreTypeToHWType tran reprs tcm subjTy) of
Right (FilteredHWType (Void (Just hty)) _areVoids)
| hty `elem` [BitVector 0, Unsigned 0, Signed 0, Index 1]
-> caseCon ctx1 (Case (Literal (IntegerLiteral 0)) ty alts)
_ -> do
let ret = caseOneAlt e
lvl <- Lens.view dbgLevel
if lvl > DebugNone then do
let subjIsConst = isConstant subj
traceIf (lvl > DebugNone && subjIsConst)
("Irreducible constant as case subject: " ++ showPpr subj ++
"\nCan be reduced to: " ++ showPpr subj1) ret
else
ret
(Var v, [], _) | isNum0 (varType v) ->
caseCon ctx (Case (Literal (IntegerLiteral 0)) ty alts)
where
isNum0 (tyView -> TyConApp (nameOcc -> tcNm) [arg])
| tcNm `elem`
["Clash.Sized.Internal.BitVector.BitVector"
,"Clash.Sized.Internal.Unsigned.Unsigned"
,"Clash.Sized.Internal.Signed.Signed"
]
= isLitX 0 arg
| tcNm ==
"Clash.Sized.Internal.Index.Index"
= isLitX 1 arg
isNum0 (coreView1 tcm -> Just t) = isNum0 t
isNum0 _ = False
isLitX n (LitTy (NumTy m)) = n == m
isLitX n (coreView1 tcm -> Just t) = isLitX n t
isLitX _ _ = False
_ -> caseOneAlt e
caseCon _ e = return e
{-# SCC caseCon #-}
matchLiteralContructor
:: Term
-> Literal
-> [(Pat,Term)]
-> NormalizeSession Term
matchLiteralContructor c (IntegerLiteral l) alts = go (reverse alts)
where
go [(DefaultPat,e)] = changed e
go ((DataPat dc [] xs,e):alts')
| dcTag dc == 1
, l >= ((-2)^(63::Int)) && l < 2^(63::Int)
= let fvs = Lens.foldMapOf freeLocalIds unitVarSet e
(binds,_) = List.partition ((`elemVarSet` fvs) . fst)
$ zip xs [Literal (IntLiteral l)]
e' = case binds of
[] -> e
_ -> Letrec binds e
in changed e'
| dcTag dc == 2
, l >= 2^(63::Int)
= let !(Jp# !(BN# ba)) = l
ba' = BA.ByteArray ba
bv = PV.Vector 0 (BA.sizeofByteArray ba') ba'
fvs = Lens.foldMapOf freeLocalIds unitVarSet e
(binds,_) = List.partition ((`elemVarSet` fvs) . fst)
$ zip xs [Literal (ByteArrayLiteral bv)]
e' = case binds of
[] -> e
_ -> Letrec binds e
in changed e'
| dcTag dc == 3
, l < ((-2)^(63::Int))
= let !(Jn# !(BN# ba)) = l
ba' = BA.ByteArray ba
bv = PV.Vector 0 (BA.sizeofByteArray ba') ba'
fvs = Lens.foldMapOf freeLocalIds unitVarSet e
(binds,_) = List.partition ((`elemVarSet` fvs) . fst)
$ zip xs [Literal (ByteArrayLiteral bv)]
e' = case binds of
[] -> e
_ -> Letrec binds e
in changed e'
| otherwise
= go alts'
go ((LitPat l', e):alts')
| IntegerLiteral l == l'
= changed e
| otherwise
= go alts'
go _ = error $ $(curLoc) ++ "Report as bug: caseCon error: " ++ showPpr c
matchLiteralContructor c (NaturalLiteral l) alts = go (reverse alts)
where
go [(DefaultPat,e)] = changed e
go ((DataPat dc [] xs,e):alts')
| dcTag dc == 1
, l >= 0 && l < 2^(64::Int)
= let fvs = Lens.foldMapOf freeLocalIds unitVarSet e
(binds,_) = List.partition ((`elemVarSet` fvs) . fst)
$ zip xs [Literal (WordLiteral l)]
e' = case binds of
[] -> e
_ -> Letrec binds e
in changed e'
| dcTag dc == 2
, l >= 2^(64::Int)
= let !(Jp# !(BN# ba)) = l
ba' = BA.ByteArray ba
bv = PV.Vector 0 (BA.sizeofByteArray ba') ba'
fvs = Lens.foldMapOf freeLocalIds unitVarSet e
(binds,_) = List.partition ((`elemVarSet` fvs) . fst)
$ zip xs [Literal (ByteArrayLiteral bv)]
e' = case binds of
[] -> e
_ -> Letrec binds e
in changed e'
| otherwise
= go alts'
go ((LitPat l', e):alts')
| NaturalLiteral l == l'
= changed e
| otherwise
= go alts'
go _ = error $ $(curLoc) ++ "Report as bug: caseCon error: " ++ showPpr c
matchLiteralContructor _ _ ((DefaultPat,e):_) = changed e
matchLiteralContructor c _ _ =
error $ $(curLoc) ++ "Report as bug: caseCon error: " ++ showPpr c
{-# SCC matchLiteralContructor #-}
caseOneAlt :: Term -> RewriteMonad extra Term
caseOneAlt e@(Case _ _ [(pat,altE)]) = case pat of
DefaultPat -> changed altE
LitPat _ -> changed altE
DataPat _ tvs xs
| (coerce tvs ++ coerce xs) `localVarsDoNotOccurIn` altE
-> changed altE
| otherwise
-> return e
caseOneAlt (Case _ _ alts@((_,alt):_:_))
| all ((== alt) . snd) (tail alts)
= changed alt
caseOneAlt e = return e
{-# SCC caseOneAlt #-}
nonRepANF :: HasCallStack => NormRewrite
nonRepANF ctx@(TransformContext is0 _) e@(App appConPrim arg)
| (conPrim, _) <- collectArgs e
, isCon conPrim || isPrim conPrim
= do
untranslatable <- isUntranslatable False arg
case (untranslatable,stripTicks arg) of
(True,Letrec binds body) ->
let (binds1,body1) = deshadowLetExpr is0 binds body
in changed (Letrec binds1 (App appConPrim body1))
(True,Case {}) -> specializeNorm ctx e
(True,Lam {}) -> specializeNorm ctx e
(True,TyLam {}) -> specializeNorm ctx e
_ -> return e
nonRepANF _ e = return e
{-# SCC nonRepANF #-}
topLet :: HasCallStack => NormRewrite
topLet (TransformContext is0 ctx) e
| all (\c -> isLambdaBodyCtx c || isTickCtx c) ctx && not (isLet e) && not (isTick e)
= do
untranslatable <- isUntranslatable False e
if untranslatable
then return e
else do tcm <- Lens.view tcCache
argId <- mkTmBinderFor is0 tcm (mkUnsafeSystemName "result" 0) e
changed (Letrec [(argId, e)] (Var argId))
where
isTick Tick{} = True
isTick _ = False
topLet (TransformContext is0 ctx) e@(Letrec binds body)
| all (\c -> isLambdaBodyCtx c || isTickCtx c) ctx
= do
let localVar = isLocalVar body
untranslatable <- isUntranslatable False body
if localVar || untranslatable
then return e
else do
tcm <- Lens.view tcCache
let is2 = extendInScopeSetList is0 (map fst binds)
argId <- mkTmBinderFor is2 tcm (mkUnsafeSystemName "result" 0) body
changed (Letrec (binds ++ [(argId,body)]) (Var argId))
topLet _ e = return e
{-# SCC topLet #-}
deadCode :: HasCallStack => NormRewrite
deadCode _ e@(Letrec binds body) = do
let bodyFVs = Lens.foldMapOf freeLocalIds unitVarSet body
used = List.foldl' collectUsed emptyVarEnv (eltsVarSet bodyFVs)
case eltsVarEnv used of
[] -> changed body
qqL | not (List.equalLength qqL binds)
-> changed (Letrec qqL body)
| otherwise
-> return e
where
bindsEnv = mkVarEnv (map (\(x,e0) -> (x,(x,e0))) binds)
collectUsed env v =
if v `elemVarEnv` env then
env
else
case lookupVarEnv v bindsEnv of
Just (x,e0) ->
let eFVs = Lens.foldMapOf freeLocalIds unitVarSet e0
in List.foldl' collectUsed
(extendVarEnv x (x,e0) env)
(eltsVarSet eFVs)
Nothing -> env
deadCode _ e = return e
{-# SCC deadCode #-}
removeUnusedExpr :: HasCallStack => NormRewrite
removeUnusedExpr _ e@(collectArgsTicks -> (p@(Prim pInfo),args,ticks)) = do
bbM <- HashMap.lookup (primName pInfo) <$> Lens.use (extra.primitives)
let
usedArgs0 =
case Monad.join (extractPrim <$> bbM) of
Just (BlackBoxHaskell{usedArguments}) ->
case usedArguments of
UsedArguments used -> Just used
IgnoredArguments ignored -> Just ([0..length args - 1] \\ ignored)
Just (BlackBox pNm _ _ _ _ _ _ _ _ inc r ri templ) -> Just $
if | isFromInt pNm -> [0,1,2]
| primName pInfo `elem` [ "Clash.Annotations.BitRepresentation.Deriving.dontApplyInHDL"
, "Clash.Sized.Vector.splitAt"
] -> [0,1]
| otherwise -> concat [ maybe [] getUsedArguments r
, maybe [] getUsedArguments ri
, getUsedArguments templ
, concatMap (getUsedArguments . snd) inc ]
_ ->
Nothing
case usedArgs0 of
Nothing ->
return e
Just usedArgs1 -> do
tcm <- Lens.view tcCache
(args1, Monoid.getAny -> hasChanged) <- listen (go tcm 0 usedArgs1 args)
if hasChanged then
return (mkApps (mkTicks p ticks) args1)
else
return e
where
arity = length . Either.rights . fst $ splitFunForallTy (primType pInfo)
go _ _ _ [] = return []
go tcm !n used (Right ty:args') = do
args'' <- go tcm n used args'
return (Right ty : args'')
go tcm !n used (Left tm : args') = do
args'' <- go tcm (n+1) used args'
case tm of
TyApp (Prim p0) _
| primName p0 == "Clash.Transformations.removedArg"
-> return (Left tm : args'')
_ -> do
let ty = termType tcm tm
p' = removedTm ty
if n < arity && n `notElem` used
then changed (Left p' : args'')
else return (Left tm : args'')
removeUnusedExpr _ e@(Case _ _ [(DataPat _ [] xs,altExpr)]) =
if xs `localIdsDoNotOccurIn` altExpr
then changed altExpr
else return e
removeUnusedExpr _ e@(collectArgsTicks -> (Data dc, [_,Right aTy,Right nTy,_,Left a,Left nil],ticks))
| nameOcc (dcName dc) == "Clash.Sized.Vector.Cons"
= do
tcm <- Lens.view tcCache
case runExcept (tyNatSize tcm nTy) of
Right 0
| (con, _) <- collectArgs nil
, not (isCon con)
-> let eTy = termType tcm e
(TyConApp vecTcNm _) = tyView eTy
(Just vecTc) = lookupUniqMap vecTcNm tcm
[nilCon,consCon] = tyConDataCons vecTc
v = mkTicks (mkVec nilCon consCon aTy 1 [a]) ticks
in changed v
_ -> return e
removeUnusedExpr _ e = return e
{-# SCC removeUnusedExpr #-}
bindConstantVar :: HasCallStack => NormRewrite
bindConstantVar = inlineBinders test
where
test _ (i,stripTicks -> e) = case isLocalVar e of
True -> return (i `localIdDoesNotOccurIn` e)
_ -> isWorkFreeIsh e >>= \case
True -> Lens.use (extra.inlineConstantLimit) >>= \case
0 -> return True
n -> return (termSize e <= n)
_ -> return False
{-# SCC bindConstantVar #-}
caseCast :: HasCallStack => NormRewrite
caseCast _ (Cast (stripTicks -> Case subj ty alts) ty1 ty2) = do
let alts' = map (\(p,e) -> (p, Cast e ty1 ty2)) alts
changed (Case subj ty alts')
caseCast _ e = return e
{-# SCC caseCast #-}
letCast :: HasCallStack => NormRewrite
letCast _ (Cast (stripTicks -> Letrec binds body) ty1 ty2) =
changed $ Letrec binds (Cast body ty1 ty2)
letCast _ e = return e
{-# SCC letCast #-}
argCastSpec :: HasCallStack => NormRewrite
argCastSpec ctx e@(App _ (stripTicks -> Cast e' _ _)) =
if isWorkFree e' then
go
else
warn go
where
go = specializeNorm ctx e
warn = trace (unwords
[ "WARNING:", $(curLoc), "specializing a function on a non work-free"
, "cast. Generated HDL implementation might contain duplicate work."
, "Please report this as a bug.", "\n\nExpression where this occured:"
, "\n\n" ++ showPpr e
])
argCastSpec _ e = return e
{-# SCC argCastSpec #-}
inlineCast :: HasCallStack => NormRewrite
inlineCast = inlineBinders test
where
test _ (_, (Cast (stripTicks -> Var {}) _ _)) = return True
test _ _ = return False
{-# SCC inlineCast #-}
eliminateCastCast :: HasCallStack => NormRewrite
eliminateCastCast _ c@(Cast (stripTicks -> Cast e tyA tyB) tyB' tyC) = do
tcm <- Lens.view tcCache
let ntyA = normalizeType tcm tyA
ntyB = normalizeType tcm tyB
ntyB' = normalizeType tcm tyB'
ntyC = normalizeType tcm tyC
if ntyB == ntyB' && ntyA == ntyC then changed e
else throwError
where throwError = do
(nm,sp) <- Lens.use curFun
throw (ClashException sp ($(curLoc) ++ showPpr nm
++ ": Found 2 nested casts whose types don't line up:\n"
++ showPpr c)
Nothing)
eliminateCastCast _ e = return e
{-# SCC eliminateCastCast #-}
splitCastWork :: HasCallStack => NormRewrite
splitCastWork ctx@(TransformContext is0 _) unchanged@(Letrec vs e') = do
(vss', Monoid.getAny -> hasChanged) <- listen (mapM (splitCastLetBinding is0) vs)
let vs' = concat vss'
if hasChanged then changed (Letrec vs' e')
else return unchanged
where
splitCastLetBinding
:: InScopeSet
-> LetBinding
-> RewriteMonad extra [LetBinding]
splitCastLetBinding isN x@(nm, e) = case stripTicks e of
Cast (Var {}) _ _ -> return [x]
Cast (Cast {}) _ _ -> return [x]
Cast e0 ty1 ty2 -> do
tcm <- Lens.view tcCache
nm' <- mkTmBinderFor isN tcm (mkDerivedName ctx (nameOcc $ varName nm)) e0
changed [(nm',e0)
,(nm, Cast (Var nm') ty1 ty2)
]
_ -> return [x]
splitCastWork _ e = return e
{-# SCC splitCastWork #-}
inlineWorkFree :: HasCallStack => NormRewrite
inlineWorkFree _ e@(collectArgsTicks -> (Var f,args@(_:_),ticks))
= do
tcm <- Lens.view tcCache
let eTy = termType tcm e
argsHaveWork <- or <$> mapM (either expressionHasWork
(const (pure False)))
args
untranslatable <- isUntranslatableType True eTy
let isSignal = isSignalType tcm eTy
let lv = isLocalId f
if untranslatable || isSignal || argsHaveWork || lv
then return e
else do
bndrs <- Lens.use bindings
case lookupVarEnv f bndrs of
Just b -> do
isRecBndr <- isRecursiveBndr f
if isRecBndr
then return e
else do
let tm = mkTicks (bindingTerm b) (mkInlineTick f : ticks)
changed $ mkApps tm args
_ -> return e
where
expressionHasWork e' = do
let fvIds = Lens.toListOf freeLocalIds e'
tcm <- Lens.view tcCache
let e'Ty = termType tcm e'
isSignal = isSignalType tcm e'Ty
return (not (null fvIds) || isSignal)
inlineWorkFree _ e@(Var f) = do
tcm <- Lens.view tcCache
let fTy = varType f
closed = not (isPolyFunCoreTy tcm fTy)
isSignal = isSignalType tcm fTy
untranslatable <- isUntranslatableType True fTy
topEnts <- Lens.view topEntities
let gv = isGlobalId f
if closed && f `notElemVarSet` topEnts && not untranslatable && not isSignal && gv
then do
bndrs <- Lens.use bindings
case lookupVarEnv f bndrs of
Just top -> do
isRecBndr <- isRecursiveBndr f
if isRecBndr
then return e
else do
let topB = bindingTerm top
sizeLimit <- Lens.use (extra.inlineWFCacheLimit)
if termSize topB < sizeLimit then
changed topB
else do
b <- normalizeTopLvlBndr False f top
changed (bindingTerm b)
_ -> return e
else return e
inlineWorkFree _ e = return e
{-# SCC inlineWorkFree #-}
inlineSmall :: HasCallStack => NormRewrite
inlineSmall _ e@(collectArgsTicks -> (Var f,args,ticks)) = do
untranslatable <- isUntranslatable True e
topEnts <- Lens.view topEntities
let lv = isLocalId f
if untranslatable || f `elemVarSet` topEnts || lv
then return e
else do
bndrs <- Lens.use bindings
sizeLimit <- Lens.use (extra.inlineFunctionLimit)
case lookupVarEnv f bndrs of
Just b -> do
isRecBndr <- isRecursiveBndr f
if not isRecBndr && bindingSpec b /= NoInline && termSize (bindingTerm b) < sizeLimit
then do
let tm = mkTicks (bindingTerm b) (mkInlineTick f : ticks)
changed $ mkApps tm args
else return e
_ -> return e
inlineSmall _ e = return e
{-# SCC inlineSmall #-}
constantSpec :: HasCallStack => NormRewrite
constantSpec ctx@(TransformContext is0 tfCtx) e@(App e1 e2)
| (Var {}, args) <- collectArgs e1
, (_, []) <- Either.partitionEithers args
, null $ Lens.toListOf termFreeTyVars e2
= do specInfo<- constantSpecInfo ctx e2
if csrFoundConstant specInfo then
let newBindings = csrNewBindings specInfo in
if null newBindings then
specializeNorm ctx (App e1 e2)
else do
let is1 = extendInScopeSetList is0 (fst <$> csrNewBindings specInfo)
Letrec newBindings
<$> specializeNorm
(TransformContext is1 tfCtx)
(App e1 (csrNewTerm specInfo))
else
return e
constantSpec _ e = return e
{-# SCC constantSpec #-}
appPropFast :: HasCallStack => NormRewrite
appPropFast ctx@(TransformContext is _) = \case
e@App {}
| let (fun,args,ticks) = collectArgsTicks e
-> go is (deShadowTerm is fun) args ticks
e@TyApp {}
| let (fun,args,ticks) = collectArgsTicks e
-> go is (deShadowTerm is fun) args ticks
e -> return e
where
go :: InScopeSet -> Term -> [Either Term Type] -> [TickInfo]
-> NormalizeSession Term
go is0 (collectArgsTicks -> (fun,args0@(_:_),ticks0)) args1 ticks1 =
go is0 fun (args0 ++ args1) (ticks0 ++ ticks1)
go is0 (Lam v e) (Left arg:args) ticks = do
setChanged
if isWorkFree arg || isVar arg
then do
let subst = extendIdSubst (mkSubst is0) v arg
(`mkTicks` ticks) <$> go is0 (substTm "appPropFast.AppLam" subst e) args []
else do
let is1 = extendInScopeSet is0 v
Letrec [(v, arg)] <$> go is1 (deShadowTerm is1 e) args ticks
go is0 (Letrec vs e) args@(_:_) ticks = do
setChanged
let vbs = map fst vs
is1 = extendInScopeSetList is0 vbs
Letrec vs <$> go is1 e args ticks
go is0 (TyLam tv e) (Right t:args) ticks = do
setChanged
let subst = extendTvSubst (mkSubst is0) tv t
(`mkTicks` ticks) <$> go is0 (substTm "appPropFast.TyAppTyLam" subst e) args []
go is0 (Case scrut ty0 alts) args0@(_:_) ticks = do
setChanged
let isA1 = unionInScope
is0
((mkInScopeSet . mkVarSet . concatMap (patVars . fst)) alts)
(ty1,vs,args1) <- goCaseArg isA1 ty0 [] args0
case vs of
[] -> (`mkTicks` ticks) . Case scrut ty1 <$> mapM (goAlt is0 args1) alts
_ -> do
let vbs = map fst vs
is1 = extendInScopeSetList is0 vbs
alts1 = map (deShadowAlt is1) alts
Letrec vs . (`mkTicks` ticks) . Case scrut ty1 <$> mapM (goAlt is1 args1) alts1
go is0 (Tick sp e) args ticks = do
setChanged
go is0 e args (sp:ticks)
go _ fun args ticks = return (mkApps (mkTicks fun ticks) args)
goAlt is0 args0 (p,e) = do
let (tvs,ids) = patIds p
is1 = extendInScopeSetList (extendInScopeSetList is0 tvs) ids
(p,) <$> go is1 e args0 []
goCaseArg isA ty0 ls0 (Right t:args0) = do
tcm <- Lens.view tcCache
let ty1 = piResultTy tcm ty0 t
(ty2,ls1,args1) <- goCaseArg isA ty1 ls0 args0
return (ty2,ls1,Right t:args1)
goCaseArg isA0 ty0 ls0 (Left arg:args0) = do
tcm <- Lens.view tcCache
let argTy = termType tcm arg
ty1 = applyFunTy tcm ty0 argTy
case isWorkFree arg || isVar arg of
True -> do
(ty2,ls1,args1) <- goCaseArg isA0 ty1 ls0 args0
return (ty2,ls1,Left arg:args1)
False -> do
boundArg <- mkTmBinderFor isA0 tcm (mkDerivedName ctx "app_arg") arg
let isA1 = extendInScopeSet isA0 boundArg
(ty2,ls1,args1) <- goCaseArg isA1 ty1 ls0 args0
return (ty2,(boundArg,arg):ls1,Left (Var boundArg):args1)
goCaseArg _ ty ls [] = return (ty,ls,[])
{-# SCC appPropFast #-}
caseFlat :: HasCallStack => NormRewrite
caseFlat _ e@(Case (collectEqArgs -> Just (scrut',_)) ty _)
= do
case collectFlat scrut' e of
Just alts' -> changed (Case scrut' ty (last alts' : init alts'))
Nothing -> return e
caseFlat _ e = return e
{-# SCC caseFlat #-}
collectFlat :: Term -> Term -> Maybe [(Pat,Term)]
collectFlat scrut (Case (collectEqArgs -> Just (scrut', val)) _ty [lAlt,rAlt])
| scrut' == scrut
= case collectArgs val of
(Prim p,args') | isFromInt (primName p) ->
go (last args')
(Data dc,args') | nameOcc (dcName dc) == "GHC.Types.I#" ->
go (last args')
_ -> Nothing
where
go (Left (Literal i)) = case (lAlt,rAlt) of
((pl,el),(pr,er))
| isFalseDcPat pl || isTrueDcPat pr ->
case collectFlat scrut el of
Just alts' -> Just ((LitPat i, er) : alts')
Nothing -> Just [(LitPat i, er)
,(DefaultPat, el)
]
| otherwise ->
case collectFlat scrut er of
Just alts' -> Just ((LitPat i, el) : alts')
Nothing -> Just [(LitPat i, el)
,(DefaultPat, er)
]
go _ = Nothing
isFalseDcPat (DataPat p _ _)
= ((== "GHC.Types.False") . nameOcc . dcName) p
isFalseDcPat _ = False
isTrueDcPat (DataPat p _ _)
= ((== "GHC.Types.True") . nameOcc . dcName) p
isTrueDcPat _ = False
collectFlat _ _ = Nothing
{-# SCC collectFlat #-}
collectEqArgs :: Term -> Maybe (Term,Term)
collectEqArgs (collectArgsTicks -> (Prim p, args, ticks))
| nm == "Clash.Sized.Internal.BitVector.eq#"
= let [_,_,Left scrut,Left val] = args
in Just (mkTicks scrut ticks,val)
| nm == "Clash.Sized.Internal.Index.eq#" ||
nm == "Clash.Sized.Internal.Signed.eq#" ||
nm == "Clash.Sized.Internal.Unsigned.eq#"
= let [_,Left scrut,Left val] = args
in Just (mkTicks scrut ticks,val)
| nm == "Clash.Transformations.eqInt"
= let [Left scrut,Left val] = args
in Just (mkTicks scrut ticks,val)
where
nm = primName p
collectEqArgs _ = Nothing
type NormRewriteW = Transform (StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState))
tellBinders :: Monad m => [LetBinding] -> StateT ([LetBinding],InScopeSet) m ()
tellBinders bs = modify ((bs ++) *** (`extendInScopeSetList` (map fst bs)))
notifyBinders :: Monad m => [LetBinding] -> StateT ([LetBinding],InScopeSet) m ()
notifyBinders bs = modify (second (`extendInScopeSetList` (map fst bs)))
isSimIOTy
:: TyConMap
-> Type
-> Bool
isSimIOTy tcm ty = case tyView (coreView tcm ty) of
TyConApp tcNm args
| nameOcc tcNm == "Clash.Explicit.SimIO.SimIO"
-> True
| nameOcc tcNm == "GHC.Prim.(#,#)"
, [_,_,st,_] <- args
-> isStateTokenTy tcm st
FunTy _ res -> isSimIOTy tcm res
_ -> False
isStateTokenTy
:: TyConMap
-> Type
-> Bool
isStateTokenTy tcm ty = case tyView (coreView tcm ty) of
TyConApp tcNm _ -> nameOcc tcNm == "GHC.Prim.State#"
_ -> False
makeANF :: HasCallStack => NormRewrite
makeANF (TransformContext is0 ctx) (Lam bndr e) = do
e' <- makeANF (TransformContext (extendInScopeSet is0 bndr)
(LamBody bndr:ctx))
e
return (Lam bndr e')
makeANF _ e@(TyLam {}) = return e
makeANF ctx@(TransformContext is0 _) e0
= do
let (is2,e1) = freshenTm is0 e0
((e2,(bndrs,_)),Monoid.getAny -> hasChanged) <-
listen (runStateT (bottomupR collectANF ctx e1) ([],is2))
case bndrs of
[] -> if hasChanged then return e2 else return e0
_ -> do
let (e3,ticks) = collectTicks e2
(srcTicks,nmTicks) = partitionTicks ticks
changed (mkTicks (Letrec bndrs (mkTicks e3 srcTicks)) nmTicks)
{-# SCC makeANF #-}
collectANF :: HasCallStack => NormRewriteW
collectANF ctx e@(App appf arg)
| (conVarPrim, _) <- collectArgs e
, isCon conVarPrim || isPrim conVarPrim || isVar conVarPrim
= do
untranslatable <- lift (isUntranslatable False arg)
let localVar = isLocalVar arg
constantNoCR <- lift (isConstantNotClockReset arg)
case (untranslatable,localVar || constantNoCR, isSimBind conVarPrim,arg) of
(False,False,False,_) -> do
tcm <- Lens.view tcCache
is1 <- Lens.use _2
argId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "app_arg") arg)
tellBinders [(argId,arg)]
return (App appf (Var argId))
(True,False,_,Letrec binds body) -> do
tellBinders binds
return (App appf body)
_ -> return e
where
isSimBind (Prim p) = primName p == "Clash.Explicit.SimIO.bindSimIO#"
isSimBind _ = False
collectANF _ (Letrec binds body) = do
tcm <- Lens.view tcCache
let isSimIO = isSimIOTy tcm (termType tcm body)
untranslatable <- lift (isUntranslatable False body)
let localVar = isLocalVar body
if localVar || untranslatable || isSimIO
then do
tellBinders binds
return body
else do
is1 <- Lens.use _2
argId <- lift (mkTmBinderFor is1 tcm (mkUnsafeSystemName "result" 0) body)
tellBinders [(argId,body)]
tellBinders binds
return (Var argId)
collectANF _ e@(Case _ _ [(DataPat dc _ _,_)])
| nameOcc (dcName dc) == "Clash.Signal.Internal.:-" = return e
collectANF ctx (Case subj ty alts) = do
let localVar = isLocalVar subj
let isConstantSubj = isConstant subj
(subj',subjBinders) <- if localVar || isConstantSubj
then return (subj,[])
else do
tcm <- Lens.view tcCache
is1 <- Lens.use _2
argId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_scrut") subj)
notifyBinders [(argId,subj)]
return (Var argId,[(argId,subj)])
tcm <- Lens.view tcCache
let isSimIOAlt = isSimIOTy tcm ty
alts' <- mapM (doAlt isSimIOAlt subj') alts
tellBinders subjBinders
case alts' of
[(DataPat _ [] xs,altExpr)]
| xs `localIdsDoNotOccurIn` altExpr || isSimIOAlt
-> return altExpr
_ -> return (Case subj' ty alts')
where
doAlt
:: Bool -> Term -> (Pat,Term)
-> StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState)
(Pat,Term)
doAlt isSimIOAlt subj' alt@(DataPat dc exts xs,altExpr) | not (bindsExistentials exts xs) = do
let lv = isLocalVar altExpr
patSels <- Monad.zipWithM (doPatBndr subj' dc) xs [0..]
let altExprIsConstant = isConstant altExpr
let usesXs (Var n) = any (== n) xs
usesXs _ = False
if or [isSimIOAlt, lv && (not (usesXs altExpr) || length alts == 1), altExprIsConstant]
then do
tellBinders patSels
return alt
else do
tcm <- Lens.view tcCache
is1 <- Lens.use _2
altId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_alt") altExpr)
tellBinders (patSels ++ [(altId,altExpr)])
return (DataPat dc exts xs,Var altId)
doAlt _ _ alt@(DataPat {}, _) = return alt
doAlt isSimIOAlt _ alt@(pat,altExpr) = do
let lv = isLocalVar altExpr
let altExprIsConstant = isConstant altExpr
if isSimIOAlt || lv || altExprIsConstant
then return alt
else do
tcm <- Lens.view tcCache
is1 <- Lens.use _2
altId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_alt") altExpr)
tellBinders [(altId,altExpr)]
return (pat,Var altId)
doPatBndr
:: Term -> DataCon -> Id -> Int
-> StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState)
LetBinding
doPatBndr subj' dc pId i
= do
tcm <- Lens.view tcCache
is1 <- Lens.use _2
patExpr <- lift (mkSelectorCase ($(curLoc) ++ "doPatBndr") is1 tcm subj' (dcTag dc) i)
return (pId,patExpr)
collectANF _ e = return e
{-# SCC collectANF #-}
etaExpansionTL :: HasCallStack => NormRewrite
etaExpansionTL (TransformContext is0 ctx) (Lam bndr e) = do
e' <- etaExpansionTL
(TransformContext (extendInScopeSet is0 bndr) (LamBody bndr:ctx))
e
return $ Lam bndr e'
etaExpansionTL (TransformContext is0 ctx) (Letrec xes e) = do
let bndrs = map fst xes
e' <- etaExpansionTL
(TransformContext (extendInScopeSetList is0 bndrs)
(LetBody bndrs:ctx))
e
case stripLambda e' of
(bs@(_:_),e2) -> do
let e3 = Letrec xes e2
changed (mkLams e3 bs)
_ -> return (Letrec xes e')
where
stripLambda :: Term -> ([Id],Term)
stripLambda (Lam bndr e0) =
let (bndrs,e1) = stripLambda e0
in (bndr:bndrs,e1)
stripLambda e' = ([],e')
etaExpansionTL (TransformContext is0 ctx) e
= do
tcm <- Lens.view tcCache
if isFun tcm e
then do
let argTy = ( fst
. Maybe.fromMaybe (error $ $(curLoc) ++ "etaExpansion splitFunTy")
. splitFunTy tcm
. termType tcm
) e
newId <- mkInternalVar is0 "arg" argTy
e' <- etaExpansionTL (TransformContext (extendInScopeSet is0 newId)
(LamBody newId:ctx))
(App e (Var newId))
changed (Lam newId e')
else return e
{-# SCC etaExpansionTL #-}
etaExpandSyn :: HasCallStack => NormRewrite
etaExpandSyn (TransformContext is0 ctx) e@(collectArgs -> (Var f, _)) = do
topEnts <- Lens.view topEntities
tcm <- Lens.view tcCache
let isTopEnt = f `elemVarSet` topEnts
isAppFunCtx =
\case
AppFun:_ -> True
TickC _:c -> isAppFunCtx c
_ -> False
argTyM = fmap fst (splitFunTy tcm (termType tcm e))
case argTyM of
Just argTy | isTopEnt && not (isAppFunCtx ctx) -> do
newId <- mkInternalVar is0 "arg" argTy
changed (Lam newId (App e (Var newId)))
_ -> return e
etaExpandSyn _ e = return e
{-# SCC etaExpandSyn #-}
isClassConstraint :: Type -> Bool
isClassConstraint (tyView -> TyConApp nm0 _) =
if
| "GHC.Classes.(%" `Text.isInfixOf` nm1 -> True
| "C:" `Text.isInfixOf` nm2 -> True
| otherwise -> False
where
nm1 = nameOcc nm0
nm2 = snd (Text.breakOnEnd "." nm1)
isClassConstraint _ = False
recToLetRec :: HasCallStack => NormRewrite
recToLetRec (TransformContext is0 []) e = do
(fn,_) <- Lens.use curFun
tcm <- Lens.view tcCache
case splitNormalized tcm e of
Right (args,bndrs,res) -> do
let args' = map Var args
(toInline,others) = List.partition (eqApp tcm fn args' . snd) bndrs
resV = Var res
case (toInline,others) of
(_:_,_:_) -> do
let is1 = extendInScopeSetList is0 (args ++ map fst bndrs)
let substsInline = extendIdSubstList (mkSubst is1)
$ map (second (const resV)) toInline
others' = map (second (substTm "recToLetRec" substsInline))
others
changed $ mkLams (Letrec others' resV) args
_ -> return e
_ -> return e
where
eqApp tcm v args (collectArgs . stripTicks -> (Var v',args'))
| isGlobalId v'
, v == v'
, let args2 = Either.lefts args'
, length args == length args2
= and (zipWith (eqArg tcm) args args2)
eqApp _ _ _ _ = False
eqArg _ v1 v2@(stripTicks -> Var {})
= v1 == v2
eqArg tcm v1 v2@(collectArgs . stripTicks -> (Data _, args'))
| let t1 = termType tcm v1
, let t2 = termType tcm v2
, t1 == t2
= if isClassConstraint t1 then
True
else
and (zipWith (eqDat v1) (map pure [0..]) (Either.lefts args'))
eqArg _ _ _
= False
eqDat :: Term -> [Int] -> Term -> Bool
eqDat v fTrace (collectArgs . stripTicks -> (Data _, args)) =
and (zipWith (eqDat v) (map (:fTrace) [0..]) (Either.lefts args))
eqDat v1 fTrace v2 =
case stripProjection (reverse fTrace) v1 v2 of
Just [] -> True
_ -> False
stripProjection :: [Int] -> Term -> Term -> Maybe [Int]
stripProjection fTrace0 vTarget0 (Case v _ [(DataPat _ _ xs, r)]) = do
fTrace1 <- stripProjection fTrace0 vTarget0 v
(n, fTrace2) <- List.uncons fTrace1
vTarget1 <- List.indexMaybe xs n
stripProjection fTrace2 (Var vTarget1) r
stripProjection fTrace (Var sTarget) (Var s) =
if sTarget == s then Just fTrace else Nothing
stripProjection _fTrace _vTarget _v =
Nothing
recToLetRec _ e = return e
{-# SCC recToLetRec #-}
inlineHO :: HasCallStack => NormRewrite
inlineHO _ e@(App _ _)
| (Var f, args, ticks) <- collectArgsTicks e
= do
tcm <- Lens.view tcCache
let hasPolyFunArgs = or (map (either (isPolyFun tcm) (const False)) args)
if hasPolyFunArgs
then do (cf,_) <- Lens.use curFun
isInlined <- zoomExtra (alreadyInlined f cf)
limit <- Lens.use (extra.inlineLimit)
if (Maybe.fromMaybe 0 isInlined) > limit
then do
lvl <- Lens.view dbgLevel
traceIf (lvl > DebugNone) ($(curLoc) ++ "InlineHO: " ++ show f ++ " already inlined " ++ show limit ++ " times in:" ++ show cf) (return e)
else do
bodyMaybe <- lookupVarEnv f <$> Lens.use bindings
case bodyMaybe of
Just b -> do
zoomExtra (addNewInline f cf)
changed (mkApps (mkTicks (bindingTerm b) ticks) args)
_ -> return e
else return e
inlineHO _ e = return e
{-# SCC inlineHO #-}
simpleCSE :: HasCallStack => NormRewrite
simpleCSE (TransformContext is0 _) (inverseTopSortLetBindings -> Letrec bndrs body) = do
let is1 = extendInScopeSetList is0 (map fst bndrs)
(subst,bndrs1) <- reduceBinders (mkSubst is1) [] bndrs
let bndrs2 = map (second (substTm "simpleCSE.bndrs" subst)) bndrs1
body1 = substTm "simpleCSE.body" subst body
return (Letrec bndrs2 body1)
simpleCSE _ e = return e
{-# SCC simpleCSE #-}
reduceBinders
:: Subst
-> [LetBinding]
-> [LetBinding]
-> RewriteMonad NormalizeState (Subst, [LetBinding])
reduceBinders !subst processed [] = return (subst,processed)
reduceBinders !subst processed ((i,substTm "reduceBinders" subst -> e):rest)
| (_,_,ticks) <- collectArgsTicks e
, NoDeDup `notElem` ticks
, Just (i1,_) <- List.find ((== e) . snd) processed
= do
let subst1 = extendIdSubst subst i (Var i1)
setChanged
reduceBinders subst1 processed rest
| otherwise
= reduceBinders subst ((i,e):processed) rest
{-# SCC reduceBinders #-}
reduceConst :: HasCallStack => NormRewrite
reduceConst ctx e@(App _ _)
| (Prim p0, _) <- collectArgs e
= whnfRW False ctx e $ \_ctx1 e1 -> case e1 of
(collectArgs -> (Prim p1, _)) | primName p0 == primName p1 -> return e
_ -> changed e1
reduceConst _ e = return e
{-# SCC reduceConst #-}
reduceNonRepPrim :: HasCallStack => NormRewrite
reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks) <- collectArgsTicks e = do
tcm <- Lens.view tcCache
ultra <- Lens.use (extra.normalizeUltra)
let eTy = termType tcm e
case tyView eTy of
(TyConApp vecTcNm@(nameOcc -> "Clash.Sized.Vector.Vec")
[runExcept . tyNatSize tcm -> Right 0, aTy]) -> do
let (Just vecTc) = lookupUniqMap vecTcNm tcm
[nilCon,consCon] = tyConDataCons vecTc
nilE = mkVec nilCon consCon aTy 0 []
changed (mkTicks nilE ticks)
tv -> let argLen = length args in case primName p of
"Clash.Sized.Vector.zipWith" | argLen == 7 -> do
let [lhsElTy,rhsElty,resElTy,nTy] = Either.rights args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure (ultra || n < 2)
, shouldReduce ctx
, List.anyM isUntranslatableType_not_poly
[lhsElTy,rhsElty,resElTy] ]
if shouldReduce1
then let [fun,lhsArg,rhsArg] = Either.lefts args
in (`mkTicks` ticks) <$>
reduceZipWith c n lhsElTy rhsElty resElTy fun lhsArg rhsArg
else return e
_ -> return e
"Clash.Sized.Vector.map" | argLen == 5 -> do
let [argElTy,resElTy,nTy] = Either.rights args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure (ultra || n < 2 )
, shouldReduce ctx
, List.anyM isUntranslatableType_not_poly
[argElTy,resElTy] ]
if shouldReduce1
then let [fun,arg] = Either.lefts args
in (`mkTicks` ticks) <$> reduceMap c n argElTy resElTy fun arg
else return e
_ -> return e
"Clash.Sized.Vector.traverse#" | argLen == 7 ->
let [aTy,fTy,bTy,nTy] = Either.rights args
in case runExcept (tyNatSize tcm nTy) of
Right n ->
let [dict,fun,arg] = Either.lefts args
in (`mkTicks` ticks) <$> reduceTraverse c n aTy fTy bTy dict fun arg
_ -> return e
"Clash.Sized.Vector.fold" | argLen == 4 -> do
let [aTy,nTy] = Either.rights args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure (ultra || n == 0)
, shouldReduce ctx
, isUntranslatableType_not_poly aTy ]
if shouldReduce1 then
let [fun,arg] = Either.lefts args
in (`mkTicks` ticks) <$> reduceFold c (n + 1) aTy fun arg
else return e
_ -> return e
"Clash.Sized.Vector.foldr" | argLen == 6 ->
let [aTy,bTy,nTy] = Either.rights args
in case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure ultra
, shouldReduce ctx
, List.anyM isUntranslatableType_not_poly [aTy,bTy] ]
if shouldReduce1
then let [fun,start,arg] = Either.lefts args
in (`mkTicks` ticks) <$> reduceFoldr c n aTy fun start arg
else return e
_ -> return e
"Clash.Sized.Vector.dfold" | argLen == 8 ->
let ([_kn,_motive,fun,start,arg],[_mTy,nTy,aTy]) = Either.partitionEithers args
in case runExcept (tyNatSize tcm nTy) of
Right n -> (`mkTicks` ticks) <$> reduceDFold is0 n aTy fun start arg
_ -> return e
"Clash.Sized.Vector.++" | argLen == 5 ->
let [nTy,aTy,mTy] = Either.rights args
[lArg,rArg] = Either.lefts args
in case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of
(Right n, Right m)
| n == 0 -> changed rArg
| m == 0 -> changed lArg
| otherwise -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType_not_poly aTy ]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceAppend is0 n m aTy lArg rArg
else return e
_ -> return e
"Clash.Sized.Vector.head" | argLen == 3 -> do
let [nTy,aTy] = Either.rights args
[vArg] = Either.lefts args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType_not_poly aTy ]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceHead is0 (n+1) aTy vArg
else return e
_ -> return e
"Clash.Sized.Vector.tail" | argLen == 3 -> do
let [nTy,aTy] = Either.rights args
[vArg] = Either.lefts args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType_not_poly aTy ]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceTail is0 (n+1) aTy vArg
else return e
_ -> return e
"Clash.Sized.Vector.last" | argLen == 3 -> do
let [nTy,aTy] = Either.rights args
[vArg] = Either.lefts args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType_not_poly aTy
]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceLast is0 (n+1) aTy vArg
else return e
_ -> return e
"Clash.Sized.Vector.init" | argLen == 3 -> do
let [nTy,aTy] = Either.rights args
[vArg] = Either.lefts args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType_not_poly aTy ]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceInit is0 (n+1) aTy vArg
else return e
_ -> return e
"Clash.Sized.Vector.unconcat" | argLen == 6 -> do
let ([_knN,_sm,arg],[mTy,nTy,aTy]) = Either.partitionEithers args
case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of
(Right n, Right 0) -> (`mkTicks` ticks) <$> reduceUnconcat n 0 aTy arg
_ -> return e
"Clash.Sized.Vector.transpose" | argLen == 5 -> do
let ([_knN,arg],[mTy,nTy,aTy]) = Either.partitionEithers args
case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of
(Right n, Right 0) -> (`mkTicks` ticks) <$> reduceTranspose n 0 aTy arg
_ -> return e
"Clash.Sized.Vector.replicate" | argLen == 4 -> do
let ([_sArg,vArg],[nTy,aTy]) = Either.partitionEithers args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType_not_poly aTy
]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceReplicate n aTy eTy vArg
else return e
_ -> return e
"Clash.Sized.Vector.replace_int" | argLen == 6 -> do
let ([_knArg,vArg,iArg,aArg],[nTy,aTy]) = Either.partitionEithers args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure ultra
, shouldReduce ctx
, isUntranslatableType_not_poly aTy
]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceReplace_int is0 n aTy eTy vArg iArg aArg
else return e
_ -> return e
"Clash.Sized.Vector.index_int" | argLen == 5 -> do
let ([_knArg,vArg,iArg],[nTy,aTy]) = Either.partitionEithers args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure ultra
, shouldReduce ctx
, isUntranslatableType_not_poly aTy ]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceIndex_int is0 n aTy vArg iArg
else return e
_ -> return e
"Clash.Sized.Vector.imap" | argLen == 6 -> do
let [nTy,argElTy,resElTy] = Either.rights args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ pure (ultra || n < 2)
, shouldReduce ctx
, List.anyM isUntranslatableType_not_poly [argElTy,resElTy] ]
if shouldReduce1
then let [_,fun,arg] = Either.lefts args
in (`mkTicks` ticks) <$> reduceImap c n argElTy resElTy fun arg
else return e
_ -> return e
"Clash.Sized.Vector.dtfold" | argLen == 8 ->
let ([_kn,_motive,lrFun,brFun,arg],[_mTy,nTy,aTy]) = Either.partitionEithers args
in case runExcept (tyNatSize tcm nTy) of
Right n -> (`mkTicks` ticks) <$> reduceDTFold is0 n aTy lrFun brFun arg
_ -> return e
"Clash.Sized.Vector.reverse"
| ultra
, ([vArg],[nTy,aTy]) <- Either.partitionEithers args
, Right n <- runExcept (tyNatSize tcm nTy)
-> (`mkTicks` ticks) <$> reduceReverse is0 n aTy vArg
"Clash.Sized.RTree.tdfold" | argLen == 8 ->
let ([_kn,_motive,lrFun,brFun,arg],[_mTy,nTy,aTy]) = Either.partitionEithers args
in case runExcept (tyNatSize tcm nTy) of
Right n -> (`mkTicks` ticks) <$> reduceTFold is0 n aTy lrFun brFun arg
_ -> return e
"Clash.Sized.RTree.treplicate" | argLen == 4 -> do
let ([_sArg,vArg],[nTy,aTy]) = Either.partitionEithers args
case runExcept (tyNatSize tcm nTy) of
Right n -> do
shouldReduce1 <- List.orM [ shouldReduce ctx
, isUntranslatableType False aTy ]
if shouldReduce1
then (`mkTicks` ticks) <$> reduceTReplicate n aTy eTy vArg
else return e
_ -> return e
"Clash.Sized.Internal.BitVector.split#" | argLen == 4 -> do
let ([_knArg,bvArg],[nTy,mTy]) = Either.partitionEithers args
case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy), tv) of
(Right n, Right m, TyConApp tupTcNm [lTy,rTy])
| n == 0 -> do
let (Just tupTc) = lookupUniqMap tupTcNm tcm
[tupDc] = tyConDataCons tupTc
tup = mkApps (Data tupDc)
[Right lTy
,Right rTy
,Left bvArg
,Left (removedTm rTy)
]
changed (mkTicks tup ticks)
| m == 0 -> do
let (Just tupTc) = lookupUniqMap tupTcNm tcm
[tupDc] = tyConDataCons tupTc
tup = mkApps (Data tupDc)
[Right lTy
,Right rTy
,Left (removedTm lTy)
,Left bvArg
]
changed (mkTicks tup ticks)
_ -> return e
"Clash.Sized.Internal.BitVector.eq#"
| ([_,_],[nTy]) <- Either.partitionEithers args
, Right 0 <- runExcept (tyNatSize tcm nTy)
, TyConApp boolTcNm [] <- tv
-> let (Just boolTc) = lookupUniqMap boolTcNm tcm
[_falseDc,trueDc] = tyConDataCons boolTc
in changed (mkTicks (Data trueDc) ticks)
_ -> return e
where
isUntranslatableType_not_poly t = do
u <- isUntranslatableType False t
if u
then return (null $ Lens.toListOf typeFreeVars t)
else return False
reduceNonRepPrim _ e = return e
{-# SCC reduceNonRepPrim #-}
disjointExpressionConsolidation :: HasCallStack => NormRewrite
disjointExpressionConsolidation ctx@(TransformContext is0 _) e@(Case _scrut _ty _alts@(_:_:_)) = do
(_,collected) <- collectGlobals is0 [] [] e
let disJoint = filter (isDisjoint . snd . snd) collected
if null disJoint
then return e
else do
exprs <- mapM (mkDisjointGroup is0) disJoint
tcm <- Lens.view tcCache
lids <- Monad.zipWithM (mkFunOut is0 tcm) disJoint exprs
let substitution = zip (map fst disJoint) (map Var lids)
subsMatrix = l2m substitution
(exprs',_) <- unzip <$> Monad.zipWithM
(\s (e',seen) -> collectGlobals is0 s seen e')
subsMatrix
exprs
(e',_) <- collectGlobals is0 substitution [] e
let lb = Letrec (zip lids exprs') e'
lb' <- bottomupR deadCode ctx lb
changed lb'
where
mkFunOut isN tcm (fun,_) (e',_) = do
let ty = termType tcm e'
nm = case collectArgs fun of
(Var v,_) -> nameOcc (varName v)
(Prim p,_) -> primName p
_ -> "complex_expression_"
nm'' = last (Text.splitOn "." nm) `Text.append` "Out"
mkInternalVar isN nm'' ty
l2m = go []
where
go _ [] = []
go xs (y:ys) = (xs ++ ys) : go (xs ++ [y]) ys
disjointExpressionConsolidation _ e = return e
{-# SCC disjointExpressionConsolidation #-}
inlineCleanup :: HasCallStack => NormRewrite
inlineCleanup (TransformContext is0 _) (Letrec binds body) = do
prims <- Lens.use (extra.primitives)
let is1 = extendInScopeSetList is0 (map fst binds)
bindsFvs = map (\(v,e) -> (v,((v,e),countFreeOccurances e))) binds
allOccs = List.foldl' (unionVarEnvWith (+)) emptyVarEnv
$ map (snd.snd) bindsFvs
bodyFVs = Lens.foldMapOf freeLocalIds unitVarSet body
(il,keep) = List.partition (isInteresting allOccs prims bodyFVs)
bindsFvs
keep' = inlineBndrsCleanup is1 (mkVarEnv il) emptyVarEnv
$ map snd keep
if | null il -> return (Letrec binds body)
| null keep' -> changed body
| otherwise -> changed (Letrec keep' body)
where
isInteresting
:: VarEnv Int
-> CompiledPrimMap
-> VarSet
-> (Id,((Id, Term), VarEnv Int))
-> Bool
isInteresting allOccs prims bodyFVs (id_,((_,(fst.collectArgs) -> tm),_))
| nameSort (varName id_) /= User
, id_ `notElemVarSet` bodyFVs
= case tm of
Prim pInfo
| let nm = primName pInfo
, Just (extractPrim -> Just p@(BlackBox {})) <- HashMap.lookup nm prims
, TExpr <- kind p
, Just occ <- lookupVarEnv id_ allOccs
, occ < 2
-> True
| otherwise
-> primName pInfo `elem` ["Clash.Explicit.SimIO.bindSimIO#"]
Case _ _ [_] -> True
Data _ -> True
Case _ aTy (_:_:_)
| TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") _ <- tyView aTy
-> True
_ -> False
| id_ `notElemVarSet` bodyFVs
= case tm of
Prim pInfo
| primName pInfo `elem` [ "Clash.Explicit.SimIO.openFile"
, "Clash.Explicit.SimIO.fgetc"
, "Clash.Explicit.SimIO.feof"
]
, Just occ <- lookupVarEnv id_ allOccs
, occ < 2
-> True
| otherwise
-> primName pInfo `elem` ["Clash.Explicit.SimIO.bindSimIO#"]
Case _ _ [(DataPat dcE _ _,_)]
-> let nm = (nameOcc (dcName dcE))
in
nm == "Clash.Sized.Internal.BitVector.BV" ||
nm == "Clash.Sized.Internal.BitVector.Bit" ||
"GHC.Classes" `Text.isPrefixOf` nm
Case _ aTy (_:_:_)
| TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") _ <- tyView aTy
-> True
_ -> False
isInteresting _ _ _ _ = False
inlineCleanup _ e = return e
{-# SCC inlineCleanup #-}
data Mark = Temp | Done | Rec
inlineBndrsCleanup
:: InScopeSet
-> VarEnv ((Id,Term),VarEnv Int)
-> VarEnv ((Id,Term),VarEnv Int,Mark)
-> [((Id,Term),VarEnv Int)]
-> [(Id,Term)]
inlineBndrsCleanup isN origInl = go
where
go doneInl [] =
[ (v,e) | ((v,e),_,Rec) <- eltsVarEnv doneInl ]
go !doneInl (((v,e),eFVs):il) =
let (sM,_,doneInl1) = foldlWithUniqueVarEnv'
(reduceBindersCleanup isN origInl)
(Nothing, emptyVarEnv, doneInl)
eFVs
e1 = case sM of
Nothing -> e
Just s -> substTm "inlineBndrsCleanup" s e
in (v,e1):go doneInl1 il
{-# SCC inlineBndrsCleanup #-}
reduceBindersCleanup
:: InScopeSet
-> VarEnv ((Id,Term),VarEnv Int)
-> (Maybe Subst,VarEnv Int,VarEnv ((Id,Term),VarEnv Int,Mark))
-> Unique
-> Int
-> (Maybe Subst,VarEnv Int,VarEnv ((Id,Term),VarEnv Int,Mark))
reduceBindersCleanup isN origInl (!substM,!substFVs,!doneInl) u _ = case lookupVarEnvDirectly u doneInl of
Nothing -> case lookupVarEnvDirectly u origInl of
Nothing ->
(substM,substFVs,doneInl)
Just ((v,e),eFVs) ->
let (sM,substFVsE,doneInl1) =
foldlWithUniqueVarEnv'
(reduceBindersCleanup isN origInl)
( Nothing
, eFVs
, extendVarEnv v ((v,e),eFVs,Temp) doneInl)
eFVs
e1 = case sM of
Nothing -> e
Just s -> substTm "reduceBindersCleanup" s e
in if v `elemVarEnv` substFVsE then
( substM
, substFVs
, extendVarEnv v ((v,e1),substFVsE,Rec) doneInl1
)
else
( Just (extendIdSubst (Maybe.fromMaybe (mkSubst isN) substM) v e1)
, unionVarEnv substFVsE substFVs
, extendVarEnv v ((v,e1),substFVsE,Done) doneInl1
)
Just ((v,e),eFVs,Done) ->
( Just (extendIdSubst (Maybe.fromMaybe (mkSubst isN) substM) v e)
, unionVarEnv eFVs substFVs
, doneInl
)
Just _ ->
( substM
, substFVs
, doneInl
)
{-# SCC reduceBindersCleanup #-}
flattenLet :: HasCallStack => NormRewrite
flattenLet (TransformContext is0 _) (Letrec binds body) = do
let is1 = extendInScopeSetList is0 (map fst binds)
bodyOccs = Lens.foldMapByOf
freeLocalIds (unionVarEnvWith (+))
emptyVarEnv (`unitVarEnv` (1 :: Int))
body
(is2,binds1) <- second concat <$> List.mapAccumLM go is1 binds
case binds1 of
[(id1,e1)] | Just occ <- lookupVarEnv id1 bodyOccs, isWorkFree e1 || occ < 2 ->
if id1 `localIdOccursIn` e1
then return (Letrec binds1 body)
else let subst = extendIdSubst (mkSubst is2) id1 e1
in changed (substTm "flattenLet" subst body)
_ -> return (Letrec binds1 body)
where
go :: InScopeSet -> LetBinding -> NormalizeSession (InScopeSet,[LetBinding])
go isN (id1,collectTicks -> (Letrec binds1 body1,ticks)) = do
let bs1 = map fst binds1
let (binds2,body2,isN1) =
if any (`elemInScopeSet` isN) bs1 then
let Letrec bindsN bodyN = deShadowTerm isN (Letrec binds1 body1)
in (bindsN,bodyN,extendInScopeSetList isN (map fst bindsN))
else
(binds1,body1,extendInScopeSetList isN bs1)
let bodyOccs = Lens.foldMapByOf
freeLocalIds (unionVarEnvWith (+))
emptyVarEnv (`unitVarEnv` (1 :: Int))
body2
(srcTicks,nmTicks) = partitionTicks ticks
(isN1,) . map (second (`mkTicks` nmTicks)) <$> case binds2 of
[(id2,e2)] | Just occ <- lookupVarEnv id2 bodyOccs, isWorkFree e2 || occ < 2 ->
if id2 `localIdOccursIn` e2
then changed ([(id2,e2),(id1, body2)])
else let subst = extendIdSubst (mkSubst isN1) id2 e2
in changed [(id1
,mkTicks (substTm "flattenLetGo" subst body2)
srcTicks)]
bs -> changed (bs ++ [(id1
,mkTicks body2 srcTicks)])
go isN b = return (isN,[b])
flattenLet _ e = return e
{-# SCC flattenLet #-}
separateLambda
:: TyConMap
-> TransformContext
-> Id
-> Term
-> Maybe Term
separateLambda tcm ctx@(TransformContext is0 _) b eb0 =
case shouldSplit tcm (varType b) of
Just (dc,argTys@(_:_:_)) ->
let
nm = mkDerivedName ctx (nameOcc (varName b))
bs0 = map (`mkLocalId` nm) argTys
(is1, bs1) = List.mapAccumL newBinder is0 bs0
subst = extendIdSubst (mkSubst is1) b (mkApps dc (map (Left . Var) bs1))
eb1 = substTm "separateArguments" subst eb0
in
Just (mkLams eb1 bs1)
_ ->
Nothing
where
newBinder isN0 x =
let
x' = uniqAway isN0 x
isN1 = extendInScopeSet isN0 x'
in
(isN1, x')
{-# SCC separateLambda #-}
separateArguments :: HasCallStack => NormRewrite
separateArguments ctx e0@(Lam b eb) = do
tcm <- Lens.view tcCache
case separateLambda tcm ctx b eb of
Just e1 -> changed e1
Nothing -> return e0
separateArguments (TransformContext is0 _) e@(collectArgsTicks -> (Var g, args, ticks))
| isGlobalId g = do
let (argTys0,resTy) = splitFunForallTy (varType g)
(concat -> args1, Monoid.getAny -> hasChanged)
<- listen (mapM (uncurry splitArg) (zip argTys0 args))
if hasChanged then
let (argTys1,args2) = unzip args1
gTy = mkPolyFunTy resTy argTys1
in return (mkApps (mkTicks (Var g {varType = gTy}) ticks) args2)
else
return e
where
splitArg
:: Either TyVar Type
-> Either Term Type
-> NormalizeSession [(Either TyVar Type,Either Term Type)]
splitArg tv arg@(Right _) = return [(tv,arg)]
splitArg ty arg@(Left tmArg) = do
tcm <- Lens.view tcCache
let argTy = termType tcm tmArg
case shouldSplit tcm argTy of
Just (_,argTys@(_:_:_)) -> do
tmArgs <- mapM (mkSelectorCase ($(curLoc) ++ "splitArg") is0 tcm tmArg 1)
[0..length argTys - 1]
changed (map ((ty,) . Left) tmArgs)
_ ->
return [(ty,arg)]
separateArguments _ e = return e
{-# SCC separateArguments #-}
xOptimize :: HasCallStack => NormRewrite
xOptimize (TransformContext is0 _) e@(Case subj ty alts) = do
runXOpt <- Lens.view aggressiveXOpt
if runXOpt then do
defPart <- List.partitionM (isPrimError . snd) alts
case defPart of
([], _) -> return e
(_, []) -> changed (Prim (PrimInfo "Clash.XException.errorX" ty WorkConstant))
(_, [alt]) -> xOptimizeSingle is0 subj alt
(_, defs) -> xOptimizeMany is0 subj ty defs
else
return e
xOptimize _ e = return e
{-# SCC xOptimize #-}
xOptimizeSingle :: InScopeSet -> Term -> Alt -> NormalizeSession Term
xOptimizeSingle is subj (DataPat dc tvs vars, expr) = do
tcm <- Lens.view tcCache
subjId <- mkInternalVar is "subj" (termType tcm subj)
let fieldTys = fmap varType vars
lets <- Monad.zipWithM (mkFieldSelector is subjId dc tvs fieldTys) vars [0..]
changed (Letrec ((subjId, subj) : lets) expr)
xOptimizeSingle _ _ (_, expr) = changed expr
xOptimizeMany
:: HasCallStack
=> InScopeSet
-> Term
-> Type
-> [Alt]
-> NormalizeSession Term
xOptimizeMany is subj ty defs@(d:ds)
| isAnyDefault defs = changed (Case subj ty defs)
| otherwise = do
newAlt <- xOptimizeSingle is subj d
changed (Case subj ty $ ds <> [(DefaultPat, newAlt)])
where
isAnyDefault :: [Alt] -> Bool
isAnyDefault = any ((== DefaultPat) . fst)
xOptimizeMany _ _ _ [] =
error $ $(curLoc) ++ "Report as bug: xOptimizeMany error: No defined alternatives"
mkFieldSelector
:: MonadUnique m
=> InScopeSet
-> Id
-> DataCon
-> [TyVar]
-> [Type]
-> Id
-> Int
-> m LetBinding
mkFieldSelector is0 subj dc tvs fieldTys nm index = do
fields <- mapM (\ty -> mkInternalVar is0 "field" ty) fieldTys
let alt = (DataPat dc tvs fields, Var $ fields !! index)
return (nm, Case (Var subj) (fieldTys !! index) [alt])
isPrimError :: Term -> NormalizeSession Bool
isPrimError (collectArgs -> (Prim pInfo, _)) = do
prim <- Lens.use (extra . primitives . Lens.at (primName pInfo))
case prim >>= extractPrim of
Just p -> return (isErr p)
Nothing -> return False
where
isErr BlackBox{template=(BBTemplate [Err _])} = True
isErr _ = False
isPrimError _ = return False