{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Core.Util where
import Control.Concurrent.Supply (Supply, freshId)
import qualified Control.Lens as Lens
import Control.Monad.Trans.Except (Except, throwE)
import qualified Data.HashSet as HashSet
import qualified Data.Graph as Graph
import Data.List (foldl', mapAccumR)
import Data.List.Extra (zipEqual)
import Data.Maybe
(fromJust, isJust, mapMaybe, catMaybes)
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import qualified Data.Text as T
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
import PrelNames (ipClassKey)
import Unique (getKey)
import Clash.Core.DataCon
import Clash.Core.EqSolver
import Clash.Core.FreeVars (tyFVsOfTypes, typeFreeVars, freeLocalIds)
import Clash.Core.Name
(Name (..), OccName, mkUnsafeInternalName, mkUnsafeSystemName)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
import Clash.Core.Term
import Clash.Core.TyCon (TyConMap, tyConDataCons)
import Clash.Core.Type
import Clash.Core.TysPrim (typeNatKind)
import Clash.Core.Var (Id, Var(..), mkLocalId, mkTyVar)
import Clash.Core.VarEnv
import Clash.Debug (traceIf)
import Clash.Unique
import Clash.Util
mkVec :: DataCon
-> DataCon
-> Type
-> Integer
-> [Term]
-> Term
mkVec nilCon consCon resTy = go
where
go _ [] = mkApps (Data nilCon) [Right (LitTy (NumTy 0))
,Right resTy
,Left (primCo nilCoTy)
]
go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n))
,Right resTy
,Right (LitTy (NumTy (n-1)))
,Left (primCo (consCoTy n))
,Left x
,Left (go (n-1) xs)]
nilCoTy = head (fromJust $! dataConInstArgTys nilCon [(LitTy (NumTy 0))
,resTy])
consCoTy n = head (fromJust $! dataConInstArgTys consCon
[(LitTy (NumTy n))
,resTy
,(LitTy (NumTy (n-1)))])
appendToVec :: DataCon
-> Type
-> Term
-> Integer
-> [Term]
-> Term
appendToVec consCon resTy vec = go
where
go _ [] = vec
go n (x:xs) = mkApps (Data consCon) [Right (LitTy (NumTy n))
,Right resTy
,Right (LitTy (NumTy (n-1)))
,Left (primCo (consCoTy n))
,Left x
,Left (go (n-1) xs)]
consCoTy n = head (fromJust $! dataConInstArgTys consCon
[(LitTy (NumTy n))
,resTy
,(LitTy (NumTy (n-1)))])
extractElems
:: Supply
-> InScopeSet
-> DataCon
-> Type
-> Char
-> Integer
-> Term
-> (Supply, [(Term,[LetBinding])])
extractElems supply inScope consCon resTy s maxN vec =
first fst (go maxN (supply,inScope) vec)
where
go :: Integer -> (Supply,InScopeSet) -> Term
-> ((Supply,InScopeSet),[(Term,[LetBinding])])
go 0 uniqs _ = (uniqs,[])
go n uniqs0 e =
(uniqs3,(elNVar,[(elNId, lhs),(restNId, rhs)]):restVs)
where
tys = [(LitTy (NumTy n)),resTy,(LitTy (NumTy (n-1)))]
(Just idTys) = dataConInstArgTys consCon tys
restTy = last idTys
(uniqs1,mTV) = mkUniqSystemTyVar uniqs0 ("m",typeNatKind)
(uniqs2,[elNId,restNId,co,el,rest]) =
mapAccumR mkUniqSystemId uniqs1 $ zip
["el" `T.append` (s `T.cons` T.pack (show (maxN-n)))
,"rest" `T.append` (s `T.cons` T.pack (show (maxN-n)))
,"_co_"
,"el"
,"rest"
]
(resTy:restTy:idTys)
elNVar = Var elNId
pat = DataPat consCon [mTV] [co,el,rest]
lhs = Case e resTy [(pat,Var el)]
rhs = Case e restTy [(pat,Var rest)]
(uniqs3,restVs) = go (n-1) uniqs2 (Var restNId)
extractTElems
:: Supply
-> InScopeSet
-> DataCon
-> DataCon
-> Type
-> Char
-> Integer
-> Term
-> (Supply,([Term],[LetBinding]))
extractTElems supply inScope lrCon brCon resTy s maxN tree =
first fst (go maxN [0..(2^(maxN+1))-2] [0..(2^maxN - 1)] (supply,inScope) tree)
where
go :: Integer
-> [Int]
-> [Int]
-> (Supply,InScopeSet)
-> Term
-> ((Supply,InScopeSet),([Term],[LetBinding]))
go 0 _ ks uniqs0 e = (uniqs1,([elNVar],[(elNId, rhs)]))
where
tys = [LitTy (NumTy 0),resTy]
(Just idTys) = dataConInstArgTys lrCon tys
(uniqs1,[elNId,co,el]) =
mapAccumR mkUniqSystemId uniqs0 $ zip
[ "el" `T.append` (s `T.cons` T.pack (show (head ks)))
, "_co_"
, "el"
]
(resTy:idTys)
elNVar = Var elNId
pat = DataPat lrCon [] [co,el]
rhs = Case e resTy [(pat,Var el)]
go n bs ks uniqs0 e =
(uniqs4
,(lVars ++ rVars,(ltNId, ltRhs):
(rtNId, rtRhs):
(lBinds ++ rBinds)))
where
tys = [LitTy (NumTy n),resTy,LitTy (NumTy (n-1))]
(Just idTys) = dataConInstArgTys brCon tys
(uniqs1,mTV) = mkUniqSystemTyVar uniqs0 ("m",typeNatKind)
(b0:bL,b1:bR) = splitAt (length bs `div` 2) bs
brTy = last idTys
(uniqs2,[ltNId,rtNId,co,lt,rt]) =
mapAccumR mkUniqSystemId uniqs1 $ zip
["lt" `T.append` (s `T.cons` T.pack (show b0))
,"rt" `T.append` (s `T.cons` T.pack (show b1))
,"_co_"
,"lt"
,"rt"
]
(brTy:brTy:idTys)
ltVar = Var ltNId
rtVar = Var rtNId
pat = DataPat brCon [mTV] [co,lt,rt]
ltRhs = Case e brTy [(pat,Var lt)]
rtRhs = Case e brTy [(pat,Var rt)]
(kL,kR) = splitAt (length ks `div` 2) ks
(uniqs3,(lVars,lBinds)) = go (n-1) bL kL uniqs2 ltVar
(uniqs4,(rVars,rBinds)) = go (n-1) bR kR uniqs3 rtVar
mkRTree :: DataCon
-> DataCon
-> Type
-> Integer
-> [Term]
-> Term
mkRTree lrCon brCon resTy = go
where
go _ [x] = mkApps (Data lrCon) [Right (LitTy (NumTy 0))
,Right resTy
,Left (primCo lrCoTy)
,Left x
]
go n xs =
let (xsL,xsR) = splitAt (length xs `div` 2) xs
in mkApps (Data brCon) [Right (LitTy (NumTy n))
,Right resTy
,Right (LitTy (NumTy (n-1)))
,Left (primCo (brCoTy n))
,Left (go (n-1) xsL)
,Left (go (n-1) xsR)]
lrCoTy = head (fromJust $! dataConInstArgTys lrCon [(LitTy (NumTy 0))
,resTy])
brCoTy n = head (fromJust $! dataConInstArgTys brCon
[(LitTy (NumTy n))
,resTy
,(LitTy (NumTy (n-1)))])
isSignalType :: TyConMap -> Type -> Bool
isSignalType tcm ty = go HashSet.empty ty
where
go tcSeen (tyView -> TyConApp tcNm args) = case nameOcc tcNm of
"Clash.Signal.Internal.Signal" -> True
"Clash.Signal.BiSignal.BiSignalIn" -> True
"Clash.Signal.Internal.BiSignalOut" -> True
_ | tcNm `HashSet.member` tcSeen -> False
| otherwise -> case lookupUniqMap tcNm tcm of
Just tc -> let dcs = tyConDataCons tc
dcInsArgTys = concat
$ mapMaybe (`dataConInstArgTys` args) dcs
tcSeen' = HashSet.insert tcNm tcSeen
in any (go tcSeen') dcInsArgTys
Nothing -> traceIf True ($(curLoc) ++ "isSignalType: " ++ show tcNm
++ " not found.") False
go _ _ = False
isEnable
:: TyConMap
-> Type
-> Bool
isEnable m ty0
| TyConApp (nameOcc -> "Clash.Signal.Internal.Enable") _ <- tyView ty0 = True
| Just ty1 <- coreView1 m ty0 = isEnable m ty1
isEnable _ _ = False
isClockOrReset
:: TyConMap
-> Type
-> Bool
isClockOrReset m (coreView1 m -> Just ty) = isClockOrReset m ty
isClockOrReset _ (tyView -> TyConApp tcNm _) = case nameOcc tcNm of
"Clash.Signal.Internal.Clock" -> True
"Clash.Signal.Internal.Reset" -> True
_ -> False
isClockOrReset _ _ = False
tyNatSize :: TyConMap
-> Type
-> Except String Integer
tyNatSize m (coreView1 m -> Just ty) = tyNatSize m ty
tyNatSize _ (LitTy (NumTy i)) = return i
tyNatSize _ ty = throwE $ $(curLoc) ++ "Cannot reduce to an integer:\n" ++ showPpr ty
mkUniqSystemTyVar
:: (Supply, InScopeSet)
-> (OccName, Kind)
-> ((Supply, InScopeSet), TyVar)
mkUniqSystemTyVar (supply,inScope) (nm, ki) =
((supply',extendInScopeSet inScope v'), v')
where
(u,supply') = freshId supply
v = mkTyVar ki (mkUnsafeSystemName nm u)
v' = uniqAway inScope v
mkUniqSystemId
:: (Supply, InScopeSet)
-> (OccName, Type)
-> ((Supply,InScopeSet), Id)
mkUniqSystemId (supply,inScope) (nm, ty) =
((supply',extendInScopeSet inScope v'), v')
where
(u,supply') = freshId supply
v = mkLocalId ty (mkUnsafeSystemName nm u)
v' = uniqAway inScope v
mkUniqInternalId
:: (Supply, InScopeSet)
-> (OccName, Type)
-> ((Supply,InScopeSet), Id)
mkUniqInternalId (supply,inScope) (nm, ty) =
((supply',extendInScopeSet inScope v'), v')
where
(u,supply') = freshId supply
v = mkLocalId ty (mkUnsafeInternalName nm u)
v' = uniqAway inScope v
dataConInstArgTysE
:: HasCallStack
=> InScopeSet
-> TyConMap
-> DataCon
-> [Type]
-> Maybe [Type]
dataConInstArgTysE is0 tcm (MkData { dcArgTys, dcExtTyVars, dcUnivTyVars }) inst_tys = do
let is1 = extendInScopeSetList is0 dcExtTyVars
is2 = unionInScope is1 (mkInScopeSet (tyFVsOfTypes inst_tys))
subst = extendTvSubstList (mkSubst is2) (zip dcUnivTyVars inst_tys)
go
(substGlobalsInExistentials is0 dcExtTyVars (zip dcUnivTyVars inst_tys))
(map (substTy subst) dcArgTys)
where
go
:: [TyVar]
-> [Type]
-> Maybe [Type]
go exts0 args0 =
let eqs = catMaybes (map (typeEq tcm) args0) in
case solveNonAbsurds tcm eqs of
[] ->
Just args0
sols ->
go exts1 args1
where
exts1 = substInExistentialsList is0 exts0 sols
is2 = extendInScopeSetList is0 exts1
subst = extendTvSubstList (mkSubst is2) sols
args1 = map (substTy subst) args0
dataConInstArgTys :: DataCon -> [Type] -> Maybe [Type]
dataConInstArgTys (MkData { dcArgTys, dcUnivTyVars, dcExtTyVars }) inst_tys =
let tyvars = dcUnivTyVars ++ dcExtTyVars in
if length tyvars == length inst_tys then
Just (map (substTyWith tyvars inst_tys) dcArgTys)
else
Nothing
primCo
:: Type
-> Term
primCo ty = Prim (PrimInfo "_CO_" ty WorkNever)
undefinedTm
:: Type
-> Term
undefinedTm = TyApp (Prim (PrimInfo "Clash.Transformations.undefined" undefinedTy WorkNever))
substArgTys
:: DataCon
-> [Type]
-> [Type]
substArgTys dc args =
let univTVs = dcUnivTyVars dc
extTVs = dcExtTyVars dc
argsFVs = foldl' unionVarSet emptyVarSet
(map (Lens.foldMapOf typeFreeVars unitVarSet) args)
is = mkInScopeSet (argsFVs `unionVarSet` mkVarSet extTVs)
subst = extendTvSubstList (mkSubst is) (univTVs `zipEqual` args)
in map (substTy subst) (dcArgTys dc)
tyLitShow
:: TyConMap
-> Type
-> Except String String
tyLitShow m (coreView1 m -> Just ty) = tyLitShow m ty
tyLitShow _ (LitTy (SymTy s)) = return s
tyLitShow _ (LitTy (NumTy s)) = return (show s)
tyLitShow _ ty = throwE $ $(curLoc) ++ "Cannot reduce to a string:\n" ++ showPpr ty
shouldSplit
:: TyConMap
-> Type
-> Maybe (Term,[Type])
shouldSplit tcm (tyView -> TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") [tyArg]) =
shouldSplit tcm tyArg
shouldSplit tcm ty = shouldSplit0 tcm (tyView (coreView tcm ty))
shouldSplit0
:: TyConMap
-> TypeView
-> Maybe (Term,[Type])
shouldSplit0 tcm (TyConApp tcNm tyArgs)
| Just tc <- lookupUniqMap tcNm tcm
, [dc] <- tyConDataCons tc
, let dcArgs = substArgTys dc tyArgs
, let dcArgVs = map (tyView . coreView tcm) dcArgs
= if any shouldSplitTy dcArgVs && not (isHidden tcNm tyArgs) then
Just (mkApps (Data dc) (map Right tyArgs), dcArgs)
else
Nothing
where
shouldSplitTy :: TypeView -> Bool
shouldSplitTy ty = isJust (shouldSplit0 tcm ty) || splitTy ty
isHidden :: Name a -> [Type] -> Bool
isHidden nm [a1, a2] | TyConApp a2Nm _ <- tyView a2 =
nameOcc nm == "GHC.Classes.(%,%)"
&& splitTy (tyView (stripIP a1))
&& nameOcc a2Nm == "Clash.Signal.Internal.KnownDomain"
isHidden _ _ = False
splitTy (TyConApp tcNm0 _)
= nameOcc tcNm0 `elem` [ "Clash.Signal.Internal.Clock"
, "Clash.Signal.Internal.Reset"
, "Clash.Signal.Internal.Enable"
, "Clash.Explicit.SimIO.File"
, "GHC.IO.Handle.Types.Handle"
]
splitTy _ = False
shouldSplit0 _ _ = Nothing
splitShouldSplit
:: TyConMap
-> [Type]
-> [Type]
splitShouldSplit tcm = foldr go []
where
go ty rest = case shouldSplit tcm ty of
Just (_,tys) -> splitShouldSplit tcm tys ++ rest
Nothing -> ty : rest
stripIP :: Type -> Type
stripIP t@(tyView -> TyConApp tcNm [_a1, a2]) =
if nameUniq tcNm == getKey ipClassKey then a2 else t
stripIP t = t
inverseTopSortLetBindings
:: HasCallStack
=> Term
-> Term
inverseTopSortLetBindings (Letrec bndrs0 res) =
let (graph,nodeMap,_) =
Graph.graphFromEdges
(map (\(i,e) -> let fvs = fmap varUniq
(Set.elems (Lens.setOf freeLocalIds e) )
in ((i,e),varUniq i,fvs)) bndrs0)
nodes = postOrd graph
bndrs1 = map ((\(x,_,_) -> x) . nodeMap) nodes
in Letrec bndrs1 res
where
postOrd :: Graph.Graph -> [Graph.Vertex]
postOrd g = postorderF (Graph.dff g) []
postorderF :: Graph.Forest a -> [a] -> [a]
postorderF ts = foldr (.) id (map postorder ts)
postorder :: Graph.Tree a -> [a] -> [a]
postorder (Graph.Node a ts) = postorderF ts . (a :)
inverseTopSortLetBindings e = e
{-# SCC inverseTopSortLetBindings #-}