module Language.Hakaru.Disintegrate
(
disintegrateWithVar
, disintegrate
, densityWithVar
, density
, observe
, determine
, perform
, atomize
, constrainValue
, constrainOutcome
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Data.Foldable (Foldable, foldMap)
import Data.Traversable (Traversable)
import Control.Applicative (Applicative(..))
#endif
import Control.Applicative (Alternative(..))
import Control.Monad ((<=<), guard)
import Data.Functor.Compose (Compose(..))
import qualified Data.Traversable as T
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as L
import qualified Data.Text as Text
import qualified Data.IntMap as IM
import Data.Sequence (Seq)
import qualified Data.Sequence as S
import Data.Proxy (KProxy(..))
import qualified Data.Set as Set (fromList)
import Data.Maybe (fromMaybe)
import Language.Hakaru.Syntax.IClasses
import Data.Number.Natural
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import qualified Language.Hakaru.Types.Coercion as C
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumCase (DatumEvaluator, MatchResult(..), matchBranches)
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.Lazy hiding (evaluate,update)
import Language.Hakaru.Evaluation.DisintegrationMonad
import qualified Language.Hakaru.Syntax.Prelude as P
import qualified Language.Hakaru.Expect as E
#ifdef __TRACE_DISINTEGRATE__
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell
import Debug.Trace (trace, traceM)
#endif
lam_ :: (ABT Term abt) => Variable a -> abt '[] b -> abt '[] (a ':-> b)
lam_ x e = syn (Lam_ :$ bind x e :* End)
disintegrateWithVar
:: (ABT Term abt)
=> Text.Text
-> Sing a
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateWithVar hint typ m =
let x = Variable hint (nextFreeOrBind m) typ
in map (lam_ x) . flip runDis [Some2 m, Some2 (var x)] $ do
ab <- perform m
#ifdef __TRACE_DISINTEGRATE__
ss <- getStatements
trace ("-- disintegrate: finished perform\n"
++ show (pretty_Statements ss PP.$+$ PP.sep(prettyPrec_ 11 ab))
++ "\n") $ return ()
#endif
(a,b) <- emitUnpair ab
#ifdef __TRACE_DISINTEGRATE__
trace ("-- disintegrate: finished emitUnpair: "
++ show (pretty a, pretty b)) $ return ()
#endif
constrainValue (var x) a
#ifdef __TRACE_DISINTEGRATE__
ss <- getStatements
locs <- getLocs
traceM ("-- disintegrate: finished constrainValue\n"
++ show (pretty_Statements ss) ++ "\n"
++ show (prettyLocs locs)
)
#endif
return b
disintegrate
:: (ABT Term abt)
=> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrate m =
disintegrateWithVar
Text.empty
(fst . sUnPair . sUnMeasure $ typeOf m)
m
densityWithVar
:: (ABT Term abt)
=> Text.Text
-> Sing a
-> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
densityWithVar hint typ m =
let x = Variable hint (nextFree m `max` nextBind m) typ
in (lam_ x . E.total) <$> observe m (var x)
density
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
density m =
densityWithVar
Text.empty
(sUnMeasure $ typeOf m)
m
observe
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> abt '[] a
-> [abt '[] ('HMeasure a)]
observe m x =
runDis (constrainOutcome x m >> return x) [Some2 m, Some2 x]
determine :: (ABT Term abt) => [abt '[] a] -> Maybe (abt '[] a)
determine [] = Nothing
determine (m:_) = Just m
firstM :: Functor f => (a -> f b) -> (a,c) -> f (b,c)
firstM f (x,y) = (\z -> (z, y)) <$> f x
evaluate_ :: (ABT Term abt) => TermEvaluator abt (Dis abt)
evaluate_ = evaluate perform evaluateCase
evaluate
:: forall abt m p
. (ABT Term abt)
=> MeasureEvaluator abt (Dis abt)
-> (TermEvaluator abt (Dis abt) -> CaseEvaluator abt (Dis abt))
-> TermEvaluator abt (Dis abt)
evaluate perform evaluateCase = goEvaluate
where
evaluateCase_ :: CaseEvaluator abt (Dis abt)
evaluateCase_ = evaluateCase goEvaluate
goEvaluate :: TermEvaluator abt (Dis abt)
goEvaluate e0 =
#ifdef __TRACE_DISINTEGRATE__
getIndices >>= \inds ->
trace ("-- goEvaluate: " ++ show (pretty e0)
++ "at " ++ show (ppInds inds)) $
#endif
caseVarSyn e0 (update perform goEvaluate) $ \t ->
case t of
Literal_ v -> return . Head_ $ WLiteral v
Datum_ d -> return . Head_ $ WDatum d
Empty_ typ -> return . Head_ $ WEmpty typ
Array_ e1 e2 -> return . Head_ $ WArray e1 e2
Lam_ :$ e1 :* End -> return . Head_ $ WLam e1
Dirac :$ e1 :* End -> return . Head_ $ WDirac e1
MBind :$ e1 :* e2 :* End -> return . Head_ $ WMBind e1 e2
Plate :$ e1 :* e2 :* End -> return . Head_ $ WPlate e1 e2
MeasureOp_ o :$ es -> return . Head_ $ WMeasureOp o es
Superpose_ pes -> return . Head_ $ WSuperpose pes
Reject_ typ -> return . Head_ $ WReject typ
Integrate :$ e1 :* e2 :* e3 :* End ->
return . Head_ $ WIntegrate e1 e2 e3
Summate h1 h2 :$ e1 :* e2 :* e3 :* End ->
return . Neutral $ syn t
App_ :$ e1 :* e2 :* End -> do
w1 <- goEvaluate e1
case w1 of
Neutral e1' -> return . Neutral $ P.app e1' e2
Head_ v1 -> evaluateApp v1
where
evaluateApp (WLam f) =
caseBind f $ \x f' -> do
i <- getIndices
push (SLet x (Thunk e2) i) f' goEvaluate
evaluateApp _ = error "evaluate{App_}: the impossible happened"
Let_ :$ e1 :* e2 :* End -> do
i <- getIndices
caseBind e2 $ \x e2' ->
push (SLet x (Thunk e1) i) e2' goEvaluate
CoerceTo_ c :$ e1 :* End -> C.coerceTo c <$> goEvaluate e1
UnsafeFrom_ c :$ e1 :* End -> C.coerceFrom c <$> goEvaluate e1
NaryOp_ o es -> evaluateNaryOp goEvaluate o es
ArrayOp_ o :$ es -> evaluateArrayOp goEvaluate o es
PrimOp_ o :$ es -> evaluatePrimOp goEvaluate o es
Expect :$ e1 :* e2 :* End ->
error "TODO: evaluate{Expect}: unclear how to handle this without cyclic dependencies"
Case_ e bs -> evaluateCase_ e bs
_ :$ _ -> error "evaluate: the impossible happened"
evaluateNaryOp
:: (ABT Term abt)
=> TermEvaluator abt (Dis abt)
-> NaryOp a
-> Seq (abt '[] a)
-> Dis abt (Whnf abt a)
evaluateNaryOp evaluate_ = \o es -> mainLoop o (evalOp o) S.empty es
where
mainLoop o op ws es =
case S.viewl es of
S.EmptyL -> return $
case S.viewl ws of
S.EmptyL -> identityElement o
w S.:< ws'
| S.null ws' -> w
| otherwise ->
Neutral . syn . NaryOp_ o $ fmap fromWhnf ws
e S.:< es' -> do
w <- evaluate_ e
case matchNaryOp o w of
Nothing -> mainLoop o op (snocLoop op ws w) es'
Just es2 -> mainLoop o op ws (es2 S.>< es')
snocLoop
:: (ABT syn abt)
=> (Head abt a -> Head abt a -> Head abt a)
-> Seq (Whnf abt a)
-> Whnf abt a
-> Seq (Whnf abt a)
snocLoop op ws w1 =
case S.viewr ws of
S.EmptyR -> S.singleton w1
ws' S.:> w2 ->
case (w1,w2) of
(Head_ v1, Head_ v2) -> snocLoop op ws' (Head_ (op v1 v2))
_ -> ws S.|> w1
matchNaryOp
:: (ABT Term abt)
=> NaryOp a
-> Whnf abt a
-> Maybe (Seq (abt '[] a))
matchNaryOp o w =
case w of
Head_ _ -> Nothing
Neutral e ->
caseVarSyn e (const Nothing) $ \t ->
case t of
NaryOp_ o' es | o' == o -> Just es
_ -> Nothing
identityElement :: (ABT Term abt) => NaryOp a -> Whnf abt a
identityElement o =
case o of
And -> Head_ (WDatum dTrue)
Or -> Head_ (WDatum dFalse)
Xor -> Head_ (WDatum dFalse)
Iff -> Head_ (WDatum dTrue)
Min _ -> Neutral (syn (NaryOp_ o S.empty))
Max _ -> Neutral (syn (NaryOp_ o S.empty))
Sum HSemiring_Nat -> Head_ (WLiteral (LNat 0))
Sum HSemiring_Int -> Head_ (WLiteral (LInt 0))
Sum HSemiring_Prob -> Head_ (WLiteral (LProb 0))
Sum HSemiring_Real -> Head_ (WLiteral (LReal 0))
Prod HSemiring_Nat -> Head_ (WLiteral (LNat 1))
Prod HSemiring_Int -> Head_ (WLiteral (LInt 1))
Prod HSemiring_Prob -> Head_ (WLiteral (LProb 1))
Prod HSemiring_Real -> Head_ (WLiteral (LReal 1))
evalOp
:: (ABT Term abt)
=> NaryOp a
-> Head abt a
-> Head abt a
-> Head abt a
evalOp And = \v1 v2 -> reflect (reify v1 && reify v2)
evalOp Or = \v1 v2 -> reflect (reify v1 || reify v2)
evalOp Xor = \v1 v2 -> reflect (reify v1 /= reify v2)
evalOp Iff = \v1 v2 -> reflect (reify v1 == reify v2)
evalOp (Min _) = error "TODO: evalOp{Min}"
evalOp (Max _) = error "TODO: evalOp{Max}"
evalOp (Sum theSemi) =
\(WLiteral v1) (WLiteral v2) -> WLiteral $ evalSum theSemi v1 v2
evalOp (Prod theSemi) =
\(WLiteral v1) (WLiteral v2) -> WLiteral $ evalProd theSemi v1 v2
evalSum, evalProd :: HSemiring a -> Literal a -> Literal a -> Literal a
evalSum HSemiring_Nat = \(LNat n1) (LNat n2) -> LNat (n1 + n2)
evalSum HSemiring_Int = \(LInt i1) (LInt i2) -> LInt (i1 + i2)
evalSum HSemiring_Prob = \(LProb p1) (LProb p2) -> LProb (p1 + p2)
evalSum HSemiring_Real = \(LReal r1) (LReal r2) -> LReal (r1 + r2)
evalProd HSemiring_Nat = \(LNat n1) (LNat n2) -> LNat (n1 * n2)
evalProd HSemiring_Int = \(LInt i1) (LInt i2) -> LInt (i1 * i2)
evalProd HSemiring_Prob = \(LProb p1) (LProb p2) -> LProb (p1 * p2)
evalProd HSemiring_Real = \(LReal r1) (LReal r2) -> LReal (r1 * r2)
isIndex :: (ABT Term abt) => Variable 'HNat -> Dis abt Bool
isIndex v = do inds <- getIndices
return $ v `elem` map indVar inds
indexArrayOp :: forall abt typs args a r
. ( ABT Term abt
, typs ~ UnLCs args, args ~ LCs typs )
=> ArrayOp typs a
-> SArgs abt args
-> TermEvaluator abt (Dis abt)
-> (abt '[] a -> Dis abt r)
-> (Head abt ('HArray a)
-> Variable 'HNat
-> Dis abt r)
-> (Term abt ('HArray a)
-> Dis abt r)
-> (abt '[] ('HArray a)
-> Dis abt r)
-> (abt '[] ('HArray a)
-> Variable 'HNat
-> Dis abt r)
-> Dis abt r
indexArrayOp o@(Index _) (e1 :* e2 :* End) evaluate_ kInd kArr kSyn kFree kMultiLoc = do
w1 <- evaluate_ e1
case w1 of
Head_ arr@(WArray _ b) -> caseBind b $ \x body ->
evalIndex (kInd . flip (rename x) body) (kArr arr)
Head_ (WEmpty _) -> error "TODO: indexArrayOp o (Empty_ :* _ :* End)"
Head_ _ -> error "indexArrayOp: unknown whnf of array type"
Neutral e1' -> flip (caseVarSyn e1') kSyn $ \x ->
do locs <- getLocs
case (lookupAssoc x locs) of
Nothing -> kFree e1'
Just (Loc _ _) -> error "indexArrayOp: impossible, we have a Neutral term"
Just (MultiLoc l js) ->
evalIndex ((kInd . var =<<) . mkLoc Text.empty l . flip extendLocInds js)
(kMultiLoc e1')
where
evalIndex :: (ABT Term abt)
=> (Variable 'HNat -> Dis abt r)
-> (Variable 'HNat -> Dis abt r)
-> Dis abt r
evalIndex thenCase elseCase = do
w2 <- evaluate_ e2
caseWhnf w2 (const bot) $ \term ->
flip (caseVarSyn term) (const bot) $ \v ->
do isInd <- isIndex v
if isInd then thenCase v else elseCase v
indexArrayOp _ _ _ _ _ _ _ _ = error "indexArrayOp called on incorrect ArrayOp"
evaluateArrayOp
:: ( ABT Term abt
, typs ~ UnLCs args, args ~ LCs typs)
=> TermEvaluator abt (Dis abt)
-> ArrayOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
evaluateArrayOp evaluate_ = go
where
go o@(Index _) = \args@(_ :* e2 :* End) ->
let returnIndex = return . Neutral . syn
in indexArrayOp o args
evaluate_
evaluate_
(\arr v -> returnIndex (ArrayOp_ o :$ fromHead arr :* var v :* End))
(\s -> returnIndex (ArrayOp_ o :$ syn s :* e2 :* End))
(\e1' -> returnIndex (ArrayOp_ o :$ e1' :* e2 :* End))
(\e1' v -> returnIndex (ArrayOp_ o :$ e1' :* var v :* End))
go o@(Size _) = \(e1 :* End) -> do
w1 <- evaluate_ e1
case w1 of
Neutral e1' -> return . Neutral $ syn (ArrayOp_ o :$ e1' :* End)
Head_ v1 ->
case head2array v1 of
WAEmpty -> return . Head_ $ WLiteral (LNat 0)
WAArray e3 _ -> evaluate_ e3
go (Reduce _) = \(e1 :* e2 :* e3 :* End) ->
error "TODO: evaluateArrayOp{Reduce}"
data ArrayHead :: ([Hakaru] -> Hakaru -> *) -> Hakaru -> * where
WAEmpty :: ArrayHead abt a
WAArray
:: !(abt '[] 'HNat)
-> !(abt '[ 'HNat] a)
-> ArrayHead abt a
head2array :: Head abt ('HArray a) -> ArrayHead abt a
head2array (WEmpty _) = WAEmpty
head2array (WArray e1 e2) = WAArray e1 e2
impl, diff, nand, nor :: Bool -> Bool -> Bool
impl x y = not x || y
diff x y = x && not y
nand x y = not (x && y)
nor x y = not (x || y)
evaluatePrimOp
:: forall abt p typs args a
. ( ABT Term abt
, typs ~ UnLCs args, args ~ LCs typs)
=> TermEvaluator abt (Dis abt)
-> PrimOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
evaluatePrimOp evaluate_ = go
where
neu1 :: forall b c
. (abt '[] b -> abt '[] c)
-> abt '[] b
-> Dis abt (Whnf abt c)
neu1 f e = (Neutral . f . fromWhnf) <$> evaluate_ e
neu2 :: forall b c d
. (abt '[] b -> abt '[] c -> abt '[] d)
-> abt '[] b
-> abt '[] c
-> Dis abt (Whnf abt d)
neu2 f e1 e2 = do e1' <- fromWhnf <$> evaluate_ e1
e2' <- fromWhnf <$> evaluate_ e2
return . Neutral $ f e1' e2'
rr1 :: forall b b' c c'
. (Interp b b', Interp c c')
=> (b' -> c')
-> (abt '[] b -> abt '[] c)
-> abt '[] b
-> Dis abt (Whnf abt c)
rr1 f' f e = do
w <- evaluate_ e
return $
case w of
Neutral e' -> Neutral $ f e'
Head_ v -> Head_ . reflect $ f' (reify v)
rr2 :: forall b b' c c' d d'
. (Interp b b', Interp c c', Interp d d')
=> (b' -> c' -> d')
-> (abt '[] b -> abt '[] c -> abt '[] d)
-> abt '[] b
-> abt '[] c
-> Dis abt (Whnf abt d)
rr2 f' f e1 e2 = do
w1 <- evaluate_ e1
w2 <- evaluate_ e2
return $
case w1 of
Neutral e1' -> Neutral $ f e1' (fromWhnf w2)
Head_ v1 ->
case w2 of
Neutral e2' -> Neutral $ f (fromWhnf w1) e2'
Head_ v2 -> Head_ . reflect $ f' (reify v1) (reify v2)
primOp2_
:: forall b c d
. PrimOp '[ b, c ] d -> abt '[] b -> abt '[] c -> abt '[] d
primOp2_ o e1 e2 = syn (PrimOp_ o :$ e1 :* e2 :* End)
go Not (e1 :* End) = rr1 not P.not e1
go Impl (e1 :* e2 :* End) = rr2 impl (primOp2_ Impl) e1 e2
go Diff (e1 :* e2 :* End) = rr2 diff (primOp2_ Diff) e1 e2
go Nand (e1 :* e2 :* End) = rr2 nand P.nand e1 e2
go Nor (e1 :* e2 :* End) = rr2 nor P.nor e1 e2
go Pi End = return $ Neutral P.pi
go Sin (e1 :* End) = neu1 P.sin e1
go Cos (e1 :* End) = neu1 P.cos e1
go Tan (e1 :* End) = neu1 P.tan e1
go Asin (e1 :* End) = neu1 P.asin e1
go Acos (e1 :* End) = neu1 P.acos e1
go Atan (e1 :* End) = neu1 P.atan e1
go Sinh (e1 :* End) = neu1 P.sinh e1
go Cosh (e1 :* End) = neu1 P.cosh e1
go Tanh (e1 :* End) = neu1 P.tanh e1
go Asinh (e1 :* End) = neu1 P.asinh e1
go Acosh (e1 :* End) = neu1 P.acosh e1
go Atanh (e1 :* End) = neu1 P.atanh e1
go RealPow (e1 :* e2 :* End) = neu2 (P.**) e1 e2
go Exp (e1 :* End) = neu1 P.exp e1
go Log (e1 :* End) = neu1 P.log e1
go (Infinity h) End =
case h of
HIntegrable_Nat -> return . Neutral $ P.primOp0_ (Infinity h)
HIntegrable_Prob -> return $ Neutral P.infinity
go GammaFunc (e1 :* End) = neu1 P.gammaFunc e1
go BetaFunc (e1 :* e2 :* End) = neu2 P.betaFunc e1 e2
go (Equal theEq) (e1 :* e2 :* End) = rrEqual theEq e1 e2
go (Less theOrd) (e1 :* e2 :* End) = rrLess theOrd e1 e2
go (NatPow theSemi) (e1 :* e2 :* End) =
case theSemi of
HSemiring_Nat -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
HSemiring_Int -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
HSemiring_Prob -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
HSemiring_Real -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
go (Negate theRing) (e1 :* End) =
case theRing of
HRing_Int -> rr1 negate P.negate e1
HRing_Real -> rr1 negate P.negate e1
go (Abs theRing) (e1 :* End) =
case theRing of
HRing_Int -> rr1 (unsafeNatural . abs) P.abs_ e1
HRing_Real -> rr1 (unsafeNonNegativeRational . abs) P.abs_ e1
go (Signum theRing) (e1 :* End) =
case theRing of
HRing_Int -> rr1 signum P.signum e1
HRing_Real -> rr1 signum P.signum e1
go (Recip theFractional) (e1 :* End) =
case theFractional of
HFractional_Prob -> rr1 recip P.recip e1
HFractional_Real -> rr1 recip P.recip e1
go (NatRoot theRadical) (e1 :* e2 :* End) =
case theRadical of
HRadical_Prob -> neu2 (flip P.thRootOf) e1 e2
go op _ = error $ "TODO: evaluatePrimOp{" ++ show op ++ "}"
rrEqual
:: forall b. HEq b -> abt '[] b -> abt '[] b -> Dis abt (Whnf abt HBool)
rrEqual theEq =
case theEq of
HEq_Nat -> rr2 (==) (P.==)
HEq_Int -> rr2 (==) (P.==)
HEq_Prob -> rr2 (==) (P.==)
HEq_Real -> rr2 (==) (P.==)
HEq_Array aEq -> error "TODO: rrEqual{HEq_Array}"
HEq_Bool -> rr2 (==) (P.==)
HEq_Unit -> rr2 (==) (P.==)
HEq_Pair aEq bEq ->
\e1 e2 -> do
w1 <- evaluate_ e1
w2 <- evaluate_ e2
case w1 of
Neutral e1' ->
return . Neutral
$ P.primOp2_ (Equal theEq) e1' (fromWhnf w2)
Head_ v1 ->
case w2 of
Neutral e2' ->
return . Neutral
$ P.primOp2_ (Equal theEq) (fromHead v1) e2'
Head_ v2 -> do
let (v1a, v1b) = reifyPair v1
let (v2a, v2b) = reifyPair v2
wa <- rrEqual aEq v1a v2a
wb <- rrEqual bEq v1b v2b
return $
case wa of
Neutral ea ->
case wb of
Neutral eb -> Neutral (ea P.&& eb)
Head_ vb
| reify vb -> wa
| otherwise -> Head_ $ WDatum dFalse
Head_ va
| reify va -> wb
| otherwise -> Head_ $ WDatum dFalse
HEq_Either aEq bEq -> error "TODO: rrEqual{HEq_Either}"
rrLess
:: forall b. HOrd b -> abt '[] b -> abt '[] b -> Dis abt (Whnf abt HBool)
rrLess theOrd =
case theOrd of
HOrd_Nat -> rr2 (<) (P.<)
HOrd_Int -> rr2 (<) (P.<)
HOrd_Prob -> rr2 (<) (P.<)
HOrd_Real -> rr2 (<) (P.<)
HOrd_Array aOrd -> error "TODO: rrLess{HOrd_Array}"
HOrd_Bool -> rr2 (<) (P.<)
HOrd_Unit -> rr2 (<) (P.<)
HOrd_Pair aOrd bOrd ->
\e1 e2 -> do
w1 <- evaluate_ e1
w2 <- evaluate_ e2
case w1 of
Neutral e1' ->
return . Neutral
$ P.primOp2_ (Less theOrd) e1' (fromWhnf w2)
Head_ v1 ->
case w2 of
Neutral e2' ->
return . Neutral
$ P.primOp2_ (Less theOrd) (fromHead v1) e2'
Head_ v2 -> do
let (v1a, v1b) = reifyPair v1
let (v2a, v2b) = reifyPair v2
error "TODO: rrLess{HOrd_Pair}"
HOrd_Either aOrd bOrd -> error "TODO: rrLess{HOrd_Either}"
update
:: forall abt
. (ABT Term abt)
=> MeasureEvaluator abt (Dis abt)
-> TermEvaluator abt (Dis abt)
-> VariableEvaluator abt (Dis abt)
update perform evaluate_ x =
do locs <- getLocs
maybe (return $ Neutral (var x)) lookForLoc (lookupAssoc x locs)
where lookForLoc (Loc l jxs) =
(maybe (freeLocError l) return =<<) . select l $ \s ->
case s of
SBind l' e ixs -> do
Refl <- varEq l l'
Just $ do
w <- withIndices ixs $ perform (caseLazy e fromWhnf id)
unsafePush (SLet l (Whnf_ w) ixs)
#ifdef __TRACE_DISINTEGRATE__
trace ("-- updated "
++ show (ppStatement 11 s)
++ " to "
++ show (ppStatement 11 (SLet l (Whnf_ w) ixs))
) $ return ()
#endif
let as = toAssocs $ zipWith Assoc (map indVar ixs) jxs
w' = renames as (fromWhnf w)
inds <- getIndices
withIndices inds $ return (fromMaybe (Neutral w') (toWhnf w'))
SLet l' e ixs -> do
Refl <- varEq l l'
Just $ do
w <- withIndices ixs $ caseLazy e return evaluate_
unsafePush (SLet l (Whnf_ w) ixs)
let as = toAssocs $ zipWith Assoc (map indVar ixs) jxs
w' = renames as (fromWhnf w)
inds <- getIndices
withIndices inds $ return (fromMaybe (Neutral w') (toWhnf w'))
SWeight _ _ -> Nothing
SGuard ls pat scrutinee i -> Just . return . Neutral $ var x
lookForLoc (MultiLoc l jxs) = return (Neutral $ var x)
evaluateCase
:: forall abt
. (ABT Term abt)
=> TermEvaluator abt (Dis abt)
-> CaseEvaluator abt (Dis abt)
evaluateCase evaluate_ = evaluateCase_
where
evaluateCase_ :: CaseEvaluator abt (Dis abt)
evaluateCase_ e bs =
defaultCaseEvaluator evaluate_ e bs
<|> evaluateBranches e bs
evaluateBranches :: CaseEvaluator abt (Dis abt)
evaluateBranches e = choose . map evaluateBranch
where
evaluateBranch (Branch pat body) =
let (vars,body') = caseBinds body
in getIndices >>= \i ->
push (SGuard vars pat (Thunk e) i) body' evaluate_
evaluateDatum :: (ABT Term abt) => DatumEvaluator (abt '[]) (Dis abt)
evaluateDatum e = viewWhnfDatum <$> evaluate_ e
perform :: forall abt. (ABT Term abt) => MeasureEvaluator abt (Dis abt)
perform = \e0 ->
#ifdef __TRACE_DISINTEGRATE__
getStatements >>= \ss ->
getLocs >>= \locs ->
getIndices >>= \inds ->
trace ("\n-- perform --\n"
++ "at " ++ show (ppInds inds) ++ "\n"
++ show (prettyLocs locs) ++ "\n"
++ show (pretty_Statements_withTerm ss e0)
++ "\n") $
#endif
caseVarSyn e0 performVar performTerm
where
performTerm :: forall a. Term abt ('HMeasure a) -> Dis abt (Whnf abt a)
performTerm (Dirac :$ e1 :* End) = evaluate_ e1
performTerm (MeasureOp_ o :$ es) = performMeasureOp o es
performTerm (MBind :$ e1 :* e2 :* End) =
caseBind e2 $ \x e2' -> do
inds <- getIndices
push (SBind x (Thunk e1) inds) e2' perform
performTerm (Plate :$ e1 :* e2 :* End) = do
x1 <- pushPlate e1 e2
return (Neutral (var x1))
performTerm (Superpose_ pes) = do
inds <- getIndices
if not (null inds) && L.length pes > 1 then bot else
emitFork_ (P.superpose . fmap ((,) P.one))
(fmap (\(p,e) -> push (SWeight (Thunk p) inds) e perform)
pes)
performTerm (Let_ :$ e1 :* e2 :* End) =
caseBind e2 $ \x e2' -> do
inds <- getIndices
push (SLet x (Thunk e1) inds) e2' perform
performTerm t0 = do
w <- evaluate_ (syn t0)
#ifdef __TRACE_DISINTEGRATE__
trace ("-- perform: finished evaluate, with:\n" ++ show (PP.sep(prettyPrec_ 11 w))) $ return ()
#endif
performWhnf w
performVar :: forall a. Variable ('HMeasure a) -> Dis abt (Whnf abt a)
performVar = performWhnf <=< update perform evaluate_
performWhnf
:: forall a. Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
performWhnf (Head_ v) = perform $ fromHead v
performWhnf (Neutral e) = (Neutral . var) <$> emitMBind e
performMeasureOp
:: forall typs args a
. (typs ~ UnLCs args, args ~ LCs typs)
=> MeasureOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
performMeasureOp = \o es -> nice o es <|> complete o es
where
nice
:: MeasureOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
nice o es = do
es' <- traverse21 atomizeCore es
x <- emitMBind $ syn (MeasureOp_ o :$ es')
return (Neutral $ var x)
complete
:: MeasureOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
complete Normal = \(mu :* sd :* End) -> do
x <- var <$> emitMBind P.lebesgue
pushWeight (P.densityNormal mu sd x)
return (Neutral x)
complete Uniform = \(lo :* hi :* End) -> do
x <- var <$> emitMBind P.lebesgue
pushGuard (lo P.< x P.&& x P.< hi)
pushWeight (P.densityUniform lo hi x)
return (Neutral x)
complete _ = \_ -> bot
pushWeight :: (ABT Term abt) => abt '[] 'HProb -> Dis abt ()
pushWeight w = do
inds <- getIndices
unsafePush $ SWeight (Thunk w) inds
pushGuard :: (ABT Term abt) => abt '[] HBool -> Dis abt ()
pushGuard b = do
inds <- getIndices
unsafePush $ SGuard Nil1 pTrue (Thunk b) inds
atomize :: (ABT Term abt) => TermEvaluator abt (Dis abt)
atomize e =
#ifdef __TRACE_DISINTEGRATE__
trace ("\n-- atomize --\n" ++ show (pretty e)) $
#endif
traverse21 atomizeCore =<< evaluate_ e
atomizeCore :: (ABT Term abt) => abt xs a -> Dis abt (abt xs a)
atomizeCore e = do
xs <- getHeapVars
if disjointVarSet xs (freeVars e)
then return e
else
let (ys, e') = caseBinds e
in (binds_ ys . fromWhnf) <$> atomize e'
where
disjointVarSet xs ys =
IM.null (IM.intersection (unVarSet xs) (unVarSet ys))
getHeapVars :: Dis abt (VarSet ('KProxy :: KProxy Hakaru))
getHeapVars =
Dis $ \_ c h -> c (foldMap statementVars (statements h)) h
constrainValue :: (ABT Term abt) => abt '[] a -> abt '[] a -> Dis abt ()
constrainValue v0 e0 =
#ifdef __TRACE_DISINTEGRATE__
getStatements >>= \ss ->
getLocs >>= \locs ->
getIndices >>= \inds ->
trace ("\n-- constrainValue: " ++ show (pretty v0) ++ "\n"
++ show (pretty_Statements_withTerm ss e0) ++ "\n"
++ "at " ++ show (ppInds inds) ++ "\n"
++ show (prettyLocs locs) ++ "\n"
) $
#endif
caseVarSyn e0 (constrainVariable v0) $ \t ->
case t of
Empty_ _ -> error "TODO: disintegrate arrays"
Array_ n e ->
caseBind e $ \x body -> do j <- freshInd n
let x' = indVar j
body' = rename x x' body
inds <- getIndices
withIndices (extendIndices j inds) $
constrainValue (v0 P.! (var x')) body'
ArrayOp_ o@(Index _) :$ args -> indexArrayOp o args
evaluate_
(constrainValue v0)
(const $ const bot)
(const bot)
(const bot)
(const $ const bot)
ArrayOp_ _ :$ _ -> error "TODO: disintegrate arrays"
Lam_ :$ _ :* End -> error "TODO: disintegrate lambdas"
App_ :$ _ :* _ :* End -> error "TODO: disintegrate lambdas"
Integrate :$ _ :* _ :* _ :* End ->
error "TODO: disintegrate integration"
Summate _ _ :$ _ :* _ :* _ :* End ->
error "TODO: disintegrate integration"
Literal_ v -> bot
Datum_ d -> constrainDatum v0 d
Dirac :$ _ :* End -> bot
MBind :$ _ :* _ :* End -> bot
MeasureOp_ o :$ es -> constrainValueMeasureOp v0 o es
Superpose_ pes -> bot
Reject_ _ -> bot
Let_ :$ e1 :* e2 :* End ->
caseBind e2 $ \x e2' ->
push (SLet x (Thunk e1) []) e2' (constrainValue v0)
CoerceTo_ c :$ e1 :* End ->
constrainValue (P.unsafeFrom_ c v0) e1
UnsafeFrom_ c :$ e1 :* End ->
constrainValue (P.coerceTo_ c v0) e1
NaryOp_ o es -> constrainNaryOp v0 o es
PrimOp_ o :$ es -> constrainPrimOp v0 o es
Expect :$ e1 :* e2 :* End -> error "TODO: constrainValue{Expect}"
Case_ e bs ->
do match <- matchBranches evaluateDatum e bs
case match of
Nothing ->
error "constrainValue{Case_}: nothing matched!"
Just GotStuck ->
constrainBranches v0 e bs
Just (Matched rho body) ->
pushes (toStatements rho) body (constrainValue v0)
<|> constrainBranches v0 e bs
_ :$ _ -> error "constrainValue: the impossible happened"
constrainBranches
:: (ABT Term abt)
=> abt '[] a
-> abt '[] b
-> [Branch b abt a]
-> Dis abt ()
constrainBranches v0 e = choose . map constrainBranch
where
constrainBranch (Branch pat body) =
let (vars,body') = caseBinds body
in push (SGuard vars pat (Thunk e) []) body' (constrainValue v0)
constrainDatum
:: (ABT Term abt) => abt '[] a -> Datum (abt '[]) a -> Dis abt ()
constrainDatum v0 d =
case patternOfDatum d of
PatternOfDatum pat es -> do
xs <- freshVars $ fmap11 (Hint Text.empty . typeOf) es
emit_ $ \body ->
syn $ Case_ v0
[ Branch pat (binds_ xs body)
, Branch PWild (P.reject $ (typeOf body))
]
constrainValues xs es
constrainValues
:: (ABT Term abt)
=> List1 Variable xs
-> List1 (abt '[]) xs
-> Dis abt ()
constrainValues (Cons1 x xs) (Cons1 e es) =
constrainValue (var x) e >> constrainValues xs es
constrainValues Nil1 Nil1 = return ()
constrainValues _ _ = error "constrainValues: the impossible happened"
data PatternOfDatum (ast :: Hakaru -> *) (a :: Hakaru) =
forall xs. PatternOfDatum
!(Pattern xs a)
!(List1 ast xs)
patternOfDatum :: Datum ast a -> PatternOfDatum ast a
patternOfDatum =
\(Datum hint _typ d) ->
podCode d $ \p es ->
PatternOfDatum (PDatum hint p) es
where
podCode
:: DatumCode xss ast a
-> (forall bs. PDatumCode xss bs a -> List1 ast bs -> r)
-> r
podCode (Inr d) k = podCode d $ \ p es -> k (PInr p) es
podCode (Inl d) k = podStruct d $ \ p es -> k (PInl p) es
podStruct
:: DatumStruct xs ast a
-> (forall bs. PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
podStruct (Et d1 d2) k =
podFun d1 $ \p1 es1 ->
podStruct d2 $ \p2 es2 ->
k (PEt p1 p2) (es1 `append1` es2)
podStruct Done k = k PDone Nil1
podFun
:: DatumFun x ast a
-> (forall bs. PDatumFun x bs a -> List1 ast bs -> r)
-> r
podFun (Konst e) k = k (PKonst PVar) (Cons1 e Nil1)
podFun (Ident e) k = k (PIdent PVar) (Cons1 e Nil1)
constrainVariable
:: (ABT Term abt) => abt '[] a -> Variable a -> Dis abt ()
constrainVariable v0 x =
do locs <- getLocs
maybe bot lookForLoc (lookupAssoc x locs)
where lookForLoc (Loc l jxs) =
let
permutes is js = length is == length js &&
Set.fromList is == Set.fromList (map indVar js)
in (maybe (freeLocError l) return =<<) . select l $ \s ->
case s of
SBind l' e ixs -> do
Refl <- varEq l l'
guard (length ixs == length jxs)
Just $ do
inds <- getIndices
guard (jxs `permutes` inds)
e' <- apply (zip ixs inds) (fromLazy e)
constrainOutcome v0 e'
unsafePush (SLet l (Whnf_ (Neutral v0)) inds)
SLet l' e ixs -> do
Refl <- varEq l l'
guard (length ixs == length jxs)
Just $ do
inds <- getIndices
guard (jxs `permutes` inds)
e' <- apply (zip ixs inds) (fromLazy e)
constrainValue v0 e'
unsafePush (SLet l (Whnf_ (Neutral v0)) inds)
SWeight _ _ -> Nothing
SGuard ls' pat scrutinee i -> error "TODO: constrainVariable{SGuard}"
lookForLoc (MultiLoc l jxs) = do
#ifdef __TRACE_DISINTEGRATE__
traceM $ "looking for MultiLoc: " ++ show (prettyLoc (MultiLoc l jxs))
#endif
n <- sizeInnermostInd l
j <- freshInd n
x' <- mkLoc Text.empty l (extendLocInds (indVar j) jxs)
inds <- getIndices
withIndices (extendIndices j inds) $
constrainValue (v0 P.! (var $ indVar j)) (var x')
constrainValueMeasureOp
:: forall abt typs args a
. (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] ('HMeasure a)
-> MeasureOp typs a
-> SArgs abt args
-> Dis abt ()
constrainValueMeasureOp v0 = go
where
go :: MeasureOp typs a -> SArgs abt args -> Dis abt ()
go Lebesgue = \End -> bot
go Counting = \End -> bot
go Categorical = \(e1 :* End) ->
constrainValue v0 (P.categorical e1)
go Uniform = \(e1 :* e2 :* End) ->
constrainValue v0 (P.uniform' e1 e2)
go Normal = \(e1 :* e2 :* End) ->
constrainValue v0 (P.normal' e1 e2)
go Poisson = \(e1 :* End) ->
constrainValue v0 (P.poisson' e1)
go Gamma = \(e1 :* e2 :* End) ->
constrainValue v0 (P.gamma' e1 e2)
go Beta = \(e1 :* e2 :* End) ->
constrainValue v0 (P.beta' e1 e2)
constrainNaryOp
:: (ABT Term abt)
=> abt '[] a
-> NaryOp a
-> Seq (abt '[] a)
-> Dis abt ()
constrainNaryOp v0 o =
case o of
Sum theSemi ->
lubSeq $ \es1 e es2 -> do
u <- atomize $ syn (NaryOp_ (Sum theSemi) (es1 S.>< es2))
v <- evaluate_ $ P.unsafeMinus_ theSemi v0 (fromWhnf u)
constrainValue (fromWhnf v) e
Prod theSemi ->
lubSeq $ \es1 e es2 -> do
u <- atomize $ syn (NaryOp_ (Prod theSemi) (es1 S.>< es2))
let u' = fromWhnf u
emitWeight $ P.recip (toProb_abs theSemi u')
v <- evaluate_ $ P.unsafeDiv_ theSemi v0 u'
constrainValue (fromWhnf v) e
Max theOrd ->
chooseSeq $ \es1 e es2 -> do
u <- atomize $ syn (NaryOp_ (Max theOrd) (es1 S.>< es2))
emitGuard $ P.primOp2_ (Less theOrd) (fromWhnf u) v0
constrainValue v0 e
_ -> error $ "TODO: constrainNaryOp{" ++ show o ++ "}"
toProb_abs :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] 'HProb
toProb_abs HSemiring_Nat = P.nat2prob
toProb_abs HSemiring_Int = P.nat2prob . P.abs_
toProb_abs HSemiring_Prob = id
toProb_abs HSemiring_Real = P.abs_
lubSeq :: (Alternative m) => (Seq a -> a -> Seq a -> m b) -> Seq a -> m b
lubSeq f = go S.empty
where
go xs ys =
case S.viewl ys of
S.EmptyL -> empty
y S.:< ys' -> f xs y ys' <|> go (xs S.|> y) ys'
chooseSeq :: (ABT Term abt)
=> (Seq a -> a -> Seq a -> Dis abt b)
-> Seq a
-> Dis abt b
chooseSeq f = choose . go S.empty
where
go xs ys =
case S.viewl ys of
S.EmptyL -> []
y S.:< ys' -> f xs y ys' : go (xs S.|> y) ys'
constrainPrimOp
:: forall abt typs args a
. (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] a
-> PrimOp typs a
-> SArgs abt args
-> Dis abt ()
constrainPrimOp v0 = go
where
error_TODO op = error $ "TODO: constrainPrimOp{" ++ op ++"}"
go :: PrimOp typs a -> SArgs abt args -> Dis abt ()
go Not = \(e1 :* End) -> error_TODO "Not"
go Impl = \(e1 :* e2 :* End) -> error_TODO "Impl"
go Diff = \(e1 :* e2 :* End) -> error_TODO "Diff"
go Nand = \(e1 :* e2 :* End) -> error_TODO "Nand"
go Nor = \(e1 :* e2 :* End) -> error_TODO "Nor"
go Pi = \End -> bot
go Sin = \(e1 :* End) -> do
x0 <- emitLet' v0
n <- var <$> emitMBind P.counting
let tau_n = P.real_ 2 P.* P.fromInt n P.* P.pi
emitGuard (P.negate P.one P.< x0 P.&& x0 P.< P.one)
v <- var <$> emitSuperpose
[ P.dirac (tau_n P.+ P.asin x0)
, P.dirac (tau_n P.+ P.pi P.- P.asin x0)
]
emitWeight
. P.recip
. P.sqrt
. P.unsafeProb
$ (P.one P.- x0 P.^ P.nat_ 2)
constrainValue v e1
go Cos = \(e1 :* End) -> do
x0 <- emitLet' v0
n <- var <$> emitMBind P.counting
let tau_n = P.real_ 2 P.* P.fromInt n P.* P.pi
emitGuard (P.negate P.one P.< x0 P.&& x0 P.< P.one)
r <- emitLet' (tau_n P.+ P.acos x0)
v <- var <$> emitSuperpose [P.dirac r, P.dirac (r P.+ P.pi)]
emitWeight
. P.recip
. P.sqrt
. P.unsafeProb
$ (P.one P.- x0 P.^ P.nat_ 2)
constrainValue v e1
go Tan = \(e1 :* End) -> do
x0 <- emitLet' v0
n <- var <$> emitMBind P.counting
r <- emitLet' (P.fromInt n P.* P.pi P.+ P.atan x0)
emitWeight $ P.recip (P.one P.+ P.square x0)
constrainValue r e1
go Asin = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight $ P.unsafeProb (P.cos x0)
constrainValue (P.sin x0) e1
go Acos = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight $ P.unsafeProb (P.sin x0)
constrainValue (P.cos x0) e1
go Atan = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight $ P.recip (P.unsafeProb (P.cos x0 P.^ P.nat_ 2))
constrainValue (P.tan x0) e1
go Sinh = \(e1 :* End) -> error_TODO "Sinh"
go Cosh = \(e1 :* End) -> error_TODO "Cosh"
go Tanh = \(e1 :* End) -> error_TODO "Tanh"
go Asinh = \(e1 :* End) -> error_TODO "Asinh"
go Acosh = \(e1 :* End) -> error_TODO "Acosh"
go Atanh = \(e1 :* End) -> error_TODO "Atanh"
go RealPow = \(e1 :* e2 :* End) ->
do
u <- emitLet' v0
let w = P.recip (u P.* P.unsafeProb (P.abs (P.log e1)))
emitWeight w
constrainValue (P.log u P./ P.log e1) e2
<|> do
u <- emitLet' v0
let ex = v0 P.** P.recip e2
let w = P.abs (P.fromProb ex P./ (e2 P.* P.fromProb u))
emitWeight $ P.unsafeProb w
constrainValue ex e1
go Exp = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight (P.recip x0)
constrainValue (P.log x0) e1
go Log = \(e1 :* End) -> do
exp_x0 <- emitLet' (P.exp v0)
emitWeight exp_x0
constrainValue exp_x0 e1
go (Infinity _) = \End -> error_TODO "Infinity"
go GammaFunc = \(e1 :* End) -> error_TODO "GammaFunc"
go BetaFunc = \(e1 :* e2 :* End) -> error_TODO "BetaFunc"
go (Equal theOrd) = \(e1 :* e2 :* End) -> error_TODO "Equal"
go (Less theOrd) = \(e1 :* e2 :* End) -> error_TODO "Less"
go (NatPow theSemi) = \(e1 :* e2 :* End) -> error_TODO "NatPow"
go (Negate theRing) = \(e1 :* End) ->
let negate_v0 = syn (PrimOp_ (Negate theRing) :$ v0 :* End)
in constrainValue negate_v0 e1
go (Abs theRing) = \(e1 :* End) -> do
let theSemi = hSemiring_HRing theRing
theOrd =
case theRing of
HRing_Int -> HOrd_Int
HRing_Real -> HOrd_Real
theEq = hEq_HOrd theOrd
signed = C.singletonCoercion (C.Signed theRing)
zero = P.zero_ theSemi
lt = P.primOp2_ $ Less theOrd
eq = P.primOp2_ $ Equal theEq
neg = P.primOp1_ $ Negate theRing
x0 <- emitLet' (P.coerceTo_ signed v0)
v <- var <$> emitMBind
(P.if_ (lt zero x0)
(P.dirac x0 P.<|> P.dirac (neg x0))
(P.if_ (eq zero x0)
(P.dirac zero)
(P.reject . SMeasure $ typeOf zero)))
constrainValue v e1
go (Signum theRing) = \(e1 :* End) ->
case theRing of
HRing_Real -> bot
HRing_Int -> do
x <- var <$> emitMBind P.counting
emitGuard $ P.signum x P.== v0
constrainValue x e1
go (Recip theFractional) = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight
. P.recip
. P.unsafeProbFraction_ theFractional
$ square (hSemiring_HFractional theFractional) x0
constrainValue (P.primOp1_ (Recip theFractional) x0) e1
go (NatRoot theRadical) = \(e1 :* e2 :* End) ->
case theRadical of
HRadical_Prob -> do
x0 <- emitLet' v0
u2 <- fromWhnf <$> atomize e2
emitWeight (P.nat2prob u2 P.* x0)
constrainValue (x0 P.^ u2) e1
go (Erf theContinuous) = \(e1 :* End) ->
error "TODO: constrainPrimOp: need InvErf to disintegrate Erf"
square :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] a
square theSemiring e =
syn (PrimOp_ (NatPow theSemiring) :$ e :* P.nat_ 2 :* End)
constrainOutcome
:: forall abt a
. (ABT Term abt)
=> abt '[] a
-> abt '[] ('HMeasure a)
-> Dis abt ()
constrainOutcome v0 e0 =
#ifdef __TRACE_DISINTEGRATE__
getLocs >>= \locs ->
getIndices >>= \inds ->
trace (
let s = "-- constrainOutcome"
in "\n" ++ s ++ ": "
++ show (pretty v0)
++ "\n" ++ replicate (length s) ' ' ++ ": "
++ show (pretty e0) ++ "\n"
++ "at " ++ show (ppInds inds) ++ "\n"
++ show (prettyLocs locs)
) $
#endif
do w0 <- evaluate_ e0
case w0 of
Neutral _ -> bot
Head_ v -> go v
where
impossible = error "constrainOutcome: the impossible happened"
go :: Head abt ('HMeasure a) -> Dis abt ()
go (WLiteral _) = impossible
go (WCoerceTo _ _) = impossible
go (WUnsafeFrom _ _) = impossible
go (WMeasureOp o es) = constrainOutcomeMeasureOp v0 o es
go (WDirac e1) = constrainValue v0 e1
go (WMBind e1 e2) =
caseBind e2 $ \x e2' -> do
i <- getIndices
push (SBind x (Thunk e1) i) e2' (constrainOutcome v0)
go (WPlate e1 e2) = do
x' <- pushPlate e1 e2
constrainValue v0 (var x')
go (WChain e1 e2 e3) = error "TODO: constrainOutcome{Chain}"
go (WReject typ) = emit_ $ \m -> P.reject (typeOf m)
go (WSuperpose pes) = do
i <- getIndices
if not (null i) && L.length pes > 1 then bot else
emitFork_ (P.superpose . fmap ((,) P.one))
(fmap (\(p,e) -> push (SWeight (Thunk p) i) e (constrainOutcome v0))
pes)
constrainOutcomeMeasureOp
:: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] a
-> MeasureOp typs a
-> SArgs abt args
-> Dis abt ()
constrainOutcomeMeasureOp v0 = go
where
go Lebesgue = \End -> return ()
go Counting = \End -> return ()
go Categorical = \(e1 :* End) -> do
pushWeight (P.densityCategorical e1 v0)
go Uniform = \(lo :* hi :* End) -> do
v0' <- emitLet' v0
pushGuard (lo P.<= v0' P.&& v0' P.<= hi)
pushWeight (P.densityUniform lo hi v0')
go Normal = \(mu :* sd :* End) -> do
pushWeight (P.densityNormal mu sd v0)
go Poisson = \(e1 :* End) -> do
v0' <- emitLet' v0
pushGuard (P.nat_ 0 P.<= v0' P.&& P.prob_ 0 P.< e1)
pushWeight (P.densityPoisson e1 v0')
go Gamma = \(e1 :* e2 :* End) -> do
v0' <- emitLet' v0
pushGuard (P.prob_ 0 P.< v0' P.&&
P.prob_ 0 P.< e1 P.&&
P.prob_ 0 P.< e2)
pushWeight (P.densityGamma e1 e2 v0')
go Beta = \(e1 :* e2 :* End) -> do
v0' <- emitLet' v0
pushGuard (P.prob_ 0 P.<= v0' P.&&
P.prob_ 1 P.>= v0' P.&&
P.prob_ 0 P.< e1 P.&&
P.prob_ 0 P.< e2)
pushWeight (P.densityBeta e1 e2 v0')