{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Normalize where
import Control.Concurrent.Supply (Supply)
import Control.Lens ((.=),(^.),_2,_5)
import qualified Control.Lens as Lens
import Data.Either (partitionEithers)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import qualified Data.HashSet as HashSet
import qualified Data.IntMap as IntMap
import Data.IntMap.Strict (IntMap)
import Data.List
(groupBy, intersect, mapAccumL, sortBy)
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Data.Semigroup ((<>))
import Data.Text.Prettyprint.Doc (vcat)
import Unbound.Generics.LocallyNameless (unembed, runLFreshM)
import BasicTypes (InlineSpec (..))
import SrcLoc (SrcSpan,noSrcSpan)
import Clash.Core.Evaluator (PrimEvaluator)
import Clash.Core.FreeVars (termFreeIds)
import Clash.Core.Name (Name (..), NameSort (..))
import Clash.Core.Pretty (showDoc, ppr)
import Clash.Core.Subst (substTms)
import Clash.Core.Term (Term (..), TmName, TmOccName)
import Clash.Core.Type (Type, splitCoreFunForallTy)
import Clash.Core.TyCon
(TyCon, TyConName, TyConOccName)
import Clash.Core.Util (collectArgs, mkApps, termType)
import Clash.Core.Var (Id,varName)
import Clash.Driver.Types
(BindingMap, ClashOpts (..), DebugLevel (..))
import Clash.Netlist.BlackBox.Types (BlackBoxTemplate)
import Clash.Netlist.Types (HWType (..))
import Clash.Netlist.Util
(splitNormalized, unsafeCoreTypeToHWType)
import Clash.Normalize.Strategy
import Clash.Normalize.Transformations
(appProp, bindConstantVar, caseCon, flattenLet, reduceConst, topLet)
import Clash.Normalize.Types
import Clash.Normalize.Util
import Clash.Primitives.Types (PrimMap)
import Clash.Rewrite.Combinators ((>->),(!->))
import Clash.Rewrite.Types
(RewriteEnv (..), RewriteState (..), bindings, curFun, dbgLevel, extra,
tcCache, topEntities, typeTranslator)
import Clash.Rewrite.Util (isUntranslatableType,
runRewrite,
runRewriteSession)
import Clash.Signal.Internal (ResetKind (..))
import Clash.Util
runNormalization
:: ClashOpts
-> Supply
-> BindingMap
-> (HashMap TyConOccName TyCon -> Bool -> Type -> Maybe (Either String HWType))
-> HashMap TyConOccName TyCon
-> IntMap TyConName
-> PrimEvaluator
-> PrimMap BlackBoxTemplate
-> HashMap TmOccName Bool
-> [TmOccName]
-> NormalizeSession a
-> a
runNormalization opts supply globals typeTrans tcm tupTcm eval primMap rcsMap topEnts
= runRewriteSession rwEnv rwState
where
rwEnv = RewriteEnv
(opt_dbgLevel opts)
typeTrans
tcm
tupTcm
eval
(opt_allowZero opts)
(HashSet.fromList topEnts)
rwState = RewriteState
0
globals
supply
(error $ $(curLoc) ++ "Report as bug: no curFun",noSrcSpan)
0
(IntMap.empty, 0)
normState
normState = NormalizeState
HashMap.empty
Map.empty
HashMap.empty
(opt_specLimit opts)
HashMap.empty
(opt_inlineLimit opts)
(opt_inlineFunctionLimit opts)
(opt_inlineConstantLimit opts)
primMap
rcsMap
normalize
:: [TmOccName]
-> NormalizeSession BindingMap
normalize [] = return HashMap.empty
normalize top = do
(new,topNormalized) <- unzip <$> mapM normalize' top
newNormalized <- normalize (concat new)
return (HashMap.union (HashMap.fromList topNormalized) newNormalized)
normalize'
:: TmOccName
-> NormalizeSession ([TmOccName],(TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term)))
normalize' nm = do
exprM <- HashMap.lookup nm <$> Lens.use bindings
let nmS = showDoc nm
case exprM of
Just (nm',ty,sp,inl,tm) -> do
tcm <- Lens.view tcCache
let (_,resTy) = splitCoreFunForallTy tcm ty
resTyRep <- not <$> isUntranslatableType False resTy
if resTyRep
then do
tmNorm <- makeCached nm (extra.normalized) $ do
curFun .= (nm',sp)
tm' <- rewriteExpr ("normalization",normalization) (nmS,tm)
ty' <- termType tcm tm'
return (nm',ty',sp,inl,tm')
let usedBndrs = Lens.toListOf termFreeIds (tmNorm ^. _5)
traceIf (nm `elem` usedBndrs)
(concat [ $(curLoc),"Expr belonging to bndr: ",nmS ," (:: "
, showDoc (tmNorm ^. _2)
, ") remains recursive after normalization:\n"
, showDoc (tmNorm ^. _5) ])
(return ())
tyTrans <- Lens.view typeTranslator
case clockResetErrors sp tyTrans tcm ty of
msgs@(_:_) -> traceIf True (concat (nmS:" (:: ":showDoc (tmNorm ^. _2)
:")\nhas potentially dangerous meta-stability issues:\n\n"
:msgs))
(return ())
_ -> return ()
prevNorm <- fmap HashMap.keys $ Lens.use (extra.normalized)
topEnts <- Lens.view topEntities
let toNormalize = filter (not . (`HashSet.member` topEnts))
$ filter (`notElem` (nm:prevNorm)) usedBndrs
return (toNormalize,(nm,tmNorm))
else do
let usedBndrs = Lens.toListOf termFreeIds tm
prevNorm <- fmap HashMap.keys $ Lens.use (extra.normalized)
topEnts <- Lens.view topEntities
let toNormalize = filter (not . (`HashSet.member` topEnts))
$ filter (`notElem` (nm:prevNorm)) usedBndrs
lvl <- Lens.view dbgLevel
traceIf (lvl >= DebugFinal)
(concat [$(curLoc), "Expr belonging to bndr: ", nmS, " (:: "
, showDoc ty
, ") has a non-representable return type."
, " Not normalising:\n", showDoc tm] )
(return (toNormalize,(nm,(nm',ty,sp,inl,tm))))
Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"
rewriteExpr :: (String,NormRewrite)
-> (String,Term)
-> NormalizeSession Term
rewriteExpr (nrwS,nrw) (bndrS,expr) = do
lvl <- Lens.view dbgLevel
let before = showDoc expr
let expr' = traceIf (lvl >= DebugFinal)
(bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ before ++ "\n")
expr
rewritten <- runRewrite nrwS nrw expr'
let after = showDoc rewritten
traceIf (lvl >= DebugFinal)
(bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ after ++ "\n") $
return rewritten
checkNonRecursive
:: BindingMap
-> BindingMap
checkNonRecursive norm = case Maybe.mapMaybe go (HashMap.toList norm) of
[] -> norm
rcs -> error $ $(curLoc) ++ "Callgraph after normalisation contains following recursive components: "
++ show (vcat $ runLFreshM $ sequence [ do a' <- ppr a
b' <- ppr b
return $ a' <> b'
| (a,b ) <- rcs
])
where
go (nm,(_,_,_,_,tm)) =
let used = Lens.toListOf termFreeIds tm
in if nm `elem` used
then Just (nm,tm)
else Nothing
cleanupGraph
:: TmOccName
-> BindingMap
-> NormalizeSession BindingMap
cleanupGraph topEntity norm
| Just ct <- mkCallTree [] norm topEntity
= do ctFlat <- flattenCallTree ct
return (HashMap.fromList $ snd $ callTreeToList [] ctFlat)
cleanupGraph _ norm = return norm
data CallTree = CLeaf (TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term))
| CBranch (TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term)) [CallTree]
mkCallTree
:: [TmOccName]
-> BindingMap
-> TmOccName
-> Maybe CallTree
mkCallTree visited bindingMap root
| Just rootTm <- HashMap.lookup root bindingMap
= let used = Set.toList $ Lens.setOf termFreeIds $ (rootTm ^. _5)
other = Maybe.mapMaybe (mkCallTree (root:visited) bindingMap) (filter (`notElem` visited) used)
in case used of
[] -> Just (CLeaf (root,rootTm))
_ -> Just (CBranch (root,rootTm) other)
mkCallTree _ _ _ = Nothing
stripArgs
:: [TmOccName]
-> [Id]
-> [Either Term Type]
-> Maybe [Either Term Type]
stripArgs _ (_:_) [] = Nothing
stripArgs allIds [] args = if any mentionsId args
then Nothing
else Just args
where
mentionsId t = not $ null (either (Lens.toListOf termFreeIds) (const []) t
`intersect`
allIds)
stripArgs allIds (id_:ids) (Left (Var _ nm):args)
| varName id_ == nm = stripArgs allIds ids args
| otherwise = Nothing
stripArgs _ _ _ = Nothing
flattenNode
:: CallTree
-> NormalizeSession (Either CallTree ((TmOccName,Term),[CallTree]))
flattenNode (CLeaf (nm,(nameSort -> Internal,_,_,_,e))) =
return (Right ((nm,e),[]))
flattenNode c@(CLeaf (nm,(_,_,_,_,e))) = do
tcm <- Lens.view tcCache
norm <- splitNormalized tcm e
case norm of
Right (ids,[(_,bExpr)],_) -> do
let (fun,args) = collectArgs (unembed bExpr)
case stripArgs (map (nameOcc.varName) ids) (reverse ids) (reverse args) of
Just remainder -> return (Right ((nm,mkApps fun (reverse remainder)),[]))
Nothing -> return (Right ((nm,e),[]))
_ | isCheapFunction e -> return (Right ((nm,e),[]))
| otherwise -> return (Left c)
flattenNode (CBranch (nm,(nameSort -> Internal,_,_,_,e)) us) =
return (Right ((nm,e),us))
flattenNode b@(CBranch (nm,(_,_,_,_,e)) us) = do
tcm <- Lens.view tcCache
norm <- splitNormalized tcm e
case norm of
Right (ids,[(_,bExpr)],_) -> do
let (fun,args) = collectArgs (unembed bExpr)
case stripArgs (map (nameOcc.varName) ids) (reverse ids) (reverse args) of
Just remainder -> return (Right ((nm,mkApps fun (reverse remainder)),us))
Nothing -> return (Right ((nm,e),us))
_ | isCheapFunction e -> return (Right ((nm,e),us))
| otherwise -> return (Left b)
flattenCallTree
:: CallTree
-> NormalizeSession CallTree
flattenCallTree c@(CLeaf _) = return c
flattenCallTree (CBranch (nm,(nm',ty,sp,inl,tm)) used) = do
flattenedUsed <- mapM flattenCallTree used
(newUsed,il_ct) <- partitionEithers <$> mapM flattenNode flattenedUsed
let (toInline,il_used) = unzip il_ct
newExpr <- case toInline of
[] -> return tm
_ -> rewriteExpr ("flattenExpr",flatten) (showDoc nm, substTms toInline tm)
let allUsed = newUsed ++ concat il_used
if isCheapFunction newExpr
then do
let (toInline',allUsed') = unzip (map goCheap allUsed)
newExpr' <- rewriteExpr ("flattenCheap",flatten) (showDoc nm, substTms toInline' newExpr)
return (CBranch (nm,(nm',ty,sp,inl,newExpr')) (concat allUsed'))
else return (CBranch (nm,(nm',ty,sp,inl,newExpr)) allUsed)
where
flatten =
innerMost (appProp >-> bindConstantVar >-> caseCon >-> reduceConst >-> flattenLet) !->
topdownSucR topLet
goCheap (CLeaf (nm2,(_,_,_,_,e))) = ((nm2,e),[])
goCheap (CBranch (nm2,(_,_,_,_,e)) us) = ((nm2,e),us)
callTreeToList
:: [TmOccName]
-> CallTree
-> ([TmOccName],[(TmOccName,(TmName,Type,SrcSpan,InlineSpec,Term))])
callTreeToList visited (CLeaf (nm,bndr))
| nm `elem` visited = (visited,[])
| otherwise = (nm:visited,[(nm,bndr)])
callTreeToList visited (CBranch (nm,bndr) used)
| nm `elem` visited = (visited,[])
| otherwise = (visited',(nm,bndr):(concat others))
where
(visited',others) = mapAccumL callTreeToList (nm:visited) used
clockResetErrors
:: SrcSpan
-> (HashMap TyConOccName TyCon -> Bool -> Type -> Maybe (Either String HWType))
-> HashMap TyConOccName TyCon
-> Type
-> [String]
clockResetErrors sp tyTran tcm ty =
(Maybe.mapMaybe reportClock clks ++ Maybe.mapMaybe reportResets rsts)
where
(args,_) = splitCoreFunForallTy tcm ty
(_,args') = partitionEithers args
hwArgs = zip (map (unsafeCoreTypeToHWType sp $(curLoc) tyTran tcm False) args') args'
clks = groupBy ((==) `on` fst) . sortBy (compare `on` fst)
$ [ ((nm,i),ty') | (Clock nm i _,ty') <- hwArgs]
rsts = groupBy ((==) `on` (fst.fst)) . sortBy (compare `on` (fst.fst))
$ [ (((nm,i),s),ty') | (Reset nm i s,ty') <- hwArgs]
reportClock clks'
| length clks' >= 2
= Just
$ concat ["The following clocks:\n"
,concatMap (\c -> "* " ++ showDoc (snd c) ++ "\n") clks'
,"belong to the same clock domain and should be connected to "
,"the same clock source in order to prevent meta-stability "
,"issues."
]
| otherwise
= Nothing
reportResets rsts'
| length rsts' >= 2
, any (\((_,sync),_) -> sync == Asynchronous) rsts'
= Just
$ concat ["The following resets:\n"
,concatMap (\c -> "* " ++ showDoc (snd c) ++ "\n") rsts'
,"belong to the same reset domain, and one or more of these "
,"resets is Asynchronous. Ensure that these resets are "
,"synchronized in order to prevent meta-stability issues."
]
reportResets _ = Nothing