module Language.Hakaru.Expect
( normalize
, total
, expect
) where
import Prelude (($), (.), error, reverse)
import qualified Data.Text as Text
import Data.Functor ((<$>))
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import qualified Data.List.NonEmpty as NE
import Control.Monad
import Language.Hakaru.Syntax.IClasses (Some2(..), Functor11(..))
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Types.Sing
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumABT
import Language.Hakaru.Syntax.AST hiding (Expect)
import qualified Language.Hakaru.Syntax.AST as AST
import Language.Hakaru.Syntax.TypeOf (typeOf)
import qualified Language.Hakaru.Syntax.Prelude as P
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.ExpectMonad
#ifdef __TRACE_DISINTEGRATE__
import Prelude (show, (++))
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell (pretty)
import Language.Hakaru.Evaluation.Types (ppStatement)
import Debug.Trace (trace)
#endif
normalize
:: (ABT Term abt) => abt '[] ('HMeasure a) -> abt '[] ('HMeasure a)
normalize m = P.withWeight (P.recip $ total m) m
total :: (ABT Term abt) => abt '[] ('HMeasure a) -> abt '[] 'HProb
total m =
expect m . binder Text.empty (sUnMeasure $ typeOf m) $ \_ -> P.one
expect
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> abt '[a] 'HProb
-> abt '[] 'HProb
expect e f = runExpect (expectTerm e) f [Some2 e, Some2 f]
residualizeExpect
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> Expect abt (abt '[] a)
residualizeExpect e = do
x <- freshVar Text.empty (sUnMeasure $ typeOf e)
unsafePush (SStuff1 x (\c ->
syn (AST.Expect :$ e :* bind x c :* End)) [])
return $ var x
let_ :: (ABT Term abt) => abt '[] a -> abt '[a] b -> abt '[] b
let_ e f =
caseVarSyn e
(\x -> caseBind f $ \y f' -> subst y (var x) f')
(\_ -> syn (Let_ :$ e :* f :* End))
expectCase
:: (ABT Term abt)
=> abt '[] a
-> [Branch a abt ('HMeasure b)]
-> Expect abt (abt '[] b)
expectCase scrutinee bs = do
ctx <- Expect $ \c h -> c h (h {statements = []})
Expect $ \c h -> residualizeExpectListContext (c () h) ctx
gms <- T.for bs $ \(Branch pat body) ->
let (vars, body') = caseBinds body
in (\vars' ->
let rho = toAssocs1 vars (fmap11 var vars')
in GBranch pat vars' (expectTerm $ substs rho body')
) <$> freshenVars vars
Expect $ \c h ->
syn $ Case_ scrutinee
[ fromGBranch $ fmap (\m -> unExpect m c h) gm
| gm <- gms
]
#ifdef __TRACE_DISINTEGRATE__
getStatements :: Expect abt [Statement abt 'ExpectP]
getStatements = Expect $ \c h -> c (statements h) h
#endif
expectTerm
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> Expect abt (abt '[] a)
expectTerm e = do
#ifdef __TRACE_DISINTEGRATE__
ss <- getStatements
trace ("\n-- expectTerm --\n"
++ show (pretty_Statements_withTerm ss e)
++ "\n") $ return ()
#endif
w <- pureEvaluate e
case w of
Neutral e' ->
caseVarSyn e' (residualizeExpect . var) $ \t ->
case t of
Case_ e1 bs -> expectCase e1 bs
_ -> residualizeExpect e'
Head_ (WLiteral _) -> error "expect: the impossible happened"
Head_ (WCoerceTo _ _) -> error "expect: the impossible happened"
Head_ (WUnsafeFrom _ _) -> error "expect: the impossible happened"
Head_ (WMeasureOp o es) -> expectMeasureOp o es
Head_ (WDirac e1) -> return e1
Head_ (WMBind e1 e2) -> do
v1 <- expectTerm e1
expectTerm (let_ v1 e2)
Head_ (WPlate _ _) -> error "TODO: expect{Plate}"
Head_ (WChain _ _ _) -> error "TODO: expect{Chain}"
Head_ (WReject _) -> expectSuperpose []
Head_ (WSuperpose pes) -> expectSuperpose (NE.toList pes)
expectSuperpose
:: (ABT Term abt)
=> [(abt '[] 'HProb, abt '[] ('HMeasure a))]
-> Expect abt (abt '[] a)
expectSuperpose pes = do
#ifdef __TRACE_DISINTEGRATE__
ss <- getStatements
trace ("\n-- expectSuperpose --\n"
++ show (pretty_Statements_withTerm ss (syn $ Superpose_ (NE.fromList pes)))
++ "\n") $ return ()
#endif
emitExpectListContext
Expect $ \c h ->
P.sum [ p P.* unExpect (expectTerm e) c h | (p,e) <- pes]
emitExpectListContext :: forall abt. (ABT Term abt) => Expect abt ()
emitExpectListContext = do
ss <- Expect $ \c h -> c (statements h) (h {statements = []})
F.traverse_ step (reverse ss)
where
step :: Statement abt 'ExpectP -> Expect abt ()
step s =
#ifdef __TRACE_DISINTEGRATE__
trace ("\n-- emitExpectListContext: " ++ show (ppStatement 0 s)) $
#endif
case s of
SLet x body _ ->
Expect $ \c h ->
syn (Let_ :$ fromLazy body :* bind x (c () h) :* End)
SStuff0 f _ -> Expect $ \c h -> f (c () h)
SStuff1 _ f _ -> Expect $ \c h -> f (c () h)
pushIntegrate
:: (ABT Term abt)
=> abt '[] 'HReal
-> abt '[] 'HReal
-> Expect abt (Variable 'HReal)
pushIntegrate lo hi = do
x <- freshVar Text.empty SReal
unsafePush (SStuff1 x (\c ->
syn (Integrate :$ lo :* hi :* bind x c :* End)) [])
return x
pushSummate
:: (ABT Term abt, HDiscrete_ a, SingI a)
=> abt '[] a
-> abt '[] a
-> Expect abt (Variable a)
pushSummate lo hi = do
x <- freshVar Text.empty sing
unsafePush (SStuff1 x (\c ->
syn (Summate hDiscrete hSemiring
:$ lo :* hi :* bind x c :* End)) [])
return x
pushLet :: (ABT Term abt) => abt '[] a -> Expect abt (Variable a)
pushLet e =
caseVarSyn e return $ \_ -> do
x <- freshVar Text.empty (typeOf e)
unsafePush (SStuff1 x (\c ->
syn (Let_ :$ e :* bind x c :* End)) [])
return x
expectMeasureOp
:: forall abt typs args a
. (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> MeasureOp typs a
-> SArgs abt args
-> Expect abt (abt '[] a)
expectMeasureOp Lebesgue = \End ->
var <$> pushIntegrate P.negativeInfinity P.infinity
expectMeasureOp Counting = \End ->
var <$> pushSummate P.negativeInfinity P.infinity
expectMeasureOp Categorical = \(ps :* End) -> do
ps' <- var <$> pushLet ps
tot <- var <$> pushLet (P.summateV ps')
unsafePush (SStuff0 (\c -> P.if_ (P.zero P.< tot) c P.zero) [])
i <- freshVar Text.empty SNat
Expect $ \c h ->
P.summateV
(syn (Array_ (P.size ps') (bind i ((ps' P.! var i) P.* c (var i) h))))
P./ tot
expectMeasureOp Uniform = \(lo :* hi :* End) -> do
lo' <- var <$> pushLet lo
hi' <- var <$> pushLet hi
x <- var <$> pushIntegrate lo' hi'
unsafePush (SStuff0 (\c -> P.densityUniform lo' hi' x P.* c) [])
return x
expectMeasureOp Normal = \(mu :* sd :* End) -> do
x <- var <$> pushIntegrate P.negativeInfinity P.infinity
unsafePush (SStuff0 (\c -> P.densityNormal mu sd x P.* c) [])
return x
expectMeasureOp Poisson = \(l :* End) -> do
l' <- var <$> pushLet l
unsafePush (SStuff0 (\c -> P.if_ (P.zero P.< l') c P.zero) [])
x <- var <$> pushSummate P.zero P.infinity
unsafePush (SStuff0 (\c -> P.densityPoisson l' x P.* c) [])
return x
expectMeasureOp Gamma = \(shape :* scale :* End) -> do
x <- var <$> pushIntegrate P.zero P.infinity
x_ <- var <$> pushLet (P.unsafeProb x)
unsafePush (SStuff0 (\c -> P.densityGamma shape scale x_ P.* c) [])
return x_
expectMeasureOp Beta = \(a :* b :* End) -> do
x <- var <$> pushIntegrate P.zero P.one
x_ <- var <$> pushLet (P.unsafeProb x)
unsafePush (SStuff0 (\c -> P.densityBeta a b x_ P.* c) [])
return x_