{-# LANGUAGE TypeOperators
           , KindSignatures
           , DataKinds
           , TypeFamilies
           , GADTs
           , FlexibleInstances
           , NoImplicitPrelude
           , ScopedTypeVariables
           , FlexibleContexts
           , Rank2Types
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2016.04.22
-- |
-- Module      :  Language.Hakaru.Syntax.Prelude
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- A replacement for Haskell's Prelude, using the familiar symbols
-- in order to construct 'AST's and 'ABT's. This is only necessary
-- if we want to use Hakaru as an embedded language in Haskell, but
-- it also provides some examples of how to use the infrastructure.
--
-- TODO: is there a way to get rid of the need to specify @'[]@ everywhere in here? Some sort of distinction between the Var vs the Open parts of View?
----------------------------------------------------------------
module Language.Hakaru.Syntax.Prelude
    (
    -- * Basic syntax
    -- ** Types and coercions
      ann_, triv, memo
    , coerceTo_, fromProb, nat2int, nat2prob, fromInt, nat2real
    , unsafeFrom_, unsafeProb, unsafeProbFraction, unsafeProbFraction_, unsafeProbSemiring, unsafeProbSemiring_
    -- ** Numeric literals
    , literal_, nat_, int_, prob_, real_
    , fromRational, half, third
    -- ** Booleans
    , true, false, bool_, if_
    , not, (&&), and, (||), or, nand, nor
    -- ** Equality and ordering
    , (==), (/=), (<), (<=), (>), (>=), min, minimum, max, maximum
    -- ** Semirings
    , zero, zero_, one, one_, (+), sum, (*), prod, (^), square
    , unsafeMinusNat, unsafeMinusProb, unsafeMinus, unsafeMinus_
    , unsafeDiv, unsafeDiv_
    -- ** Rings
    , (-), negate, negative, abs, abs_, signum
    -- ** Fractional
    , (/), recip, (^^)
    -- ** Radical
    , sqrt, thRootOf
    -- ** Integration
    , integrate, summate, product
    -- ** Continuous
    , RealProb(..), Integrable(..)
    , betaFunc
    , log, logBase
    , negativeInfinity
    -- *** Trig
    , sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh
    
    -- * Measures
    -- ** Abstract nonsense
    , dirac, (<$>), (<*>), (<*), (*>), (>>=), (>>), bindx, liftM2
    -- ** Linear operators
    , superpose, (<|>)
    , weight, withWeight, weightedDirac
    , reject, guard, withGuard
    -- ** Measure operators
    -- | When two versions of the same operator are given, the one without the prime builds an AST using the built-in operator, whereas the one with the prime is a default definition in terms of more primitive measure operators.
    , lebesgue
    , counting
    , densityCategorical, categorical, categorical'
    , densityUniform, uniform, uniform'
    , densityNormal, normal, normal'
    , densityPoisson, poisson, poisson'
    , densityGamma, gamma, gamma'
    , densityBeta, beta, beta', beta''
    , plateWithVar, plate, plate'
    , chain, chain'
    , invgamma
    , exponential
    , chi2
    , cauchy
    , laplace
    , studentT
    , weibull
    , bern
    , mix
    , binomial
    , negativeBinomial
    , geometric
    , multinomial
    , dirichlet

    -- * Data types (other than booleans)
    , datum_
    -- * Case and Branch
    , case_, branch
    -- ** HUnit
    , unit
    -- ** HPair
    , pair, pair_, unpair, fst, snd, swap
    -- ** HEither
    , left, right, uneither
    -- ** HMaybe
    , nothing, just, maybe, unmaybe
    -- ** HList
    , nil, cons, list

    -- * Lambda calculus
    , lam, lamWithVar, let_
    , app, app2, app3

    -- * Arrays
    , empty, arrayWithVar, array, (!), size, reduce
    , sumV, summateV, appendV, mapV, mapWithIndex, normalizeV, constV, unitV, zipWithV

    -- * Implementation details
    , primOp0_, primOp1_, primOp2_, primOp3_
    , arrayOp0_, arrayOp1_, arrayOp2_, arrayOp3_
    , measure0_, measure1_, measure2_
    , unsafeNaryOp_, naryOp_withIdentity, naryOp2_
    ) where

-- TODO: implement and use Prelude's fromInteger and fromRational, so we can use numeric literals!
import Prelude (Maybe(..), Bool(..), Integer, Rational, ($), flip, const, error)
import qualified Prelude
import           Data.Sequence       (Seq)
import qualified Data.Sequence       as Seq
import qualified Data.Text           as Text
import           Data.List.NonEmpty  (NonEmpty(..))
import qualified Data.List.NonEmpty  as L
import           Data.Semigroup      (Semigroup(..))
import           Control.Category    (Category(..))

import Data.Number.Natural
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing(..), SingI(sing), sUnPair, sUnEither, sUnMaybe, sUnMeasure, sUnArray)
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.ABT hiding (View(..))

----------------------------------------------------------------
----- Helper combinators for defining our EDSL
{-
Below we implement a lot of simple optimizations; however, these
optimizations only apply if the client uses the type class methods
to produce the AST. We should implement a stand-alone function which
performs these sorts of optimizations, as a program transformation.
-}
-- TODO: constant propogation

-- TODO: NBE to get rid of administrative redexes.
app :: (ABT Term abt) => abt '[] (a ':-> b) -> abt '[] a -> abt '[] b
app e1 e2 = syn (App_ :$ e1 :* e2 :* End)

app2 :: (ABT Term abt) => abt '[] (a ':-> b ':-> c) -> abt '[] a -> abt '[] b -> abt '[] c
app2 = (app .) . app

app3 :: (ABT Term abt) => abt '[] (a ':-> b ':-> c ':-> d) -> abt '[] a -> abt '[] b -> abt '[] c -> abt '[] d
app3 = (app2 .) . app

triv :: TrivialABT Term '[] a -> TrivialABT Term '[] a
triv = id

memo :: MemoizedABT Term '[] a -> MemoizedABT Term '[] a
memo = id

primOp0_ :: (ABT Term abt) => PrimOp '[] a -> abt '[] a
primOp0_ o = syn (PrimOp_ o :$ End)

primOp1_
    :: (ABT Term abt)
    => PrimOp '[ a ] b
    -> abt '[] a -> abt '[] b
primOp1_ o e1 = syn (PrimOp_ o :$ e1 :* End)

primOp2_
    :: (ABT Term abt)
    => PrimOp '[ a, b ] c
    -> abt '[] a -> abt '[] b -> abt '[] c
primOp2_ o e1 e2 = syn (PrimOp_ o :$ e1 :* e2 :* End)

primOp3_
    :: (ABT Term abt)
    => PrimOp '[ a, b, c ] d
    -> abt '[] a -> abt '[] b -> abt '[] c -> abt '[] d
primOp3_ o e1 e2 e3 = syn (PrimOp_ o :$ e1 :* e2 :* e3 :* End)

arrayOp0_ :: (ABT Term abt) => ArrayOp '[] a -> abt '[] a
arrayOp0_ o = syn (ArrayOp_ o :$ End)

arrayOp1_
    :: (ABT Term abt)
    => ArrayOp '[ a ] b
    -> abt '[] a -> abt '[] b
arrayOp1_ o e1 = syn (ArrayOp_ o :$ e1 :* End)

arrayOp2_
    :: (ABT Term abt)
    => ArrayOp '[ a, b ] c
    -> abt '[] a -> abt '[] b -> abt '[] c
arrayOp2_ o e1 e2 = syn (ArrayOp_ o :$ e1 :* e2 :* End)

arrayOp3_
    :: (ABT Term abt)
    => ArrayOp '[ a, b, c ] d
    -> abt '[] a -> abt '[] b -> abt '[] c -> abt '[] d
arrayOp3_ o e1 e2 e3 = syn (ArrayOp_ o :$ e1 :* e2 :* e3 :* End)

measure0_ :: (ABT Term abt) => MeasureOp '[] a -> abt '[] ('HMeasure a)
measure0_ o = syn (MeasureOp_ o :$ End)

measure1_
    :: (ABT Term abt)
    => MeasureOp '[ a ] b
    -> abt '[] a -> abt '[] ('HMeasure b)
measure1_ o e1 = syn (MeasureOp_ o :$ e1 :* End)

measure2_
    :: (ABT Term abt)
    => MeasureOp '[ a, b ] c
    -> abt '[] a -> abt '[] b -> abt '[] ('HMeasure c)
measure2_ o e1 e2 = syn (MeasureOp_ o :$ e1 :* e2 :* End)


-- N.B., we don't take advantage of commutativity, for more predictable
-- AST outputs. However, that means we can end up being slow...
--
-- N.B., we also don't try to eliminate the identity elements or
-- do cancellations because (a) it's undecidable in general, and
-- (b) that's prolly better handled as a post-processing simplification
-- step
--
-- TODO: generalize these two from [] to Foldable?

-- | Apply an n-ary operator to a list. This smart constructor will
-- flatten nested calls to the same operator. And if there is exactly
-- one element in the flattened sequence, then it will remove the
-- 'NaryOp_' node from the AST.
--
-- N.B., if the flattened sequence is empty, this smart constructor
-- will return an AST which applies the operator to the empty
-- sequence; which may or may not be unsafe. If the operator has
-- an identity element, then it's fine (operating on the empty
-- sequence evaluates to the identity element). However, if the
-- operator doesn't have an identity, then the generated code will
-- error whenever we attempt to run it.
unsafeNaryOp_ :: (ABT Term abt) => NaryOp a -> [abt '[] a] -> abt '[] a
unsafeNaryOp_ o = naryOp_withIdentity o (syn $ NaryOp_ o Seq.empty)


-- | A variant of 'unsafeNaryOp_' which will replace operating over
-- the empty sequence with a specified identity element. The produced
-- AST has the same semantics, we're just preemptively
-- evaluating\/simplifying the 'NaryOp_' node of the AST.
--
-- N.B., this function does not simplify away the identity element
-- if it exists in the flattened sequence! We should add that in
-- the future.
naryOp_withIdentity
    :: (ABT Term abt) => NaryOp a -> abt '[] a -> [abt '[] a] -> abt '[] a
naryOp_withIdentity o i = go Seq.empty
    where
    go es [] =
        case Seq.viewl es of
        Seq.EmptyL   -> i
        e Seq.:< es' ->
            case Seq.viewl es' of
            Seq.EmptyL -> e
            _          -> syn $ NaryOp_ o es
    go es (e:es') =
        case matchNaryOp o e of
        Nothing   -> go (es Seq.|> e)    es'
        Just es'' -> go (es Seq.>< es'') es'


-- TODO: is this actually worth breaking out, performance-wise? Or should we simply use:
-- > naryOp2_ o x y = unsafeNaryOp_ o [x,y]
naryOp2_
    :: (ABT Term abt) => NaryOp a -> abt '[] a -> abt '[] a -> abt '[] a
naryOp2_ o x y =
    case (matchNaryOp o x, matchNaryOp o y) of
    (Just xs, Just ys) -> syn . NaryOp_ o $ xs Seq.>< ys
    (Just xs, Nothing) -> syn . NaryOp_ o $ xs Seq.|> y
    (Nothing, Just ys) -> syn . NaryOp_ o $ x  Seq.<| ys
    (Nothing, Nothing) -> syn . NaryOp_ o $ x  Seq.<| Seq.singleton y


matchNaryOp
    :: (ABT Term abt) => NaryOp a -> abt '[] a -> Maybe (Seq (abt '[] a))
matchNaryOp o e =
    caseVarSyn e
        (const Nothing)
        $ \t ->
            case t of
            NaryOp_ o' xs | o' Prelude.== o -> Just xs
            _ -> Nothing


----------------------------------------------------------------
----------------------------------------------------------------
----- Now for the actual EDSL

{-
infixr 9 `pair`

infixr 1 =<<
infixr 1 <=<, >=>
infixr 9 .
infixr 0 $
-}

infixl 1 >>=, >>
infixr 2 ||
infixr 3 &&
infix  4 ==, /=, <, <=, >, >=
infixl 4 <$>, <*>, <*, *> -- <$
infixl 6 +, -
infixl 7 *, /
infixr 8 ^, ^^, **
-- infixl9 is the default when things are unspecified
infixl 9 !, `app`, `thRootOf`

-- TODO: some infix notation reminiscent of \"::\"
-- TODO: actually do something with the type argument?
ann_ :: (ABT Term abt) => Sing a -> abt '[] a -> abt '[] a
ann_ _ e = e

coerceTo_ :: (ABT Term abt) => Coercion a b -> abt '[] a -> abt '[] b
coerceTo_ CNil e = e
coerceTo_ c    e = syn (CoerceTo_ c :$ e :* End)

unsafeFrom_ :: (ABT Term abt) => Coercion a b -> abt '[] b -> abt '[] a
unsafeFrom_ CNil e = e
unsafeFrom_ c    e = syn (UnsafeFrom_ c :$ e :* End)

literal_ :: (ABT Term abt) => Literal a  -> abt '[] a
literal_ = syn . Literal_
bool_    :: (ABT Term abt) => Bool     -> abt '[] HBool
bool_    = datum_ . (\b -> if b then dTrue else dFalse)
nat_     :: (ABT Term abt) => Natural  -> abt '[] 'HNat
nat_     = literal_ . LNat
int_     :: (ABT Term abt) => Integer  -> abt '[] 'HInt
int_     = literal_ . LInt
prob_    :: (ABT Term abt) => NonNegativeRational -> abt '[] 'HProb
prob_    = literal_ . LProb
real_    :: (ABT Term abt) => Rational -> abt '[] 'HReal
real_    = literal_ . LReal

fromRational
    :: forall abt a
    . (ABT Term abt, HFractional_ a)
    => Rational
    -> abt '[] a
fromRational =
    case (hFractional :: HFractional a) of
    HFractional_Prob -> prob_ . unsafeNonNegativeRational
    HFractional_Real -> real_

half :: forall abt a
     .  (ABT Term abt, HFractional_ a) => abt '[] a
half = fromRational (1 Prelude./ 2)

third :: (ABT Term abt, HFractional_ a) => abt '[] a
third = fromRational (1 Prelude./ 3)


-- Boolean operators
true, false :: (ABT Term abt) => abt '[] HBool
true  = bool_ True
false = bool_ False

-- TODO: simplifications: distribution, constant-propogation
-- TODO: do we really want to distribute /by default/? Clearly we'll want to do that in some optimization\/partial-evaluation pass, but do note that it makes terms larger in general...
not :: (ABT Term abt) => abt '[] HBool -> abt '[] HBool
not e =
    Prelude.maybe (primOp1_ Not e) id
        $ caseVarSyn e
            (const Nothing)
            $ \t ->
                case t of
                PrimOp_ Not :$ es' ->
                    case es' of
                    e' :* End -> Just e'
                    _         -> error "not: the impossible happened"
                NaryOp_ And xs ->
                    Just . syn . NaryOp_ Or  $ Prelude.fmap not xs
                NaryOp_ Or xs ->
                    Just . syn . NaryOp_ And $ Prelude.fmap not xs
                NaryOp_ Xor xs ->
                    Just . syn . NaryOp_ Iff $ Prelude.fmap not xs
                NaryOp_ Iff xs ->
                    Just . syn . NaryOp_ Xor $ Prelude.fmap not xs
                Literal_ _ -> error "not: the impossible happened"
                _ -> Nothing

and, or :: (ABT Term abt) => [abt '[] HBool] -> abt '[] HBool
and = naryOp_withIdentity And true
or  = naryOp_withIdentity Or  false

(&&), (||),
    -- (</=>), (<==>), (==>), (<==), (\\), (//) -- TODO: better names?
    nand, nor
    :: (ABT Term abt) => abt '[] HBool -> abt '[] HBool -> abt '[] HBool
(&&) = naryOp2_ And
(||) = naryOp2_ Or
-- (</=>) = naryOp2_ Xor
-- (<==>) = naryOp2_ Iff
-- (==>)  = primOp2_ Impl
-- (<==)  = flip (==>)
-- (\\)   = primOp2_ Diff
-- (//)   = flip (\\)
nand   = primOp2_ Nand
nor    = primOp2_ Nor


-- HEq & HOrder operators
(==), (/=)
    :: (ABT Term abt, HEq_ a) => abt '[] a -> abt '[] a -> abt '[] HBool
(==) = primOp2_ $ Equal hEq
(/=) = (not .) . (==)

(<), (<=), (>), (>=)
    :: (ABT Term abt, HOrd_ a) => abt '[] a -> abt '[] a -> abt '[] HBool
(<)    = primOp2_ $ Less hOrd
x <= y = not (x > y) -- or: @(x < y) || (x == y)@
(>)    = flip (<)
(>=)   = flip (<=)

min, max :: (ABT Term abt, HOrd_ a) => abt '[] a -> abt '[] a -> abt '[] a
min = naryOp2_ $ Min hOrd
max = naryOp2_ $ Max hOrd

-- TODO: if @a@ is bounded, then we can make these safe...
minimum, maximum :: (ABT Term abt, HOrd_ a) => [abt '[] a] -> abt '[] a
minimum = unsafeNaryOp_ $ Min hOrd
maximum = unsafeNaryOp_ $ Max hOrd


-- HSemiring operators
(+), (*)
    :: (ABT Term abt, HSemiring_ a) => abt '[] a -> abt '[] a -> abt '[] a
(+) = naryOp2_ $ Sum  hSemiring
(*) = naryOp2_ $ Prod hSemiring

zero, one :: forall abt a. (ABT Term abt, HSemiring_ a) => abt '[] a
zero = zero_ (hSemiring :: HSemiring a)
one  = one_  (hSemiring :: HSemiring a)

zero_, one_ :: (ABT Term abt) => HSemiring a -> abt '[] a
zero_ HSemiring_Nat  = literal_ $ LNat  0
zero_ HSemiring_Int  = literal_ $ LInt  0
zero_ HSemiring_Prob = literal_ $ LProb 0
zero_ HSemiring_Real = literal_ $ LReal 0
one_  HSemiring_Nat  = literal_ $ LNat  1
one_  HSemiring_Int  = literal_ $ LInt  1
one_  HSemiring_Prob = literal_ $ LProb 1
one_  HSemiring_Real = literal_ $ LReal 1

-- TODO: add a smart constructor for @HSemiring_ a => Natural -> abt '[] a@ and\/or @HRing_ a => Integer -> abt '[] a@

sum, prod :: (ABT Term abt, HSemiring_ a) => [abt '[] a] -> abt '[] a
sum  = naryOp_withIdentity (Sum  hSemiring) zero
prod = naryOp_withIdentity (Prod hSemiring) one

{-
sum, product :: (ABT Term abt, HSemiring_ a) => [abt '[] a] -> abt '[] a
sum     = unsafeNaryOp_ $ Sum  hSemiring
product = unsafeNaryOp_ $ Prod hSemiring
-}


-- TODO: simplifications
(^) :: (ABT Term abt, HSemiring_ a)
    => abt '[] a -> abt '[] 'HNat -> abt '[] a
(^) = primOp2_ $ NatPow hSemiring

-- TODO: this is actually safe, how can we capture that?
-- TODO: is this type restruction actually helpful anywhere for us?
-- If so, we ought to make this function polymorphic so that we can
-- use it for non-HRing HSemirings too...
square :: (ABT Term abt, HRing_ a) => abt '[] a -> abt '[] (NonNegative a)
square e = unsafeFrom_ signed (e ^ nat_ 2)


-- HRing operators
(-) :: (ABT Term abt, HRing_ a) => abt '[] a -> abt '[] a -> abt '[] a
x - y = x + negate y


-- TODO: do we really want to distribute negation over addition /by
-- default/? Clearly we'll want to do that in some
-- optimization\/partial-evaluation pass, but do note that it makes
-- terms larger in general...
negate :: (ABT Term abt, HRing_ a) => abt '[] a -> abt '[] a
negate e =
    Prelude.maybe (primOp1_ (Negate hRing) e) id
        $ caseVarSyn e
            (const Nothing)
            $ \t ->
                case t of
                -- TODO: need we case analyze the @HSemiring@?
                NaryOp_ (Sum theSemi) xs ->
                    Just . syn . NaryOp_ (Sum theSemi) $ Prelude.fmap negate xs
                -- TODO: need we case analyze the @HRing@?
                PrimOp_ (Negate _theRing) :$ es' ->
                    case es' of
                    e' :* End -> Just e'
                    _         -> error "negate: the impossible happened"
                _ -> Nothing


-- TODO: test case: @negative . square@ simplifies away the intermediate coercions. (cf., normal')
-- BUG: this can lead to ambiguity when used with the polymorphic functions of RealProb.
-- | An occasionally helpful variant of 'negate'.
negative :: (ABT Term abt, HRing_ a) => abt '[] (NonNegative a) -> abt '[] a
negative = negate . coerceTo_ signed


abs :: (ABT Term abt, HRing_ a) => abt '[] a -> abt '[] a
abs = coerceTo_ signed . abs_

abs_ :: (ABT Term abt, HRing_ a) => abt '[] a -> abt '[] (NonNegative a)
abs_ e = 
    Prelude.maybe (primOp1_ (Abs hRing) e) id
        $ caseVarSyn e
            (const Nothing)
            $ \t ->
                case t of
                -- BUG: can't use the 'Signed' pattern synonym here, because that /requires/ the input to be (NonNegative a), instead of giving us the information that it is.
                -- TODO: need we case analyze the @HRing@?
                CoerceTo_ (CCons (Signed _theRing) CNil) :$ es' ->
                    case es' of
                    e' :* End -> Just e'
                    _         -> error "abs_: the impossible happened"
                _ -> Nothing


-- TODO: any obvious simplifications? idempotent?
signum :: (ABT Term abt, HRing_ a) => abt '[] a -> abt '[] a
signum = primOp1_ $ Signum hRing


-- HFractional operators
(/) :: (ABT Term abt, HFractional_ a) => abt '[] a -> abt '[] a -> abt '[] a
x / y = x * recip y


-- TODO: generalize this pattern so we don't have to repeat it...
--
-- TODO: do we really want to distribute reciprocal over multiplication
-- /by default/? Clearly we'll want to do that in some
-- optimization\/partial-evaluation pass, but do note that it makes
-- terms larger in general...
recip :: (ABT Term abt, HFractional_ a) => abt '[] a -> abt '[] a
recip e0 =
    Prelude.maybe (primOp1_ (Recip hFractional) e0) id
        $ caseVarSyn e0
            (const Nothing)
            $ \t0 ->
                case t0 of
                -- TODO: need we case analyze the @HSemiring@?
                NaryOp_ (Prod theSemi) xs ->
                    Just . syn . NaryOp_ (Prod theSemi) $ Prelude.fmap recip xs
                -- TODO: need we case analyze the @HFractional@?
                PrimOp_ (Recip _theFrac) :$ es' ->
                    case es' of
                    e :* End -> Just e
                    _ -> error "recip: the impossible happened"
                _ -> Nothing


-- TODO: simplifications
-- TODO: a variant of 'if_' which gives us the evidence that the argument is non-negative, so we don't need to coerce or use 'abs_'
(^^) :: (ABT Term abt, HFractional_ a)
    => abt '[] a -> abt '[] 'HInt -> abt '[] a
x ^^ y =
    if_ (y < int_ 0)
        (recip x ^ abs_ y)
        (x ^ abs_ y)


-- HRadical operators
-- N.B., HProb is the only HRadical type (for now...)
-- TODO: simplifications
thRootOf
    :: (ABT Term abt, HRadical_ a)
    => abt '[] 'HNat -> abt '[] a -> abt '[] a
n `thRootOf` x = primOp2_ (NatRoot hRadical) x n

sqrt :: (ABT Term abt, HRadical_ a) => abt '[] a -> abt '[] a
sqrt = (nat_ 2 `thRootOf`)

{-
-- TODO: simplifications
(^+) :: (ABT Term abt, HRadical_ a)
    => abt '[] a -> abt '[] 'HPositiveRational -> abt '[] a
x ^+ y = casePositiveRational y $ \n d -> d `thRootOf` (x ^ n)

(^*) :: (ABT Term abt, HRadical_ a)
    => abt '[] a -> abt '[] 'HRational -> abt '[] a
x ^* y = caseRational y $ \n d -> d `thRootOf` (x ^^ n)
-}

betaFunc
    :: (ABT Term abt) => abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
betaFunc = primOp2_ BetaFunc


integrate
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HReal
    -> (abt '[] 'HReal -> abt '[] 'HProb)
    -> abt '[] 'HProb
integrate lo hi f =
    syn (Integrate :$ lo :* hi :* binder Text.empty sing f :* End)

summate
    :: (ABT Term abt, HDiscrete_ a, HSemiring_ b, SingI a)
    => abt '[] a
    -> abt '[] a
    -> (abt '[] a -> abt '[] b)
    -> abt '[] b
summate lo hi f =
    syn (Summate hDiscrete hSemiring
         :$ lo :* hi :* binder Text.empty sing f :* End)

product
    :: (ABT Term abt, HDiscrete_ a, HSemiring_ b, SingI a)
    => abt '[] a
    -> abt '[] a
    -> (abt '[] a -> abt '[] b)
    -> abt '[] b
product lo hi f =
    syn (Product hDiscrete hSemiring
         :$ lo :* hi :* binder Text.empty sing f :* End)


class Integrable (a :: Hakaru) where
    infinity :: (ABT Term abt) => abt '[] a

instance Integrable 'HNat where
    infinity = primOp0_ (Infinity HIntegrable_Nat)

instance Integrable 'HInt where
    infinity = nat2int $ primOp0_ (Infinity HIntegrable_Nat)

instance Integrable 'HProb where
    infinity = primOp0_ (Infinity HIntegrable_Prob)

instance Integrable 'HReal where
    infinity = fromProb $ primOp0_ (Infinity HIntegrable_Prob)

-- HACK: we define this class in order to gain more polymorphism;
-- but, will it cause type inferencing issues? Excepting 'log'
-- (which should be moved out of the class) these are all safe.
class RealProb (a :: Hakaru) where
    (**) :: (ABT Term abt) => abt '[] 'HProb -> abt '[] a -> abt '[] 'HProb
    exp  :: (ABT Term abt) => abt '[] a -> abt '[] 'HProb
    erf  :: (ABT Term abt) => abt '[] a -> abt '[] a
    pi   :: (ABT Term abt) => abt '[] a
    gammaFunc :: (ABT Term abt) => abt '[] a -> abt '[] 'HProb

instance RealProb 'HReal where
    (**)      = primOp2_ RealPow
    exp       = primOp1_ Exp
    erf       = primOp1_ $ Erf hContinuous
    pi        = fromProb $ primOp0_ Pi
    gammaFunc = primOp1_ GammaFunc

instance RealProb 'HProb where
    x ** y    = primOp2_ RealPow x $ fromProb y
    exp       = primOp1_ Exp . fromProb
    erf       = primOp1_ $ Erf hContinuous
    pi        = primOp0_ Pi
    gammaFunc = primOp1_ GammaFunc . fromProb

log  :: (ABT Term abt) => abt '[] 'HProb -> abt '[] 'HReal
log = primOp1_ Log

logBase
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HReal
logBase b x = log x / log b -- undefined when b == 1

sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh
    :: (ABT Term abt) => abt '[] 'HReal -> abt '[] 'HReal
sin    = primOp1_ Sin
cos    = primOp1_ Cos
tan    = primOp1_ Tan
asin   = primOp1_ Asin
acos   = primOp1_ Acos
atan   = primOp1_ Atan
sinh   = primOp1_ Sinh
cosh   = primOp1_ Cosh
tanh   = primOp1_ Tanh
asinh  = primOp1_ Asinh
acosh  = primOp1_ Acosh
atanh  = primOp1_ Atanh


----------------------------------------------------------------
datum_
    :: (ABT Term abt)
    => Datum (abt '[]) (HData' t)
    -> abt '[] (HData' t)
datum_ = syn . Datum_

case_
     :: (ABT Term abt)
     => abt '[] a
     -> [Branch a abt b]
     -> abt '[] b
case_ e bs = syn (Case_ e bs)

branch
    :: (ABT Term abt)
    => Pattern xs a
    -> abt xs b
    -> Branch a abt b
branch = Branch

unit :: (ABT Term abt) => abt '[] HUnit
unit = datum_ dUnit

pair
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] a -> abt '[] b -> abt '[] (HPair a b)
pair = (datum_ .) . dPair


pair_
    :: (ABT Term abt)
    => Sing a
    -> Sing b
    -> abt '[] a
    -> abt '[] b
    -> abt '[] (HPair a b)
pair_ a b = (datum_ .) . dPair_ a b


unpair
    :: forall abt a b c
    .  (ABT Term abt)
    => abt '[] (HPair a b)
    -> (abt '[] a -> abt '[] b -> abt '[] c)
    -> abt '[] c
unpair e hoas =
    let (aTyp,bTyp) = sUnPair $ typeOf e
        body        = hoas (var a) (var b)
        inc x       = 1 Prelude.+ x
        a           = Variable Text.empty (nextBind body)         aTyp
        b           = Variable Text.empty (inc . nextBind $ body) bTyp
    in case_ e
        [Branch (pPair PVar PVar)
           (bind a (bind b body))
        ]

fst :: (ABT Term abt)
    => abt '[] (HPair a b)
    -> abt '[] a
fst p = unpair p (\x _ -> x)

snd :: (ABT Term abt)
    => abt '[] (HPair a b)
    -> abt '[] b
snd p = unpair p (\_ y -> y)

swap :: (ABT Term abt, SingI a, SingI b)
    => abt '[] (HPair a b)
    -> abt '[] (HPair b a)
swap ab = unpair ab (flip pair)

left
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] a -> abt '[] (HEither a b)
left = datum_ . dLeft

right
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] b -> abt '[] (HEither a b)
right = datum_ . dRight

uneither
    :: (ABT Term abt)
    => abt '[] (HEither a b)
    -> (abt '[] a -> abt '[] c)
    -> (abt '[] b -> abt '[] c)
    -> abt '[] c
uneither e l r =
    let (a,b) = sUnEither $ typeOf e
    in case_ e
        [ Branch (pLeft  PVar) (binder Text.empty a l)
        , Branch (pRight PVar) (binder Text.empty b r)
        ]

if_ :: (ABT Term abt)
    => abt '[] HBool
    -> abt '[] a
    -> abt '[] a
    -> abt '[] a
if_ b t f =
    case_ b
     [ Branch pTrue  t
     , Branch pFalse f
     ]

nil :: (ABT Term abt, SingI a) => abt '[] (HList a)
nil = datum_ dNil

cons
    :: (ABT Term abt, SingI a)
    => abt '[] a -> abt '[] (HList a) -> abt '[] (HList a)
cons = (datum_ .) . dCons

list :: (ABT Term abt, SingI a) => [abt '[] a] -> abt '[] (HList a)
list = Prelude.foldr cons nil

nothing :: (ABT Term abt, SingI a) => abt '[] (HMaybe a)
nothing = datum_ dNothing

just :: (ABT Term abt, SingI a) => abt '[] a -> abt '[] (HMaybe a)
just = datum_ . dJust

maybe :: (ABT Term abt, SingI a) => Maybe (abt '[] a) -> abt '[] (HMaybe a)
maybe = Prelude.maybe nothing just

unmaybe
    :: (ABT Term abt)
    => abt '[] (HMaybe a)
    -> abt '[] b
    -> (abt '[] a -> abt '[] b)
    -> abt '[] b
unmaybe e n j = 
    case_ e
     [ Branch pNothing     n
     , Branch (pJust PVar) (binder Text.empty (sUnMaybe $ typeOf e) j)
     ]

unsafeProb :: (ABT Term abt) => abt '[] 'HReal -> abt '[] 'HProb
unsafeProb = unsafeFrom_ signed

fromProb   :: (ABT Term abt) => abt '[] 'HProb -> abt '[] 'HReal
fromProb   = coerceTo_ signed

nat2int    :: (ABT Term abt) => abt '[] 'HNat -> abt '[] 'HInt
nat2int    = coerceTo_ signed

fromInt    :: (ABT Term abt) => abt '[] 'HInt  -> abt '[] 'HReal
fromInt    = coerceTo_ continuous

nat2prob   :: (ABT Term abt) => abt '[] 'HNat  -> abt '[] 'HProb
nat2prob   = coerceTo_ continuous

nat2real   :: (ABT Term abt) => abt '[] 'HNat  -> abt '[] 'HReal
nat2real   = coerceTo_ (continuous . signed)

{- -- Uncomment only if we actually end up needing this anywhere
class FromNat (a :: Hakaru) where
    fromNat :: (ABT Term abt) => abt '[] 'HNat  -> abt '[] a

instance FromNat 'HNat  where fromNat = id
instance FromNat 'HInt  where fromNat = nat2int
instance FromNat 'HProb where fromNat = nat2prob
instance FromNat 'HReal where fromNat = fromProb . nat2prob
-}

unsafeProbFraction
    :: forall abt a
    .  (ABT Term abt, HFractional_ a)
    => abt '[] a
    -> abt '[] 'HProb
unsafeProbFraction e =
    unsafeProbFraction_ (hFractional :: HFractional a) e

unsafeProbFraction_
    :: (ABT Term abt)
    => HFractional a
    -> abt '[] a
    -> abt '[] 'HProb
unsafeProbFraction_ HFractional_Prob = id
unsafeProbFraction_ HFractional_Real = unsafeProb

unsafeProbSemiring
    :: forall abt a
    .  (ABT Term abt, HSemiring_ a)
    => abt '[] a
    -> abt '[] 'HProb
unsafeProbSemiring e =
    unsafeProbSemiring_ (hSemiring :: HSemiring a) e

unsafeProbSemiring_
    :: (ABT Term abt)
    => HSemiring a
    -> abt '[] a
    -> abt '[] 'HProb
unsafeProbSemiring_ HSemiring_Nat  = nat2prob
unsafeProbSemiring_ HSemiring_Int  = coerceTo_ continuous . unsafeFrom_ signed
unsafeProbSemiring_ HSemiring_Prob = id
unsafeProbSemiring_ HSemiring_Real = unsafeProb


negativeInfinity :: ( ABT Term abt
                    , HRing_ a
                    , Integrable a)
                 => abt '[] a
negativeInfinity = negate infinity

-- instance (ABT Term abt) => Lambda abt where
-- 'app' already defined

-- TODO: use 'typeOf' to remove the 'SingI' requirement somehow
-- | A variant of 'lamWithVar' for automatically computing the type
-- via 'sing'.
lam :: (ABT Term abt, SingI a)
    => (abt '[] a -> abt '[] b)
    -> abt '[] (a ':-> b)
lam = lamWithVar Text.empty sing

-- | Create a lambda abstraction. The first two arguments give the
-- hint and type of the lambda-bound variable in the result. If you
-- want to automatically fill those in, then see 'lam'.
lamWithVar
    :: (ABT Term abt)
    => Text.Text
    -> Sing a
    -> (abt '[] a -> abt '[] b)
    -> abt '[] (a ':-> b)
lamWithVar hint typ f = syn (Lam_ :$ binder hint typ f :* End)

{-
-- some test cases to make sure we tied-the-knot successfully:
> let
    lam :: (ABT Term abt)
        => String
        -> Sing a
        -> (abt '[] a -> abt '[] b)
        -> abt '[] (a ':-> b)
    lam name typ f = syn (Lam_ :$ binder name typ f :* End)
> lam "x" SInt (\x -> x) :: TrivialABT Term ('HInt ':-> 'HInt)
> lam "x" SInt (\x -> lam "y" SInt $ \y -> x < y) :: TrivialABT Term ('HInt ':-> 'HInt ':-> 'HBool)
-}

-- TODO: make this smarter so that if the @e@ is already a variable then we just plug it into @f@ instead of introducing the trivial let-binding.
let_
    :: (ABT Term abt)
    => abt '[] a
    -> (abt '[] a -> abt '[] b)
    -> abt '[] b
let_ e f = syn (Let_ :$ e :* binder Text.empty (typeOf e) f :* End)


----------------------------------------------------------------
array
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> (abt '[] 'HNat -> abt '[] a)
    -> abt '[] ('HArray a)
array n =
    syn . Array_ n . binder Text.empty sing

arrayWithVar
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> Variable 'HNat
    -> abt '[] a
    -> abt '[] ('HArray a)
arrayWithVar n x body =
    syn $ Array_ n (bind x body)


empty :: (ABT Term abt, SingI a) => abt '[] ('HArray a)
empty = syn (Empty_ sing)

(!) :: (ABT Term abt)
    => abt '[] ('HArray a) -> abt '[] 'HNat -> abt '[] a
(!) e = arrayOp2_ (Index . sUnArray $ typeOf e) e

size :: (ABT Term abt) => abt '[] ('HArray a) -> abt '[] 'HNat
size e = arrayOp1_ (Size . sUnArray $ typeOf e) e

reduce
    :: (ABT Term abt)
    => (abt '[] a -> abt '[] a -> abt '[] a)
    -> abt '[] a
    -> abt '[] ('HArray a)
    -> abt '[] a
reduce f e =
    let a  = typeOf e
        f' = lamWithVar Text.empty a $ \x ->
                lamWithVar Text.empty a $ \y -> f x y
    in arrayOp3_ (Reduce a) f' e

-- TODO: better names for all these. The \"V\" suffix doesn't make sense anymore since we're calling these arrays, not vectors...
-- TODO: bust these all out into their own place, since the API for arrays is gonna be huge

sumV :: (ABT Term abt, HSemiring_ a)
    => abt '[] ('HArray a) -> abt '[] a
sumV = reduce (+) zero -- equivalent to summateV if @a ~ 'HProb@

summateV :: (ABT Term abt) => abt '[] ('HArray 'HProb) -> abt '[] 'HProb
summateV x =
    summate (nat_ 0) (size x)
        (\i -> x ! i)

-- TODO: a variant of 'if_' for giving us evidence that the subtraction is sound.

unsafeMinusNat
    :: (ABT Term abt) => abt '[] 'HNat -> abt '[] 'HNat -> abt '[] 'HNat
unsafeMinusNat x y = unsafeFrom_ signed (nat2int x - nat2int y)

unsafeMinusProb
    :: (ABT Term abt) => abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
unsafeMinusProb x y = unsafeProb (fromProb x - fromProb y)

-- | For any semiring we can attempt subtraction by lifting to a
-- ring, subtracting there, and then lowering back to the semiring.
-- Of course, the lowering step may well fail.
unsafeMinus
    :: (ABT Term abt, HSemiring_ a) => abt '[] a -> abt '[] a -> abt '[] a
unsafeMinus = unsafeMinus_ hSemiring

-- | A variant of 'unsafeMinus' for explicitly passing the semiring
-- instance.
unsafeMinus_
    :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] a -> abt '[] a
unsafeMinus_ theSemi =
    signed_HSemiring theSemi $ \c ->
        let lift  = coerceTo_   c
            lower = unsafeFrom_ c
        in \e1 e2 -> lower (lift e1 - lift e2)

-- TODO: move to Coercion.hs?
-- | For any semiring, return a coercion to its ring completion.
-- Because this completion is existentially quantified, we must use
-- a cps trick to eliminate the existential.
signed_HSemiring
    :: HSemiring a -> (forall b. (HRing_ b) => Coercion a b -> r) -> r
signed_HSemiring c k =
    case c of
    HSemiring_Nat  -> k $ singletonCoercion (Signed HRing_Int)
    HSemiring_Int  -> k CNil
    HSemiring_Prob -> k $ singletonCoercion (Signed HRing_Real)
    HSemiring_Real -> k CNil

-- | For any semiring we can attempt division by lifting to a
-- semifield, dividing there, and then lowering back to the semiring.
-- Of course, the lowering step may well fail.
unsafeDiv
    :: (ABT Term abt, HSemiring_ a) => abt '[] a -> abt '[] a -> abt '[] a
unsafeDiv = unsafeDiv_ hSemiring

-- | A variant of 'unsafeDiv' for explicitly passing the semiring
-- instance.
unsafeDiv_
    :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] a -> abt '[] a
unsafeDiv_ theSemi =
    continuous_HSemiring theSemi $ \c ->
        let lift  = coerceTo_   c
            lower = unsafeFrom_ c
        in \e1 e2 -> lower (lift e1 / lift e2)

-- TODO: move to Coercion.hs?
-- | For any semiring, return a coercion to its semifield completion.
-- Because this completion is existentially quantified, we must use
-- a cps trick to eliminate the existential.
continuous_HSemiring
    :: HSemiring a -> (forall b. (HFractional_ b) => Coercion a b -> r) -> r
continuous_HSemiring c k =
    case c of
    HSemiring_Nat  -> k $ singletonCoercion (Continuous HContinuous_Prob)
    HSemiring_Int  -> k $ singletonCoercion (Continuous HContinuous_Real)
    HSemiring_Prob -> k CNil
    HSemiring_Real -> k CNil


appendV
    :: (ABT Term abt)
    => abt '[] ('HArray a) -> abt '[] ('HArray a) -> abt '[] ('HArray a)
appendV v1 v2 =
    array (size v1 + size v2) $ \i ->
        if_ (i < size v1)
            (v1 ! i)
            (v2 ! (i `unsafeMinusNat` size v1))

mapWithIndex
    :: (ABT Term abt)
    => (abt '[] 'HNat -> abt '[] a -> abt '[] b)
    -> abt '[] ('HArray a)
    -> abt '[] ('HArray b)
mapWithIndex f v = array (size v) $ \i -> f i (v ! i)

mapV
    :: (ABT Term abt)
    => (abt '[] a -> abt '[] b)
    -> abt '[] ('HArray a)
    -> abt '[] ('HArray b)
mapV f v = array (size v) $ \i -> f (v ! i)

normalizeV
    :: (ABT Term abt)
    => abt '[] ('HArray 'HProb)
    -> abt '[] ('HArray 'HProb)
normalizeV x = mapV (/ sumV x) x

constV
    :: (ABT Term abt) => abt '[] 'HNat -> abt '[] b -> abt '[] ('HArray b)
constV n c = array n (const c)

unitV
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> abt '[] 'HNat
    -> abt '[] ('HArray 'HProb)
unitV s i = array s (\j -> if_ (i == j) (prob_ 1) (prob_ 0))

zipWithV
    :: (ABT Term abt)
    => (abt '[] a -> abt '[] b -> abt '[] c)
    -> abt '[] ('HArray a)
    -> abt '[] ('HArray b)
    -> abt '[] ('HArray c)
zipWithV f v1 v2 =
    array (size v1) (\i -> f (v1 ! i) (v2 ! i))


----------------------------------------------------------------
(>>=)
    :: (ABT Term abt)
    => abt '[] ('HMeasure a)
    -> (abt '[] a -> abt '[] ('HMeasure b))
    -> abt '[] ('HMeasure b)
m >>= f =
    syn (MBind :$ m
               :* binder Text.empty (sUnMeasure $ typeOf m) f
               :* End)


dirac :: (ABT Term abt) => abt '[] a -> abt '[] ('HMeasure a)
dirac e1 = syn (Dirac :$ e1 :* End)


-- TODO: can we use let-binding instead of (>>=)-binding (i.e., for when the dirac is immediately (>>=)-bound again...)?
(<$>)
    :: (ABT Term abt, SingI a)
    => (abt '[] a -> abt '[] b)
    -> abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure b)
f <$> m = m >>= dirac . f

-- | N.B, this function may introduce administrative redexes.
-- Moreover, it's not clear that we should even allow the type
-- @'HMeasure (a ':-> b)@!
(<*>)
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] ('HMeasure (a ':-> b))
    -> abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure b)
mf <*> mx = mf >>= \f -> app f <$> mx

-- TODO: ensure that @dirac a *> n@ simplifies to just @n@, regardless of @a@ but especially when @a = unit@.
(*>), (>>)
    :: (ABT Term abt, SingI a)
    => abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure b)
    -> abt '[] ('HMeasure b)
m *> n = m >>= \_ -> n
(>>) = (*>)

-- TODO: ensure that @m <* dirac a@ simplifies to just @m@, regardless of @a@ but especially when @a = unit@.
(<*)
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure b)
    -> abt '[] ('HMeasure a)
m <* n = m >>= \a -> n *> dirac a

bindx
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] ('HMeasure a)
    -> (abt '[] a -> abt '[] ('HMeasure b))
    -> abt '[] ('HMeasure (HPair a b))
m `bindx` f = m >>= \a -> pair a <$> f a

-- Defined because using @(<$>)@ and @(<*>)@ would introduce administrative redexes
liftM2
    :: (ABT Term abt, SingI a, SingI b)
    => (abt '[] a -> abt '[] b -> abt '[] c)
    -> abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure b)
    -> abt '[] ('HMeasure c)
liftM2 f m n = m >>= \x -> f x <$> n


lebesgue :: (ABT Term abt) => abt '[] ('HMeasure 'HReal)
lebesgue = measure0_ Lebesgue

counting :: (ABT Term abt) => abt '[] ('HMeasure 'HInt)
counting = measure0_ Counting

-- TODO: make this smarter by collapsing nested @Superpose_@ similar to how we collapse nested NaryOps. Though beware, that could cause duplication of the computation for the probabilities\/weights; thus may want to only do it when the weights are constant values, or \"simplify\" things by generating let-bindings in order to share work.
--
-- TODO: can we make this smarter enough to handle empty lists?
superpose
    :: (ABT Term abt)
    => NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
    -> abt '[] ('HMeasure a)
superpose = syn . Superpose_

-- | The empty measure. Is called @fail@ in the Core Hakaru paper.
reject
    :: (ABT Term abt)
    => (Sing ('HMeasure a))
    -> abt '[] ('HMeasure a)
reject = syn . Reject_

-- | The sum of two measures. Is called @mplus@ in the Core Hakaru paper.
(<|>) :: (ABT Term abt)
      => abt '[] ('HMeasure a)
      -> abt '[] ('HMeasure a)
      -> abt '[] ('HMeasure a)
x <|> y =
    superpose $
        case (matchSuperpose x, matchSuperpose y) of
        (Just xs, Just ys) -> xs <> ys
        (Just xs, Nothing) -> (one, y) :| L.toList xs -- HACK: reordering!
        (Nothing, Just ys) -> (one, x) :| L.toList ys
        (Nothing, Nothing) -> (one, x) :| [(one, y)]

matchSuperpose
    :: (ABT Term abt) 
    => abt '[] ('HMeasure a)
    -> Maybe (NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a)))
matchSuperpose e =
    caseVarSyn e
        (const Nothing)
        $ \t ->
            case t of
            Superpose_ xs -> Just xs
            _ -> Nothing

-- TODO: we should ensure that the following reductions happen:
-- > (withWeight p m >> n) ---> withWeight p (m >> n)
-- > (m >> withWeight p n) ---> withWeight p (m >> n)
-- > withWeight 1 m ---> m
-- > withWeight p (withWeight q m) ---> withWeight (p*q) m
-- > (weight p >> m) ---> withWeight p m
--
-- | Adjust the weight of the current measure.
--
-- /N.B.,/ the name for this function is terribly inconsistent
-- across the literature, even just the Hakaru literature, let alone
-- the Hakaru code base. It is variously called \"factor\" or
-- \"weight\"; though \"factor\" is also used to mean the function
-- 'factor' or the function 'observe', and \"weight\" is also used
-- to mean the 'weight' function.
weight
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] ('HMeasure HUnit)
weight p = withWeight p (dirac unit)


-- | A variant of 'weight' which removes an administrative @(dirac
-- unit >>)@ redex.
--
-- TODO: ideally we'll be able to get rid of this function entirely,
-- and be able to trust optimization to clean up any redexes
-- introduced by 'weight'.
withWeight
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] ('HMeasure w)
    -> abt '[] ('HMeasure w)
withWeight p m = syn $ Superpose_ ((p, m) :| [])


-- | A particularly common use case of 'weight':
--
-- > weightedDirac e p
-- >     == weight p (dirac e)
-- >     == weight p *> dirac e
-- >     == dirac e <* weight p
weightedDirac
    :: (ABT Term abt, SingI a)
    => abt '[] a
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure a)
weightedDirac e p = withWeight p (dirac e)


-- TODO: this taking of two arguments is as per the Core Hakaru specification; but for the EDSL, can we rephrase this as just taking the first argument, using @dirac unit@ for the else-branch, and then, making @(>>)@ work in the right way to plug the continuation measure in place of the @dirac unit@.
-- TODO: would it help inference\/simplification at all to move this into the AST as a primitive? I mean, it is a primitive of Core Hakaru afterall... Also, that would help clarify whether the (first)argument should actually be an @HBool@ or whether it should be some sort of proposition.

-- | Assert that a condition is true.
--
-- /N.B.,/ the name for this function is terribly inconsistent
-- across the literature, even just the Hakaru literature, let alone
-- the Hakaru code base. It is variously called \"factor\" or
-- \"observe\"; though \"factor\" is also used to mean the function
-- 'pose', and \"observe\" is also used to mean the backwards part
-- of Lazy.hs.
guard
    :: (ABT Term abt)
    => abt '[] HBool
    -> abt '[] ('HMeasure HUnit)
guard b = withGuard b (dirac unit)


-- | A variant of 'guard' which removes an administrative @(dirac
-- unit >>)@ redex.
--
-- TODO: ideally we'll be able to get rid of this function entirely,
-- and be able to trust optimization to clean up any redexes
-- introduced by 'guard'.
withGuard
    :: (ABT Term abt)
    => abt '[] HBool
    -> abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure a)
withGuard b m = if_ b m (reject (typeOf m))


densityCategorical
    :: (ABT Term abt)
    => abt '[] ('HArray 'HProb)
    -> abt '[] 'HNat
    -> abt '[] 'HProb
densityCategorical v i = v ! i / summateV v

categorical, categorical'
    :: (ABT Term abt)
    => abt '[] ('HArray 'HProb)
    -> abt '[] ('HMeasure 'HNat)
categorical = measure1_ Categorical

-- TODO: a variant of 'if_' which gives us the evidence that the argument is non-negative, so we don't need to coerce or use 'abs_'
categorical' v =
    counting >>= \i ->
    withGuard (int_ 0 <= i && i < nat2int (size v)) $
    let_ (unsafeFrom_ signed i) $ \i_ ->
    weightedDirac i_ (densityCategorical v i_)


densityUniform
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HReal
    -> abt '[] 'HReal
    -> abt '[] 'HProb
densityUniform lo hi _ = recip . unsafeProb $ hi - lo


-- TODO: make Uniform polymorphic, so that if the two inputs are
-- HProb then we know the measure must be over HProb too
uniform, uniform'
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HReal
    -> abt '[] ('HMeasure 'HReal)
uniform = measure2_ Uniform

uniform' lo hi = 
    lebesgue >>= \x ->
    withGuard (lo < x && x < hi) $
        -- TODO: how can we capture that this 'unsafeProb' is safe? (and that this 'recip' isn't Infinity, for that matter)
    weightedDirac x (densityUniform lo hi x)

densityNormal
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HProb
    -> abt '[] 'HReal
    -> abt '[] 'HProb
densityNormal mu sd x = 
    exp (negate ((x - mu) ^ nat_ 2)  -- TODO: use negative\/square instead of negate\/(^2)
         / fromProb (prob_ 2 * sd ^ nat_ 2)) -- TODO: use square?
     / sd / sqrt (prob_ 2 * pi)


normal, normal'
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HReal)
normal = measure2_ Normal

normal' mu sd  = 
    lebesgue >>= \x ->
    weightedDirac x (densityNormal mu sd x)


densityPoisson
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HNat
    -> abt '[] 'HProb
densityPoisson l x =
     l ^ x
       / gammaFunc (nat2real (x + nat_ 1)) -- TODO: use factorial instead of gammaFunc...
       / exp l


poisson, poisson'
    :: (ABT Term abt) => abt '[] 'HProb -> abt '[] ('HMeasure 'HNat)
poisson = measure1_ Poisson

poisson' l = 
    counting >>= \x ->
    -- TODO: use 'SafeFrom_' instead of @if_ (x >= int_ 0)@ so we can prove that @unsafeFrom_ signed x@ is actually always safe.
    withGuard (int_ 0 <= x && prob_ 0 < l) $ -- N.B., @0 < l@ means simply that @l /= 0@; why phrase it the other way?
    let_ (unsafeFrom_ signed x) $ \x_ ->
        weightedDirac x_ (densityPoisson l x_)

densityGamma
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
densityGamma shape scale x =
    x ** (fromProb shape - real_ 1)
    * exp (negate . fromProb $ x / scale)
    / (scale ** shape * gammaFunc shape)


gamma, gamma'
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HProb)
gamma = measure2_ Gamma

gamma' shape scale =
    lebesgue >>= \x ->
    -- TODO: use 'SafeFrom_' instead of @if_ (real_ 0 < x)@ so we can prove that @unsafeProb x@ is actually always safe. Of course, then we'll need to mess around with checking (/=0) which'll get ugly... Use another SafeFrom_ with an associated NonZero type?
    withGuard (real_ 0 < x) $
    let_ (unsafeProb x) $ \ x_ ->
    weightedDirac x_ (densityGamma shape scale x_)

densityBeta
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
densityBeta a b x =
    x ** (fromProb a - real_ 1)
    * unsafeProb (real_ 1 - fromProb x) ** (fromProb b - real_ 1)
    / betaFunc a b

beta, beta', beta''
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HProb)
beta = measure2_ Beta

beta' a b =
    -- TODO: make Uniform polymorphic, so that if the two inputs are HProb then we know the measure must be over HProb too, and hence @unsafeProb x@ must always be safe. Alas, capturing the safety of @unsafeProb (1-x)@ would take a lot more work...
    unsafeProb <$> uniform (real_ 0) (real_ 1) >>= \x ->
    weightedDirac x (densityBeta a b x)

beta'' a b =
    gamma a (prob_ 1) >>= \x ->
    gamma b (prob_ 1) >>= \y ->
    dirac (x / (x+y))

plateWithVar
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> Variable 'HNat
    -> abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure ('HArray a))
plateWithVar e1 x e2 = syn (Plate :$ e1 :* bind x e2 :* End)
        
plate :: (ABT Term abt)
      => abt '[] 'HNat
      -> (abt '[] 'HNat -> abt '[] ('HMeasure a))
      -> abt '[] ('HMeasure ('HArray a))
plate e f = syn (Plate :$ e :* binder Text.empty sing f :* End)

plate'
    :: (ABT Term abt, SingI a)
    => abt '[] ('HArray ('HMeasure          a))
    -> abt '[] (         'HMeasure ('HArray a))

plate' v = reduce r z (mapV m v)
    where
    r   = liftM2 appendV
    z   = dirac empty
    m a = (array (nat_ 1) . const) <$> a


-- BUG: remove the 'SingI' requirement!
chain :: (ABT Term abt, SingI s)
      => abt '[] 'HNat
      -> abt '[] s
      -> (abt '[] s -> abt '[] ('HMeasure (HPair a s)))
      -> abt '[] ('HMeasure (HPair ('HArray a) s))
chain n s f = syn (Chain :$ n :* s :* binder Text.empty sing f :* End)

chain'
    :: (ABT Term abt, SingI s, SingI a)
    => abt '[] ('HArray (s ':-> 'HMeasure (HPair a s)))
    -> abt '[] s
    -> abt '[] ('HMeasure (HPair ('HArray a) s))

chain' v s0 = reduce r z (mapV m v) `app` s0
    where
    r x y = lam $ \s ->
            app x s >>= \v1s1 ->
            v1s1 `unpair` \v1 s1 ->
            app y s1 >>= \v2s2 ->
            v2s2 `unpair` \v2 s2 ->
            dirac $ pair (appendV v1 v2) s2
    z     = lam $ \s -> dirac (pair empty s)
    m a   = lam $ \s -> (`unpair` pair . array (nat_ 1) . const) <$> app a s


invgamma
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HProb)
invgamma k t = recip <$> gamma k (recip t)

exponential
    :: (ABT Term abt) => abt '[] 'HProb -> abt '[] ('HMeasure 'HProb)
exponential = gamma (prob_ 1)

chi2 :: (ABT Term abt) => abt '[] 'HProb -> abt '[] ('HMeasure 'HProb)
chi2 v = gamma (v / prob_ 2) (prob_ 2)

cauchy
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HReal)
cauchy loc scale =
    normal (real_ 0) (prob_ 1) >>= \x ->
    normal (real_ 0) (prob_ 1) >>= \y ->
    dirac $ loc + fromProb scale * x / y

laplace
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HReal)
laplace loc scale =
    exponential (prob_ 1) >>= \v ->
    normal (real_ 0) (prob_ 1) >>= \z ->
    dirac $ loc + z * fromProb (scale * sqrt (prob_ 2 * v))

studentT
    :: (ABT Term abt)
    => abt '[] 'HReal
    -> abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HReal)
studentT loc scale v =
    normal loc scale >>= \z ->
    chi2 v >>= \df ->
    dirac $ z * fromProb (sqrt (v / df))

weibull
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HProb)
weibull b k =
    exponential (prob_ 1) >>= \x ->
    dirac $ b * x ** recip k

-- BUG: would it be better to 'observe' that @p <= 1@ before doing the superpose? At least that way things would be /defined/ for all inputs...
bern :: (ABT Term abt) => abt '[] 'HProb -> abt '[] ('HMeasure HBool)
bern p = weightedDirac true  p
     <|> weightedDirac false (prob_ 1 `unsafeMinusProb` p)

mix :: (ABT Term abt)
    => abt '[] ('HArray 'HProb) -> abt '[] ('HMeasure 'HNat)
mix v = withWeight (sumV v) (categorical v)

binomial
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> abt '[] 'HProb
    -> abt '[] ('HMeasure 'HInt)
binomial n p =
    sumV <$> plate n (const $ ((\b -> if_ b (int_ 1) (int_ 0)) <$> bern p))

-- BUG: would it be better to 'observe' that @p >= 1@ before doing everything? At least that way things would be /defined/ for all inputs...
negativeBinomial
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> abt '[] 'HProb -- N.B., must actually be between 0 and 1
    -> abt '[] ('HMeasure 'HNat)
negativeBinomial r p =
    gamma (nat2prob r) (recip (recip p `unsafeMinusProb` prob_ 1)) >>= poisson

geometric :: (ABT Term abt) => abt '[] 'HProb -> abt '[] ('HMeasure 'HNat)
geometric = negativeBinomial (nat_ 1)


multinomial
    :: (ABT Term abt)
    => abt '[] 'HNat
    -> abt '[] ('HArray 'HProb)
    -> abt '[] ('HMeasure ('HArray 'HProb))
multinomial n v =
    reduce (liftM2 (zipWithV (+)))
        (dirac (constV (size v) (prob_ 0)))
        (constV n (unitV (size v) <$> categorical v))

dirichlet
    :: (ABT Term abt)
    => abt '[] ('HArray 'HProb)
    -> abt '[] ('HMeasure ('HArray 'HProb))
dirichlet a = normalizeV <$> plate (size a) (\ i -> a ! i  `gamma` prob_ 1)

----------------------------------------------------------------
----------------------------------------------------------- fin.