module Language.Hakaru.Observe where
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import qualified Language.Hakaru.Syntax.Prelude as P
import Language.Hakaru.Syntax.TypeOf
observe
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> abt '[] a
-> abt '[] ('HMeasure a)
observe m a = observeAST (LC_ m) (LC_ a)
freshenVarRe
:: ABT syn abt => Variable (a :: k) -> abt '[] (b :: k) -> Variable a
freshenVarRe x m = x {varID = nextFree m `max` nextBind m}
observeAST
:: (ABT Term abt)
=> LC_ abt ('HMeasure a)
-> LC_ abt a
-> abt '[] ('HMeasure a)
observeAST (LC_ m) (LC_ a) =
caseVarSyn m observeVar $ \ast ->
case ast of
Let_ :$ e1 :* e2 :* End ->
caseBind e2 $ \x e2' ->
let x' = freshenVarRe x m
e2'' = rename x x' e2'
in syn (Let_ :$ e1 :* bind x' (observe e2'' a) :* End)
MBind :$ e1 :* e2 :* End ->
caseBind e2 $ \x e2' ->
let x' = freshenVarRe x m
e2'' = rename x x' e2'
in syn (MBind :$ e1 :* bind x' (observe e2'' a) :* End)
Plate :$ e1 :* e2 :* End ->
caseBind e2 $ \x e2' ->
let a' = syn (ArrayOp_ (Index (sUnMeasure $ typeOf e2'))
:$ a
:* var x :* End)
in syn (Plate :$ e1 :* bind x (observe e2' a') :* End)
MeasureOp_ op :$ es -> observeMeasureOp op es a
_ -> error "observe can only be applied to measure primitives"
observeVar = error "observe can only be applied measure primitives"
observeMeasureOp
:: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> MeasureOp typs a
-> SArgs abt args
-> abt '[] a
-> abt '[] ('HMeasure a)
observeMeasureOp Normal = \(mu :* sd :* End) a ->
P.withWeight (P.densityNormal mu sd a) (P.dirac a)
observeMeasureOp Uniform = \(lo :* hi :* End) a ->
P.if_ (lo P.<= a P.&& a P.<= hi)
(P.withWeight (P.unsafeProb $ P.recip $ hi P.- lo) (P.dirac a))
(P.reject (SMeasure SReal))
observeMeasureOp _ = error "TODO{Observe:observeMeasureOp}"