{-#LANGUAGE CPP#-}
module Foreign.Storable.Generic.Plugin.Internal.Compile
(
compileExpr
, tryCompileExpr
, intToExpr
, intSubstitution
, offsetSubstitution
, offsetSubstitutionTree
, OffsetScope(..)
, getScopeId
, getScopeExpr
, intListExpr
, exprToIntList
, isLitOrGlobal
, inScopeAll
, isIndexer
, caseExprIndex
, compileGStorableBind
, lintBind
, replaceIdsBind
, compileGroups
)
where
import Prelude hiding ((<>))
import CoreSyn (Bind(..),Expr(..), CoreExpr, CoreBind, CoreProgram, Alt, AltCon(..), isId, Unfolding(..))
import Literal (Literal(..))
#if MIN_VERSION_GLASGOW_HASKELL(8,6,0,0)
import Literal (LitNumType(..))
#endif
import Id (isLocalId, isGlobalId,setIdInfo,Id)
import IdInfo (IdInfo(..))
import Var (Var(..), idInfo)
import Name (getOccName,mkOccName, getSrcSpan)
import OccName (OccName(..), occNameString)
import qualified Name as N (varName, tvName, tcClsName)
import SrcLoc (noSrcSpan, SrcSpan)
import Unique (getUnique)
import HscMain (hscCompileCoreExpr)
import HscTypes (HscEnv,ModGuts(..))
import CoreMonad (CoreM,CoreToDo(..), getHscEnv, getDynFlags)
import CoreLint (lintExpr)
import BasicTypes (CompilerPhase(..))
import Type (isAlgType, splitTyConApp_maybe)
import TyCon (tyConName, algTyConRhs, visibleDataCons)
import TyCoRep (Type(..), TyBinder(..), TyLit(..))
import TysWiredIn
import TysPrim (intPrimTy)
import DataCon (dataConWorkId,dataConOrigArgTys)
import MkCore (mkWildValBinder)
import Outputable (cat, ppr, SDoc, showSDocUnsafe)
import Outputable (Outputable(..),($$), ($+$), vcat, empty,text, (<>), (<+>), nest, int, comma)
import CoreMonad (putMsg, putMsgS)
import GHCi.RemoteTypes
import PrelNames (buildIdKey, augmentIdKey)
import DataCon (dataConWorkId)
import BasicTypes (Boxity(..))
import Unsafe.Coerce
import Data.List
import Data.Maybe
import Data.Either
import Debug.Trace
import Control.Monad.IO.Class
import Control.Monad
import Control.Applicative hiding (empty)
import Control.Exception
import Foreign.Storable.Generic.Plugin.Internal.Helpers
import Foreign.Storable.Generic.Plugin.Internal.Error
import Foreign.Storable.Generic.Plugin.Internal.Predicates
import Foreign.Storable.Generic.Plugin.Internal.Types
compileExpr :: HscEnv -> CoreExpr -> SrcSpan -> IO a
compileExpr hsc_env expr src_span = do
foreign_hval <- liftIO $ hscCompileCoreExpr hsc_env src_span expr
hval <- liftIO $ withForeignRef foreign_hval localRef
let val = unsafeCoerce hval :: a
return val
tryCompileExpr :: Id -> CoreExpr -> CoreM (Either Error a)
tryCompileExpr id core_expr = do
hsc_env <- getHscEnv
e_compiled <- liftIO $ try $
compileExpr hsc_env core_expr (getSrcSpan id) :: CoreM (Either SomeException a)
case e_compiled of
Left se -> return $ Left $ CompilationError (NonRec id core_expr) (stringToPpr $ show se)
Right val-> return $ Right val
intLiteral :: (Integral a) => a -> CoreExpr
#if MIN_VERSION_GLASGOW_HASKELL(8,6,0,0)
intLiteral i = Lit $ LitNumber LitNumInt (fromIntegral i) intPrimTy
#else
intLiteral i = Lit $ MachInt $ fromIntegral i
#endif
intToExpr :: Type -> Int -> CoreExpr
intToExpr t i = Lam wild $ App fun arg
where fun = Var $ dataConWorkId intDataCon
arg = intLiteral i
wild= mkWildValBinder t
intSubstitution :: CoreBind -> CoreM (Either Error CoreBind)
intSubstitution b@(Rec _) = return $ Left $ CompilationNotSupported b
#if MIN_VERSION_GLASGOW_HASKELL(8,8,1,0)
#endif
intSubstitution b@(NonRec id (Lam l1 l@(Lam l2 e@(Lam l3 expr)))) = do
hsc_env <- getHscEnv
the_integer <- tryCompileExpr id expr :: CoreM (Either Error Int)
let m_t = getGStorableType (varType id)
case m_t of
Just t -> return $ NonRec id <$> (Lam l1 <$> (Lam l2 <$> (intToExpr t <$> the_integer)))
Nothing ->
return the_integer >> return $ Left $ CompilationError b (text "Type not found")
intSubstitution b@(NonRec id (Lam l1 expr)) = do
hsc_env <- getHscEnv
the_integer <- tryCompileExpr id expr :: CoreM (Either Error Int)
let m_t = getGStorableType (varType id)
case m_t of
Just t -> return $ NonRec id <$> (intToExpr t <$> the_integer)
Nothing ->
return the_integer >> return $ Left $ CompilationError b (text "Type not found")
intSubstitution b@(NonRec id e@(App expr g)) = case expr of
Lam _ (Lam _ (Lam _ e)) -> intSubstitution $ NonRec id expr
App e t -> do
subs <- intSubstitution $ NonRec id e
case subs of
Right (NonRec i (Lam l1 (Lam l2 e)) ) -> return (Right $ NonRec i e)
err -> return err
_ -> intSubstitutionWorker id expr
intSubstitution b@(NonRec id (Case _ _ _ _)) = error $ "am case"
intSubstitution b@(NonRec id (Let _ _)) = error $ "am let"
intSubstitution b@(NonRec id e) = error $ showSDocUnsafe $ ppr e
intSubstitutionWorker id expr = do
hsc_env <- getHscEnv
the_integer <- tryCompileExpr id expr :: CoreM (Either Error Int)
let m_t = getGStorableType (varType id)
case m_t of
Just t -> return $ NonRec id <$> (intToExpr t <$> the_integer)
Nothing ->
return the_integer >> return $ Left $ CompilationError (NonRec id expr) (text "Type not found")
offsetSubstitution :: CoreBind -> CoreM (Either Error CoreBind)
offsetSubstitution b@(Rec _) = return $ Left $ CompilationNotSupported b
offsetSubstitution b@(NonRec id expr) = do
e_subs <- offsetSubstitutionTree [] expr
let ne_subs = case e_subs of
Left (OtherError sdoc)
-> Left $ CompilationError b sdoc
Left err@(CompilationError _ _)
-> Left $ CompilationError b (pprError Some err)
a -> a
return $ NonRec id <$> e_subs
data OffsetScope = IntList Id CoreExpr
| IntPrimVal Id CoreExpr
getScopeId :: OffsetScope -> Id
getScopeId (IntList id _) = id
getScopeId (IntPrimVal id _) = id
getScopeExpr :: OffsetScope -> CoreExpr
getScopeExpr (IntList _ expr) = expr
getScopeExpr (IntPrimVal _ expr) = expr
instance Outputable OffsetScope where
ppr (IntList id expr) = ppr id <+> ppr (getUnique id) <+> comma <+> ppr expr
ppr (IntPrimVal id expr) = ppr id <+> ppr (getUnique id) <+> comma <+> ppr expr
pprPrec _ el = ppr el
intListExpr :: [Int] -> CoreExpr
intListExpr list = intListExpr' (reverse list) empty_list
where empty_list = App ( Var $ dataConWorkId nilDataCon) (Type intTy)
intListExpr' :: [Int] -> CoreExpr -> CoreExpr
intListExpr' [] acc = acc
intListExpr' (l:ls) acc = intListExpr' ls $ App int_cons acc
where int_t_cons = App (Var $ dataConWorkId consDataCon) (Type intTy)
int_val = App (Var $ dataConWorkId intDataCon ) (intLiteral l)
int_cons = App int_t_cons int_val
exprToIntList :: Id -> CoreExpr -> CoreM (Either Error OffsetScope)
exprToIntList id core_expr = do
int_list <- tryCompileExpr id core_expr
let new_expr = intListExpr <$> int_list
return $ IntList id <$> new_expr
intPrimValExpr :: Int -> CoreExpr
intPrimValExpr i = intLiteral i
exprToIntVal :: Id -> CoreExpr -> CoreM (Either Error OffsetScope)
exprToIntVal id core_expr = do
int_val <- tryCompileExpr id core_expr
let new_expr = intPrimValExpr <$> int_val
return $ IntPrimVal id <$> new_expr
isLitOrGlobal :: CoreExpr -> Maybe CoreExpr
isLitOrGlobal e@(Lit _) = Just e
isLitOrGlobal e@(Var id)
| isGlobalId id
= Just e
isLitOrGlobal _ = Nothing
inScopeAll :: [OffsetScope] -> CoreExpr -> Maybe CoreExpr
inScopeAll (el:rest) e@(Var v_id)
| id <- getScopeId el
, id == v_id
, getOccName (varName id) == getOccName (varName v_id)
= Just $ getScopeExpr el
| otherwise = inScopeAll rest e
inScopeAll _ _ = Nothing
isIndexer :: Id
-> Bool
isIndexer id = getOccName (varName id) == mkOccName N.varName "$w!!"
caseExprIndex :: [OffsetScope] -> CoreExpr -> Maybe CoreExpr
caseExprIndex scope expr
| App beg lit <- expr
, Just lit_expr <- inScopeAll scope lit <|> isLitOrGlobal lit
, App beg2 offsets <- beg
, Just list_expr <- inScopeAll scope offsets <|> Just offsets
, App ix_var t_int <- beg2
, Var ix_id <- ix_var
, Type intt <- t_int
, isIntType intt
, isIndexer ix_id
= Just $ App (App (App ix_var t_int) list_expr) lit_expr
| otherwise = Nothing
offsetSubstitutionTree :: [OffsetScope] -> CoreExpr -> CoreM (Either Error CoreExpr)
offsetSubstitutionTree scope e@(Lit _ ) = return $ Right e
offsetSubstitutionTree scope e@(App e1 e2) = do
subs1 <- offsetSubstitutionTree scope e1
subs2 <- offsetSubstitutionTree scope e2
return $ App <$> subs1 <*> subs2
offsetSubstitutionTree scope e@(Cast expr c) = do
subs <- offsetSubstitutionTree scope expr
return $ Cast <$> subs <*> pure c
offsetSubstitutionTree scope e@(Tick t expr) = do
subs <- offsetSubstitutionTree scope expr
return $ Tick t <$> subs
offsetSubstitutionTree scope e@(Type _ ) = return $ Right e
offsetSubstitutionTree scope e@(Coercion _) = return $ Right e
offsetSubstitutionTree scope e@(Lam b expr) = do
subs <- offsetSubstitutionTree scope expr
return $ Lam b <$> subs
offsetSubstitutionTree scope expr
| Let offset_bind in_expr <- expr
, NonRec offset_id offset_expr <- offset_bind
, isOffsetsId offset_id
= do
e_new_s <- exprToIntList offset_id offset_expr
case e_new_s of
Left err -> return $ Left err
Right int_list -> offsetSubstitutionTree (int_list:scope) in_expr
| Let bind in_expr <- expr
= do
subs <- offsetSubstitutionTree scope in_expr
let sub_idexpr (id,e) = do
inner_subs <- offsetSubstitutionTree scope e
return $ (,) id <$> inner_subs
sub_bind (NonRec id e) = do
inner_subs <- offsetSubstitutionTree scope e
return $ NonRec id <$> inner_subs
sub_bind (Rec bs) = do
inner_subs <- mapM sub_idexpr bs
case lefts inner_subs of
[] -> return $ Right $ Rec (rights inner_subs)
(err:_) -> return $ Left err
bind_subs <- sub_bind bind
return $ Let <$> bind_subs <*> subs
| Case case_expr _ _ [alt0] <- expr
, (DataAlt i_prim_con, [x_id], alt_expr) <- alt0
, i_prim_con == intDataCon
, Just new_case_expr <- caseExprIndex scope case_expr
= do
e_new_s <- exprToIntVal x_id new_case_expr
case e_new_s of
Left err -> return $ Left err
Right int_val -> offsetSubstitutionTree (int_val:scope) alt_expr
| Case case_expr cb t alts <- expr
= do
e_new_alts <- mapM (\(a, args, a_expr) -> (,,) a args <$> offsetSubstitutionTree scope a_expr) alts
new_case_expr <- offsetSubstitutionTree scope case_expr
let c_err = find (\(_,_,e) -> isLeft e) e_new_alts
case c_err of
Nothing -> return $ Case <$> new_case_expr
<*> pure cb <*> pure t <*> pure [(a,b,ne) | (a,b,Right ne) <- e_new_alts]
Just (_,_,err) -> return err
| Var id <- expr
= do
let m_subs = inScopeAll scope expr
new_e = m_subs <|> Just expr
case new_e of
Just e -> return $ Right e
Nothing -> return $ Left $ OtherError (text "This shouldn't happen."
$$ text "`m_subs <|> Just e` cannot be `Nothing`.")
| otherwise = return $ Left $ OtherError $ (text "Unsupported expression:" $$ ppr expr)
compileGStorableBind :: CoreBind -> CoreM (Either Error CoreBind)
compileGStorableBind core_bind
| (NonRec id expr) <- core_bind
, isSizeOfId id || isSpecSizeOfId id || isChoiceSizeOfId id
= intSubstitution core_bind
| (NonRec id expr) <- core_bind
, isAlignmentId id || isSpecAlignmentId id || isChoiceAlignmentId id
= intSubstitution core_bind
| (NonRec id expr) <- core_bind
, isPeekId id || isSpecPeekId id || isChoicePeekId id
= offsetSubstitution core_bind
| (NonRec id expr) <- core_bind
, isPokeId id || isSpecPokeId id || isChoicePokeId id
= offsetSubstitution core_bind
| otherwise = return $ Left $ CompilationNotSupported core_bind
replaceUnfoldingBind :: CoreBind -> CoreBind
replaceUnfoldingBind b@(NonRec id expr)
| NonRec id expr <- b
, isId id
, id_info <- idInfo id
, unfolding <- unfoldingInfo id_info
, _ <- uf_tmpl
= NonRec (setIdInfo id $ id_info {unfoldingInfo = unfolding{uf_tmpl = expr} } ) expr
| otherwise
= b
lintBind :: CoreBind
-> CoreBind
-> CoreM (Either Error CoreBind)
lintBind b_old b@(NonRec id expr) = do
dyn_flags <- getDynFlags
case lintExpr dyn_flags [] expr of
Just sdoc -> (return $ Left $ CompilationError b_old sdoc)
Nothing -> return $ Right b
lintBind b_old b@(Rec bs) = do
dyn_flags <- getDynFlags
let errs = mapMaybe (\(_,expr) -> lintExpr dyn_flags [] expr) bs
case errs of
[] -> return $ Right b
_ -> return $ Left $ CompilationError b_old (vcat errs)
replaceIdsBind :: [CoreBind]
-> [CoreBind]
-> CoreBind
-> CoreBind
replaceIdsBind gstorable_bs other_bs (NonRec id e) = NonRec id (replaceIds gstorable_bs other_bs e)
replaceIdsBind gstorable_bs other_bs (Rec recs) = Rec $ map (\(id,e) -> (id,replaceIds gstorable_bs other_bs e)) recs
replaceIds :: [CoreBind]
-> [CoreBind]
-> CoreExpr
-> CoreExpr
replaceIds gstorable_bs other_bs e@(Var id)
| isLocalId id
, Just (_,expr) <- find ((id==).fst) $ [(id,expr) | NonRec id expr <- gstorable_bs]
= replaceIds gstorable_bs other_bs expr
| isLocalId id
, Just (_,expr) <- find ((id==).fst) $ [(id,expr) | NonRec id expr <- other_bs]
= replaceIds gstorable_bs other_bs expr
| isLocalId id
, ([id_here],rest) <- partition (\x -> id `elem` (map fst x)) $ [bs | Rec bs <- gstorable_bs]
, Just (_,expr) <- find ((id==).fst) id_here
= replaceIds (map Rec rest) other_bs expr
| isLocalId id
, ([id_here],rest) <- partition (\x -> id `elem` (map fst x)) $ [bs | Rec bs <- other_bs]
, Just (_,expr) <- find ((id==).fst) id_here
= replaceIds gstorable_bs (map Rec rest) expr
| otherwise = e
replaceIds gstorable_bs other_bs (App e1 e2) = App (replaceIds gstorable_bs other_bs e1) (replaceIds gstorable_bs other_bs e2)
replaceIds gstorable_bs other_bs (Lam id e) = Lam id (replaceIds gstorable_bs other_bs e)
replaceIds gstorable_bs other_bs (Let b e) = Let (replaceIdsBind gstorable_bs other_bs b) (replaceIds gstorable_bs other_bs e)
replaceIds gstorable_bs other_bs (Case e ev t alts) = do
let new_e = replaceIds gstorable_bs other_bs e
new_alts = map (\(alt, ids, exprs) -> (alt,ids, replaceIds gstorable_bs other_bs exprs)) alts
Case new_e ev t new_alts
replaceIds gstorable_bs other_bs (Cast e c) = Cast (replaceIds gstorable_bs other_bs e) c
replaceIds gstorable_bs other_bs (Tick t e) = Tick t (replaceIds gstorable_bs other_bs e)
replaceIds gstorable_bs other_bs e = e
compileGroups :: Flags
-> [[CoreBind]]
-> [CoreBind]
-> CoreM [CoreBind]
compileGroups flags bind_groups bind_rest = compileGroups_rec flags 0 bind_groups bind_rest [] []
compileGroups_rec :: Flags
-> Int
-> [[CoreBind]]
-> [CoreBind]
-> [CoreBind]
-> [CoreBind]
-> CoreM [CoreBind]
compileGroups_rec flags _ [] bind_rest subs not_subs = return $ concat [subs,not_subs]
compileGroups_rec flags d (bg:bgs) bind_rest subs not_subs = do
let layer_replaced = map (replaceIdsBind bind_rest subs) bg
compile_and_lint bind = do
e_compiled <- compileGStorableBind bind
case e_compiled of
Right bind' -> lintBind bind (replaceUnfoldingBind bind')
_ -> return e_compiled
e_compiled <- mapM compile_and_lint layer_replaced
let errors = lefts e_compiled
compiled = rights e_compiled
not_compiled <- compileGroups_error flags d errors
compileGroups_rec flags (d+1) bgs bind_rest (concat [compiled,subs]) (concat [not_compiled, not_subs])
compileGroups_error :: Flags
-> Int
-> [Error]
-> CoreM [CoreBind]
compileGroups_error flags d errors = do
let (Flags verb to_crash) = flags
crasher errs = case errs of
[] -> return ()
_ -> error "Crashing..."
print_header txt = case verb of
None -> empty
other -> text "Errors while compiling and substituting bindings at depth " <+> int d <> text ":"
$$ nest 4 txt
printer errs = case errs of
[] -> return ()
ls -> putMsg $ print_header (vcat (map (pprError verb) errs))
ungroup err = case err of
(CompilationNotSupported bind) -> Just bind
(CompilationError bind _) -> Just bind
_ -> Nothing
printer errors
when to_crash $ crasher errors
return $ mapMaybe ungroup errors