{-# LANGUAGE CPP, TupleSections #-}
module Vectorise.Exp
(
vectTopExpr
, vectTopExprs
, vectScalarFun
, vectScalarDFun
)
where
#include "HsVersions.h"
import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Convert
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils
import CoreUtils
import MkCore
import CoreSyn
import CoreFVs
import Class
import DataCon
import TyCon
import TcType
import Type
import TyCoRep
import Var
import VarEnv
import VarSet
import NameSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
import TysPrim
import Outputable
import FastString
import DynFlags
import Util
import Control.Monad
import Data.Maybe
import Data.List
vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
vectTopExpr var expr
= do
{ exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
; if isVIEncaps exprVI
then
return Nothing
else do
{ vExpr <- closedV $
inBind var $
vectAnnPolyExpr False exprVI
; inline <- computeInline exprVI
; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
}
}
computeInline :: CoreExprWithVectInfo -> VM Inline
computeInline ((_, VIDict), _) = return $ DontInline
computeInline (_, AnnTick _ expr) = computeInline expr
computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
where
(tvs, _) = collectAnnTypeBinders expr
computeInline _expr = return $ DontInline
vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
vectTopExprs binds
= do
{ exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
; if all isVIEncaps exprVIs
then return Nothing
else do
{
; let areVIParr = any isVIParr exprVIs
; revised_exprVIs <- if not areVIParr
then return exprVIs
else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
; vExprs <- zipWithM vect vars revised_exprVIs
; return $ Just (areVIParr, vExprs)
}
}
where
(vars, exprs) = unzip binds
vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
vect var exprVI
= do
{ vExpr <- closedV $
inBind var $
vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
; inline <- computeInline exprVI
; return (inline, vectorised vExpr)
}
vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
= vTick tickish <$> vectAnnPolyExpr loop_breaker expr
vectAnnPolyExpr loop_breaker expr
| isVIDict expr
= (, undefined) <$> vectDictExpr (deAnnotate expr)
| otherwise
= polyAbstract tvs $ \args ->
mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
where
(tvs, mono) = collectAnnTypeBinders expr
encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
encapsulateScalars ce@(_, AnnType _ty)
= return ce
encapsulateScalars ce@((_, VISimple), AnnVar _v)
= liftSimpleAndCase ce
encapsulateScalars ce@(_, AnnVar _v)
= return ce
encapsulateScalars ce@(_, AnnLit _)
= return ce
encapsulateScalars ((fvs, vi), AnnTick tck expr)
= do
{ encExpr <- encapsulateScalars expr
; return ((fvs, vi), AnnTick tck encExpr)
}
encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
= do
{ vectAvoid <- isVectAvoidanceAggressive
; varsS <- allScalarVarTypeSet fvs
; bndrsS <- allScalarVarType bndrs
; case (vi, vectAvoid && varsS && bndrsS) of
(VISimple, True) -> liftSimpleAndCase ce
_ -> do
{ encExpr <- encapsulateScalars expr
; return ((fvs, vi), AnnLam bndr encExpr)
}
}
where
(bndrs, _) = collectAnnBndrs ce
encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
= do
{ vectAvoid <- isVectAvoidanceAggressive
; varsS <- allScalarVarTypeSet fvs
; case (vi, (vectAvoid || isSimpleApplication ce) && varsS) of
(VISimple, True) -> liftSimpleAndCase ce
_ -> do
{ encCe1 <- encapsulateScalars ce1
; encCe2 <- encapsulateScalars ce2
; return ((fvs, vi), AnnApp encCe1 encCe2)
}
}
where
isSimpleApplication :: CoreExprWithVectInfo -> Bool
isSimpleApplication (_, AnnTick _ ce) = isSimpleApplication ce
isSimpleApplication (_, AnnCast ce _) = isSimpleApplication ce
isSimpleApplication ce | isSimple ce = True
isSimpleApplication (_, AnnApp ce1 ce2) = isSimple ce1 && isSimpleApplication ce2
isSimpleApplication _ = False
isSimple :: CoreExprWithVectInfo -> Bool
isSimple (_, AnnType {}) = True
isSimple (_, AnnVar {}) = True
isSimple (_, AnnLit {}) = True
isSimple (_, AnnTick _ ce) = isSimple ce
isSimple (_, AnnCast ce _) = isSimple ce
isSimple _ = False
encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
= do
{ vectAvoid <- isVectAvoidanceAggressive
; varsS <- allScalarVarTypeSet fvs
; case (vi, vectAvoid && varsS) of
(VISimple, True) -> liftSimpleAndCase ce
_ -> do
{ encScrut <- encapsulateScalars scrut
; encAlts <- mapM encAlt alts
; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
}
}
where
encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
= do
{ vectAvoid <- isVectAvoidanceAggressive
; varsS <- allScalarVarTypeSet fvs
; case (vi, vectAvoid && varsS) of
(VISimple, True) -> liftSimpleAndCase ce
_ -> do
{ encExpr1 <- encapsulateScalars expr1
; encExpr2 <- encapsulateScalars expr2
; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
}
}
encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
= do
{ vectAvoid <- isVectAvoidanceAggressive
; varsS <- allScalarVarTypeSet fvs
; case (vi, vectAvoid && varsS) of
(VISimple, True) -> liftSimpleAndCase ce
_ -> do
{ encBinds <- mapM encBind binds
; encExpr <- encapsulateScalars expr
; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
}
}
where
encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
encapsulateScalars ((fvs, vi), AnnCast expr coercion)
= do
{ encExpr <- encapsulateScalars expr
; return ((fvs, vi), AnnCast encExpr coercion)
}
encapsulateScalars _
= panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
= do
{ vi <- vectAvoidInfoTypeOf expr
; if (vi == VISimple)
then
liftSimple aexpr
else do
{ alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
; return ((fvs, vi), AnnCase expr bndr t alts')
}
}
liftSimpleAndCase aexpr = liftSimple aexpr
liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
liftSimple ((fvs, vi), AnnVar v)
| v `elemDVarSet` fvs
&& not (isToplevel v)
= return $ ((fvs, vi), AnnVar v)
liftSimple aexpr@((fvs_orig, VISimple), expr)
= do
{ let liftedExpr = mkAnnApps (mkAnnLams (reverse vars) fvs expr) vars
; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
; return $ liftedExpr
}
where
vars = dVarSetElems fvs
fvs = filterDVarSet (not . isToplevel) fvs_orig
mkAnnLams :: [Var] -> DVarSet -> AnnExpr' Var (DVarSet, VectAvoidInfo) -> CoreExprWithVectInfo
mkAnnLams [] fvs expr = ASSERT(isEmptyDVarSet fvs)
((emptyDVarSet, VIEncaps), expr)
mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delDVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
mkAnnApps aexpr [] = aexpr
mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
mkAnnApp aexpr@((fvs, _vi), _expr) v
= ((fvs `extendDVarSet` v, VISimple), AnnApp aexpr ((unitDVarSet v, VISimple), AnnVar v))
liftSimple aexpr
= pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
isToplevel :: Var -> Bool
isToplevel v | isId v = case realIdUnfolding v of
NoUnfolding -> False
BootUnfolding -> False
OtherCon {} -> True
DFunUnfolding {} -> True
CoreUnfolding {uf_is_top = top} -> top
| otherwise = False
vectExpr :: CoreExprWithVectInfo -> VM VExpr
vectExpr aexpr
| (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
= vectFnExpr True False aexpr
| isVIEncaps aexpr
= traceVt "vectExpr (encapsulated constant):" (ppr . deAnnotate $ aexpr) >>
vectConst (deAnnotate aexpr)
vectExpr (_, AnnVar v)
= vectVar v
vectExpr (_, AnnLit lit)
= vectConst $ Lit lit
vectExpr aexpr@(_, AnnLam _ _)
= traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >>
vectFnExpr True False aexpr
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
| v == pAT_ERROR_ID
= do
{ (vty, lty) <- vectAndLiftType ty
; return (mkCoreApps (Var v) [Type (getRuntimeRep "vectExpr" vty), Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
}
where
err' = deAnnotate err
vectExpr e@(_, AnnApp _ arg)
| isAnnTypeArg arg
= vectPolyApp e
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
| Just _con <- isDataConId_maybe v
= do
{ let vexpr = App (Var v) (Lit lit)
; lexpr <- liftPD vexpr
; return (vexpr, lexpr)
}
vectExpr e@(_, AnnApp fn arg)
| isPredTy arg_ty
= vectPolyApp e
| otherwise
= do
{
; varg_ty <- vectType arg_ty
; vres_ty <- vectType res_ty
; vfn <- vectExpr fn
; varg <- vectExpr arg
; mkClosureApp varg_ty vres_ty vfn varg
}
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
vectExpr (_, AnnCase scrut bndr ty alts)
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts
| otherwise
= do
{ dflags <- getDynFlags
; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
ppr scrut_ty
}
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
{ traceVt "let binding (non-recursive)" Outputable.empty
; vrhs <- localV $
inBind bndr $
vectAnnPolyExpr False rhs
; traceVt "let body (non-recursive)" Outputable.empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vLet (vNonRec vbndr vrhs) vbody
}
vectExpr (_, AnnLet (AnnRec bs) body)
= do
{ (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
{ traceVt "let bindings (recursive)" Outputable.empty
; vrhss <- zipWithM vect_rhs bndrs rhss
; traceVt "let body (recursive)" Outputable.empty
; vbody <- vectExpr body
; return (vrhss, vbody)
}
; return $ vLet (vRec vbndrs vrhss) vbody
}
where
(bndrs, rhss) = unzip bs
vect_rhs bndr rhs = localV $
inBind bndr $
vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
vectExpr (_, AnnTick tickish expr)
= vTick tickish <$> vectExpr expr
vectExpr (_, AnnType ty)
= vType <$> vectType ty
vectExpr e
= do
{ dflags <- getDynFlags
; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
}
vectFnExpr :: Bool
-> Bool
-> CoreExprWithVectInfo
-> VM VExpr
vectFnExpr inline loop_breaker aexpr@(_ann, AnnLam bndr body)
| isId bndr
&& isPredTy (idType bndr)
= do
{ vBndr <- vectBndr bndr
; vbody <- vectFnExpr inline loop_breaker body
; return $ mapVect (mkLams [vectorised vBndr]) vbody
}
| isId bndr && isVIEncaps aexpr
= vectScalarFun . deAnnotate $ aexpr
| isId bndr
= vectLam inline loop_breaker aexpr
| otherwise
= do
{ dflags <- getDynFlags
; cantVectorise dflags "Vectorise.Exp.vectFnExpr: Unexpected type lambda" $
ppr (deAnnotate aexpr)
}
vectFnExpr _ _ aexpr
| (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
= vectScalarFun . deAnnotate $ aexpr
| otherwise
= vectExpr aexpr
vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
vectPolyApp e0
= case e4 of
(_, AnnVar var)
-> do {
; vVar <- lookupVar var
; traceVt "vectPolyApp of" (ppr var)
; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
; vTysOuter <- mapM vectType tysOuter
; vTysInner <- mapM vectType tysInner
; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
; case vVar of
Local (vv, lv)
-> do { MASSERT( null dictsInner )
; traceVt " LOCAL" (text "")
; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
}
Global vv
| isDictComp var
-> do {
; ve <- if null dictsInner
then
return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
else
reconstructOuter
(Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
; traceVt " GLOBAL (dict):" (ppr ve)
; vectConst ve
}
| otherwise
-> do { MASSERT( null dictsInner )
; ve <- reconstructOuter (Var vv)
; traceVt " GLOBAL (non-dict):" (ppr ve)
; vectConst ve
}
}
_ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
where
(e1, dictsOuter) = collectAnnDictArgs e0
(e2, tysOuter) = collectAnnTypeArgs e1
(e3, dictsInner) = collectAnnDictArgs e2
(e4, tysInner) = collectAnnTypeArgs e3
isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
vectDictExpr :: CoreExpr -> VM CoreExpr
vectDictExpr (Var var)
= do { mb_scope <- lookupVar_maybe var
; case mb_scope of
Nothing -> return $ Var var
Just (Local (vVar, _)) -> return $ Var vVar
Just (Global vVar) -> return $ Var vVar
}
vectDictExpr (Lit lit)
= pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
vectDictExpr (Lam bndr e)
= Lam bndr <$> vectDictExpr e
vectDictExpr (App fn arg)
= App <$> vectDictExpr fn <*> vectDictExpr arg
vectDictExpr (Case e bndr ty alts)
= Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
where
vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
where
dataConErr = text "Cannot vectorise data constructor:" <+> ppr datacon
vectDictAltCon (LitAlt lit) = return $ LitAlt lit
vectDictAltCon DEFAULT = return DEFAULT
vectDictExpr (Let bnd body)
= Let <$> vectDictBind bnd <*> vectDictExpr body
where
vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
vectDictExpr e@(Cast _e _coe)
= pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
vectDictExpr (Tick tickish e)
= Tick tickish <$> vectDictExpr e
vectDictExpr (Type ty)
= Type <$> vectType ty
vectDictExpr (Coercion coe)
= pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
= do
{ traceVt "vectScalarFun:" (ppr expr)
; let (arg_tys, res_ty) = splitFunTys (exprType expr)
; mkScalarFun arg_tys res_ty expr
}
mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
mkScalarFun arg_tys res_ty expr
| isPredTy res_ty
= do { vExpr <- vectDictExpr expr
; return (vExpr, unused)
}
| otherwise
= do { traceVt "mkScalarFun: " $ ppr expr $$ text " ::" <+>
ppr (mkFunTys arg_tys res_ty)
; fn_var <- hoistExpr (fsLit "fn") expr DontInline
; zipf <- zipScalars arg_tys res_ty
; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
; clo_var <- hoistExpr (fsLit "clo") clo DontInline
; lclo <- liftPD (Var clo_var)
; return (Var clo_var, lclo)
}
where
unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
vectScalarDFun :: Var
-> VM CoreExpr
vectScalarDFun var
= do {
; mapM_ defLocalTyVar tvs
; vTheta <- mapM vectType theta
; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
; let vThetaVars = varsToCoreExprs vThetaBndr
; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
; thetaExprs <- zipWithM unVectDict theta vThetaVars
; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
selIds
; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
; Just vDataCon <- lookupDataCon dataCon
; vTys <- mapM vectType tys
; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
; return $ mkLams (tvs ++ vThetaBndr) vBody
}
where
ty = varType var
(tvs, theta, pty) = tcSplitSigmaTy ty
(cls, tys) = tcSplitDFunHead pty
selIds = classAllSelIds cls
dataCon = classDataCon cls
unVectDict :: Type -> CoreExpr -> VM CoreExpr
unVectDict ty e
= do { vTys <- mapM vectType tys
; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
; scOps <- zipWithM fromVect methTys meths
; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
}
where
(tycon, tys) = splitTyConApp ty
Just dataCon = isDataProductTyCon_maybe tycon
Just cls = tyConClass_maybe tycon
methTys = dataConInstArgTys dataCon tys
selIds = classAllSelIds cls
vectLam :: Bool
-> Bool
-> CoreExprWithVectInfo
-> VM VExpr
vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
= do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
; let (bndrs, body) = collectAnnValBinders expr
; tyvars <- localTyVars
; vfvs <- readLEnv $ \env ->
[ (var, fromJust mb_vv)
| var <- dVarSetElems fvs
, let mb_vv = lookupVarEnv (local_vars env) var
, isJust mb_vv
]
; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
(_fvs_dict, vfvs_dict) = unzip vvs_dict
(fvs_nondict, vfvs_nondict) = unzip vvs_nondict
; arg_tys <- mapM (vectType . idType) bndrs
; res_ty <- vectType (exprType $ deAnnotate body)
; let arity = length fvs_nondict + length bndrs
vfvs_dict' = map vectorised vfvs_dict
; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
. hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
$ do {
; lc <- builtin liftingContext
; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
; vbody' <- break_loop lc res_ty vbody
; return $ vLams lc vbndrs vbody'
}
}
where
maybe_inline n | inline = Inline n
| otherwise = DontInline
break_loop lc ty (ve, le)
| loop_breaker
= do { dflags <- getDynFlags
; empty <- emptyPD ty
; lty <- mkPDataType ty
; return (ve, mkWildCase (Var lc) intPrimTy lty
[(DEFAULT, [], le),
(LitAlt (mkMachInt dflags 0), [], empty)])
}
| otherwise = return (ve, le)
vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
-> [(AltCon, [Var], CoreExprWithVectInfo)]
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
{ traceVt "scrutinee (DEFAULT only)" Outputable.empty
; vscrut <- vectExpr scrut
; (vty, lty) <- vectAndLiftType ty
; traceVt "alternative body (DEFAULT only)" Outputable.empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
}
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
{ traceVt "scrutinee (one shot w/o binders)" Outputable.empty
; vscrut <- vectExpr scrut
; (vty, lty) <- vectAndLiftType ty
; traceVt "alternative body (one shot w/o binders)" Outputable.empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
}
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
{ traceVt "scrutinee (one shot w/ binders)" Outputable.empty
; vexpr <- vectExpr scrut
; (vty, lty) <- vectAndLiftType ty
; traceVt "alternative body (one shot w/ binders)" Outputable.empty
; (vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body
; let (vect_bndrs, lift_bndrs) = unzip vbndrs
; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
; vect_dc <- maybeV dataConErr (lookupDataCon dc)
; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
}
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
mk_wild_case expr ty dc bndrs body
= mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
vectAlgCase tycon _ty_args scrut bndr ty alts
= do
{ traceVt "scrutinee (general case)" Outputable.empty
; vexpr <- vectExpr scrut
; vect_tc <- vectTyCon tycon
; (vty, lty) <- vectAndLiftType ty
; let arity = length (tyConDataCons vect_tc)
; sel_ty <- builtin (selTy arity)
; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
; let sel = Var sel_bndr
; traceVt "alternatives' body (general case)" Outputable.empty
; (vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) alts'
; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
; let (vect_bodies, lift_bodies) = unzip vbodies
; vdummy <- newDummyVar (exprType vect_scrut)
; ldummy <- newDummyVar (exprType lift_scrut)
; let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
; lc <- builtin liftingContext
; lbody <- combinePD vty (Var lc) sel lift_bodies
; let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
; return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
}
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
cmp DEFAULT DEFAULT = EQ
cmp DEFAULT _ = LT
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
= do
dflags <- getDynFlags
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag dflags vect_dc
fvs = fvs_body `delDVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
elems <- builtin (selElements arity ntag)
(vbndrs, vbody)
<- vectBndrsIn bndrs
. localV
$ do
{ binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ dVarSetElems fvs
; traceVt "case alternative:" (ppr . deAnnotate $ body)
; (ve, le) <- vectExpr body
; return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
}
let (vect_bndrs, lift_bndrs) = unzip vbndrs
return (vect_dc, vect_bndrs, lift_bndrs, vbody)
where
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
pack_var len tags t v
= do
{ r <- lookupVar_maybe v
; case r of
Just (Local (vv, lv)) ->
do
{ lv' <- cloneVar lv
; expr <- packByTagPD (idType vv) (Var lv) len tags t
; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
; return [(NonRec lv' expr)]
}
_ -> return []
}
data VectAvoidInfo = VIParr
| VISimple
| VIComplex
| VIEncaps
| VIDict
deriving (Eq, Show)
type CoreExprWithVectInfo = AnnExpr Id (DVarSet, VectAvoidInfo)
annExprType :: AnnExpr Var ann -> Type
annExprType = exprType . deAnnotate
vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
vectAvoidInfoOf ((_, vi), _) = vi
isVIParr :: CoreExprWithVectInfo -> Bool
isVIParr = (== VIParr) . vectAvoidInfoOf
isVIEncaps :: CoreExprWithVectInfo -> Bool
isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
isVIDict :: CoreExprWithVectInfo -> Bool
isVIDict = (== VIDict) . vectAvoidInfoOf
unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
unlessVIParr _ VIParr = VIParr
unlessVIParr vi _ = vi
unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
infixl `unlessVIParrExpr`
unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
vectAvoidInfo pvs ce@(_, AnnVar v)
= do
{ gpvs <- globalParallelVars
; vi <- if v `elemVarSet` pvs || v `elemDVarSet` gpvs
then return VIParr
else vectAvoidInfoTypeOf ce
; viTrace ce vi []
; when (vi == VIParr) $
traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else
if v `elemDVarSet` gpvs then text "global" else text "parallel type"
; return ((fvs, vi), AnnVar v)
}
where
fvs = freeVarsOf ce
vectAvoidInfo _pvs ce@(_, AnnLit lit)
= do
{ vi <- vectAvoidInfoTypeOf ce
; viTrace ce vi []
; return ((fvs, vi), AnnLit lit)
}
where
fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnApp e1 e2)
= do
{ ceVI <- vectAvoidInfoTypeOf ce
; eVI1 <- vectAvoidInfo pvs e1
; eVI2 <- vectAvoidInfo pvs e2
; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
; return ((fvs, vi), AnnApp eVI1 eVI2)
}
where
fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnLam var body)
= do
{ bodyVI <- vectAvoidInfo pvs body
; varVI <- vectAvoidInfoType $ varType var
; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
; return ((fvs, vi), AnnLam var bodyVI)
}
where
fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnLet (AnnNonRec var e) body)
= do
{ ceVI <- vectAvoidInfoTypeOf ce
; eVI <- vectAvoidInfo pvs e
; isScalarTy <- isScalar $ varType var
; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
then do
{ bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
; return (bodyVI, VIParr)
}
else do
{ bodyVI <- vectAvoidInfo pvs body
; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
}
; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
}
where
fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnLet (AnnRec bnds) body)
= do
{ ceVI <- vectAvoidInfoTypeOf ce
; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
; if not . null $ parrBndrs
then do
{ new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
; let extendedPvs = pvs `extendVarSetList` new_pvs
; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
; bodyVI <- vectAvoidInfo extendedPvs body
; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
}
else do
{ bodyVI <- vectAvoidInfo pvs body
; let vi = ceVI `unlessVIParrExpr` bodyVI
; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
}
}
where
fvs = freeVarsOf ce
vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
isVIParrBnd (var, eVI)
= do
{ isScalarTy <- isScalar (varType var)
; return $ isVIParr eVI && not isScalarTy
}
vectAvoidInfo pvs ce@(_, AnnCase e var ty alts)
= do
{ ceVI <- vectAvoidInfoTypeOf ce
; eVI <- vectAvoidInfo pvs e
; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs)
; return ((fvs, vi), AnnCase eVI var ty altsVI)
}
where
fvs = freeVarsOf ce
vectAvoidInfoAlt scrutIsPar (con, bndrs, e)
= do
{ allScalar <- allScalarVarType bndrs
; let altPvs | scrutIsPar && not allScalar = pvs `extendVarSetList` bndrs
| otherwise = pvs
; (con, bndrs,) <$> vectAvoidInfo altPvs e
}
vectAvoidInfo pvs ce@(_, AnnCast e (fvs_ann, ann))
= do
{ eVI <- vectAvoidInfo pvs e
; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((freeVarsOfAnn fvs_ann, VISimple), ann))
}
where
fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnTick tick e)
= do
{ eVI <- vectAvoidInfo pvs e
; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
}
where
fvs = freeVarsOf ce
vectAvoidInfo _pvs ce@(_, AnnType ty)
= return ((fvs, VISimple), AnnType ty)
where
fvs = freeVarsOf ce
vectAvoidInfo _pvs ce@(_, AnnCoercion coe)
= return ((fvs, VISimple), AnnCoercion coe)
where
fvs = freeVarsOf ce
vectAvoidInfoType :: Type -> VM VectAvoidInfo
vectAvoidInfoType ty
| isPredTy ty
= return VIDict
| Just (arg, res) <- splitFunTy_maybe ty
= do
{ argVI <- vectAvoidInfoType arg
; resVI <- vectAvoidInfoType res
; case (argVI, resVI) of
(VISimple, VISimple) -> return VISimple
(_ , VIDict) -> return VIDict
_ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
}
| otherwise
= do
{ parr <- maybeParrTy ty
; if parr
then return VIParr
else do
{ scalar <- isScalar ty
; if scalar
then return VISimple
else return VIComplex
} }
vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
maybeParrTy :: Type -> VM Bool
maybeParrTy ty
| Just ty' <- coreView ty
= (== VIParr) <$> vectAvoidInfoType ty'
| Just (tc, ts) <- splitTyConApp_maybe ty
= do
{ isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
; if isParallel
then return True
else or <$> mapM maybeParrTy ts
}
maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
maybeParrTy _ = return False
allScalarVarType :: [Var] -> VM Bool
allScalarVarType vs = and <$> mapM isScalarOrToplevel vs
where
isScalarOrToplevel v | isToplevel v = return True
| otherwise = isScalar (varType v)
allScalarVarTypeSet :: DVarSet -> VM Bool
allScalarVarTypeSet = allScalarVarType . dVarSetElems
viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
viTrace ce vi vTs
= traceVt ("vect info: " ++ show vi ++ "[" ++
(concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
(ppr $ deAnnotate ce)