module Language.Hakaru.CodeGen.Wrapper
( wrapProgram
, PrintConfig(..)
) where
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.TypeCheck
import Language.Hakaru.Syntax.TypeOf (typeOf)
import Language.Hakaru.Types.Sing
import Language.Hakaru.CodeGen.CodeGenMonad
import Language.Hakaru.CodeGen.Flatten
import Language.Hakaru.CodeGen.Types
import Language.Hakaru.CodeGen.AST
import Language.Hakaru.Types.DataKind (Hakaru(..))
import Control.Monad.State.Strict
import Prelude as P hiding (unlines,exp)
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
wrapProgram
:: TypedAST (TrivialABT Term)
-> Maybe String
-> PrintConfig
-> CodeGen ()
wrapProgram tast@(TypedAST typ _) mn pc =
do sequence_ . fmap (extDeclare . CPPExt) . header $ typ
baseCG
return ()
where baseCG = case (tast,mn) of
( TypedAST (SFun _ _) abt, Just name ) ->
do reserveName name
flattenTopLambda abt $ Ident name
( TypedAST (SFun _ _) abt, Nothing ) ->
genIdent' "fn" >>= \name ->
flattenTopLambda abt name
( TypedAST _ _, Just _ ) -> undefined
( TypedAST typ' abt, Nothing ) ->
mainFunction pc typ' abt
header :: Sing (a :: Hakaru) -> [Preprocessor]
header (SMeasure _) = fmap PPInclude ["time.h", "stdlib.h", "stdio.h", "math.h"]
header _ = fmap PPInclude ["stdlib.h", "stdio.h", "math.h"]
mainFunction
:: ABT Term abt
=> PrintConfig
-> Sing (a :: Hakaru)
-> abt '[] (a :: Hakaru)
-> CodeGen ()
mainFunction pc typ@(SMeasure t) abt =
let ident = Ident "measure"
funId = Ident "main"
mdataId = Ident "mdata"
in do reserveName "measure"
reserveName "mdata"
reserveName "main"
extDeclare . mdataStruct $ t
funCG CVoid ident [mdataPtrDeclaration t mdataId] $
do flattenABT abt (CVar mdataId)
putStat (CReturn Nothing)
printf pc typ (CVar ident)
putStat . CReturn . Just $ intE 0
!cg <- get
extDeclare . CFunDefExt $ functionDef SInt
funId
[]
(P.reverse $ declarations cg)
(P.reverse $ statements cg)
mainFunction pc typ abt =
let resId = Ident "result"
resE = CVar resId
funId = Ident "main"
in do reserveName "result"
reserveName "main"
declare typ resId
flattenABT abt resE
printf pc typ resE
putStat . CReturn . Just $ intE 0
cg <- get
extDeclare . CFunDefExt $ functionDef SInt
funId
[]
(P.reverse $ declarations cg)
(P.reverse $ statements cg)
data PrintConfig
= PrintConfig { showWeights :: Bool
, showProbInLog :: Bool
} deriving Show
printf
:: PrintConfig
-> Sing (a :: Hakaru)
-> CExpr
-> CodeGen ()
printf pc mt@(SMeasure t) sampleFunc =
case t of
_ -> do mId <- genIdent' "m"
declare mt mId
let mE = CVar mId
getSampleS = CExpr . Just $ CCall sampleFunc [address mE]
printSampleE = CExpr . Just
$ CCall (CVar . Ident $ "printf")
$ [ stringE $ printfText pc mt "\n"]
++ (if showWeights pc
then [ if showProbInLog pc
then mdataWeight mE
else exp $ mdataWeight mE ]
else [])
++ [ case t of
SProb -> if showProbInLog pc
then mdataSample mE
else exp $ mdataSample mE
_ -> mdataSample mE ]
wrapSampleFunc = CCompound $ [CBlockStat getSampleS
,CBlockStat $ CIf ((exp $ mdataWeight mE) .>. (floatE 0)) printSampleE Nothing]
putStat $ CWhile (intE 1) wrapSampleFunc False
printf pc (SArray t) arg =
do iterId <- genIdent' "it"
declare SInt iterId
let iter = CVar iterId
result = arg
dataPtr = CMember result (Ident "data") True
sizeVar = CMember result (Ident "size") True
cond = iter .<. sizeVar
inc = CUnary CPostIncOp iter
currInd = indirect (dataPtr .+. iter)
loopBody = putExprStat $ CCall (CVar . Ident $ "printf")
[ stringE $ printfText pc t " "
, currInd ]
putString "[ "
mkSequential
forCG (iter .=. (intE 0)) cond inc loopBody
putString "]\n"
where putString s = putExprStat $ CCall (CVar . Ident $ "printf")
[stringE s]
printf pc SProb arg =
putExprStat $ CCall (CVar . Ident $ "printf")
[ stringE $ printfText pc SProb "\n"
, if showProbInLog pc
then arg
else exp arg ]
printf pc typ arg =
putExprStat $ CCall (CVar . Ident $ "printf")
[ stringE $ printfText pc typ "\n"
, arg ]
printfText :: PrintConfig -> Sing (a :: Hakaru) -> (String -> String)
printfText _ SInt = \s -> "%d" ++ s
printfText _ SNat = \s -> "%d" ++ s
printfText c SProb = \s -> if showProbInLog c
then "exp(%.15f)" ++ s
else "%.15f" ++ s
printfText _ SReal = \s -> "%.17f" ++ s
printfText c (SMeasure t) = if showWeights c
then \s -> if showProbInLog c
then "exp(%.15f) " ++ printfText c t s
else "%.15f " ++ printfText c t s
else printfText c t
printfText c (SArray t) = printfText c t
printfText _ (SFun _ _) = id
printfText _ (SData _ _) = \s -> "TODO: printft datum" ++ s
flattenTopLambda
:: ABT Term abt
=> abt '[] a
-> Ident
-> CodeGen ()
flattenTopLambda abt name =
coalesceLambda abt $ \vars abt' ->
let varMs = foldMap11 (\v -> [mkVarDecl v =<< createIdent v]) vars
typ = typeOf abt'
in do argDecls <- sequence varMs
cg <- get
case typ of
_ -> do let m = do outId <- genIdent' "out"
declare (typeOf abt') outId
let outE = CVar outId
flattenABT abt' outE
putStat . CReturn . Just $ outE
(_,cg') = runState m $ cg { statements = []
, declarations = [] }
put $ cg' { statements = statements cg
, declarations = declarations cg }
extDeclare . CFunDefExt
$ functionDef typ name
argDecls
(P.reverse $ declarations cg')
(P.reverse $ statements cg')
where coalesceLambda
:: ABT Term abt
=> abt '[] a
-> ( forall (ys :: [Hakaru]) b
. List1 Variable ys -> abt '[] b -> r)
-> r
coalesceLambda abt_ k =
caseVarSyn abt_ (const (k Nil1 abt_)) $ \term ->
case term of
(Lam_ :$ body :* End) ->
caseBind body $ \v body' ->
coalesceLambda body' $ \vars body'' -> k (Cons1 v vars) body''
_ -> k Nil1 abt_
mkVarDecl :: Variable (a :: Hakaru) -> Ident -> CodeGen CDecl
mkVarDecl (Variable _ _ SInt) = return . typeDeclaration SInt
mkVarDecl (Variable _ _ SNat) = return . typeDeclaration SNat
mkVarDecl (Variable _ _ SProb) = return . typeDeclaration SProb
mkVarDecl (Variable _ _ SReal) = return . typeDeclaration SReal
mkVarDecl (Variable _ _ (SArray t)) = \i -> do extDeclare $ arrayStruct t
return $ arrayDeclaration t i
mkVarDecl (Variable _ _ d@(SData _ _)) = \i -> do extDeclare $ datumStruct d
return $ datumDeclaration d i
mkVarDecl v = error $ "flattenSCon.Lam_.mkVarDecl cannot handle vars of type " ++ show v