{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module: Numeric.MCMC.Hamiltonian
-- Copyright: (c) 2015 Jared Tobin
-- License: MIT
--
-- Maintainer: Jared Tobin <jared@jtobin.ca>
-- Stability: unstable
-- Portability: ghc
--
-- This implementation performs Hamiltonian Monte Carlo using an identity mass
-- matrix.
--
-- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
-- while the `slice` transition can be used for more flexible purposes, such as
-- working with samples in memory.
--
-- See <http://arxiv.org/pdf/1206.1901.pdf Neal, 2012> for the definitive
-- reference of the algorithm.

module Numeric.MCMC.Hamiltonian (
    mcmc
  , chain
  , hamiltonian

  -- * Re-exported
  , Target(..)
  , MWC.create
  , MWC.createSystemRandom
  , MWC.withSystemRandom
  , MWC.asGenIO
  ) where

import Control.Lens hiding (index)
import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimState, PrimMonad)
import Control.Monad.Trans.State.Strict hiding (state)
import qualified Data.Foldable as Foldable (sum)
import Data.Maybe (fromMaybe)
import Data.Sampling.Types
import Data.Traversable (for)
import Pipes hiding (for, next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Prob, Gen)
import qualified System.Random.MWC.Probability as MWC

-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
-- >>> withSystemRandom . asGenIO $ mcmc 10000 0.05 20 [0, 0] target
mcmc
  :: ( MonadIO m, PrimMonad m
     , Num (IxValue (t Double)), Show (t Double), Traversable t
     , FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
     , IxValue (t Double) ~ Double)
  => Int
  -> Double
  -> Int
  -> t Double
  -> Target (t Double)
  -> Gen (PrimState m)
  -> m ()
mcmc :: Int
-> Double
-> Int
-> t Double
-> Target (t Double)
-> Gen (PrimState m)
-> m ()
mcmc Int
n Double
step Int
leaps t Double
chainPosition Target (t Double)
chainTarget Gen (PrimState m)
gen = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$
        Double
-> Int
-> Chain (t Double) Any
-> Gen (PrimState m)
-> Producer (Chain (t Double) Any) m ()
forall (t :: * -> *) (m :: * -> *) b c.
(Num (IxValue (t Double)), Traversable t,
 FunctorWithIndex (Index (t Double)) t, Ixed (t Double),
 PrimMonad m, IxValue (t Double) ~ Double) =>
Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive Double
step Int
leaps Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t Double
Double
Maybe Any
Target (t Double)
forall a. Maybe a
chainTarget :: Target (t Double)
chainScore :: Double
chainPosition :: t Double
chainTunables :: Maybe Any
chainTunables :: forall a. Maybe a
chainScore :: Double
chainTarget :: Target (t Double)
chainPosition :: t Double
..} Gen (PrimState m)
gen
    Producer (Chain (t Double) Any) m ()
-> Proxy () (Chain (t Double) Any) () (Chain (t Double) Any) m ()
-> Producer (Chain (t Double) Any) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int
-> Proxy () (Chain (t Double) Any) () (Chain (t Double) Any) m ()
forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
Pipes.take Int
n
    Producer (Chain (t Double) Any) m ()
-> Proxy () (Chain (t Double) Any) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Chain (t Double) Any -> m ())
-> Consumer' (Chain (t Double) Any) m ()
forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
Pipes.mapM_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (Chain (t Double) Any -> IO ()) -> Chain (t Double) Any -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain (t Double) Any -> IO ()
forall a. Show a => a -> IO ()
print)
  where
    chainScore :: Double
chainScore    = Target (t Double) -> t Double -> Double
forall a. Target a -> a -> Double
lTarget Target (t Double)
chainTarget t Double
chainPosition
    chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing

-- | Trace 'n' iterations of a Markov chain and collect the results in a list.
--
-- >>> results <- withSystemRandom . asGenIO $ chain 1000 0.05 20 [0, 0] target
chain
  :: (PrimMonad m, Traversable f
     , FunctorWithIndex (Index (f Double)) f, Ixed (f Double)
     , IxValue (f Double) ~ Double)
  => Int
  -> Double
  -> Int
  -> f Double
  -> Target (f Double)
  -> Gen (PrimState m)
  -> m [Chain (f Double) b]
chain :: Int
-> Double
-> Int
-> f Double
-> Target (f Double)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain Int
n Double
step Int
leaps f Double
position Target (f Double)
target Gen (PrimState m)
gen = Effect m [Chain (f Double) b] -> m [Chain (f Double) b]
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m [Chain (f Double) b] -> m [Chain (f Double) b])
-> Effect m [Chain (f Double) b] -> m [Chain (f Double) b]
forall a b. (a -> b) -> a -> b
$
        Double
-> Int
-> Chain (f Double) b
-> Gen (PrimState m)
-> Producer (Chain (f Double) b) m [Chain (f Double) b]
forall (t :: * -> *) (m :: * -> *) b c.
(Num (IxValue (t Double)), Traversable t,
 FunctorWithIndex (Index (t Double)) t, Ixed (t Double),
 PrimMonad m, IxValue (t Double) ~ Double) =>
Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive Double
step Int
leaps Chain (f Double) b
origin Gen (PrimState m)
gen
    Producer (Chain (f Double) b) m [Chain (f Double) b]
-> Proxy () (Chain (f Double) b) () X m [Chain (f Double) b]
-> Effect m [Chain (f Double) b]
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (f Double) b) () X m [Chain (f Double) b]
forall (m :: * -> *) a. Monad m => Int -> Consumer a m [a]
collect Int
n
  where
    origin :: Chain (f Double) b
origin = Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {
        chainScore :: Double
chainScore    = Target (f Double) -> f Double -> Double
forall a. Target a -> a -> Double
lTarget Target (f Double)
target f Double
position
      , chainTunables :: Maybe b
chainTunables = Maybe b
forall a. Maybe a
Nothing
      , chainTarget :: Target (f Double)
chainTarget   = Target (f Double)
target
      , chainPosition :: f Double
chainPosition = f Double
position
      }

    collect :: Monad m => Int -> Consumer a m [a]
    collect :: Int -> Consumer a m [a]
collect Int
size = Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall (f :: * -> *) a. Applicative f => Codensity f a -> f a
lowerCodensity (Codensity (Proxy () a () X m) [a] -> Consumer a m [a])
-> Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall a b. (a -> b) -> a -> b
$
      Int
-> Codensity (Proxy () a () X m) a
-> Codensity (Proxy () a () X m) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Proxy () a () X m a -> Codensity (Proxy () a () X m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Proxy () a () X m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
Pipes.await)

-- Drive a Markov chain.
drive
  :: (Num (IxValue (t Double)), Traversable t
     , FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
     , PrimMonad m, IxValue (t Double) ~ Double)
  => Double
  -> Int
  -> Chain (t Double) b
  -> Gen (PrimState m)
  -> Producer (Chain (t Double) b) m c
drive :: Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive Double
step Int
leaps = Chain (t Double) b
-> Gen (PrimState m) -> Producer (Chain (t Double) b) m c
loop where
  loop :: Chain (t Double) b
-> Gen (PrimState m) -> Producer (Chain (t Double) b) m c
loop Chain (t Double) b
state Gen (PrimState m)
prng = do
    Chain (t Double) b
next <- m (Chain (t Double) b)
-> Proxy X () () (Chain (t Double) b) m (Chain (t Double) b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (Chain (t Double) b)
-> Gen (PrimState m) -> m (Chain (t Double) b)
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.sample (StateT (Chain (t Double) b) (Prob m) ()
-> Chain (t Double) b -> Prob m (Chain (t Double) b)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (Double -> Int -> StateT (Chain (t Double) b) (Prob m) ()
forall (t :: * -> *) (m :: * -> *) b.
(Num (IxValue (t Double)), Traversable t,
 FunctorWithIndex (Index (t Double)) t, Ixed (t Double),
 PrimMonad m, IxValue (t Double) ~ Double) =>
Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian Double
step Int
leaps) Chain (t Double) b
state) Gen (PrimState m)
prng)
    Chain (t Double) b -> Proxy X () () (Chain (t Double) b) m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield Chain (t Double) b
next
    Chain (t Double) b
-> Gen (PrimState m) -> Producer (Chain (t Double) b) m c
loop Chain (t Double) b
next Gen (PrimState m)
prng

-- | A Hamiltonian transition operator.
hamiltonian
  :: (Num (IxValue (t Double)), Traversable t
     , FunctorWithIndex (Index (t Double)) t, Ixed (t Double), PrimMonad m
     , IxValue (t Double) ~ Double)
  => Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian :: Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian Double
e Int
l = do
  Chain {t Double
Double
Maybe b
Target (t Double)
chainTunables :: Maybe b
chainPosition :: t Double
chainScore :: Double
chainTarget :: Target (t Double)
chainTarget :: forall a b. Chain a b -> Target a
chainScore :: forall a b. Chain a b -> Double
chainPosition :: forall a b. Chain a b -> a
chainTunables :: forall a b. Chain a b -> Maybe b
..} <- StateT (Chain (t Double) b) (Prob m) (Chain (t Double) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  t Double
r0 <- Prob m (t Double)
-> StateT (Chain (t Double) b) (Prob m) (t Double)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (t Double -> (Double -> Prob m Double) -> Prob m (t Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for t Double
chainPosition (Prob m Double -> Double -> Prob m Double
forall a b. a -> b -> a
const Prob m Double
forall (m :: * -> *). PrimMonad m => Prob m Double
MWC.standardNormal))
  Double
zc <- Prob m Double -> StateT (Chain (t Double) b) (Prob m) Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). PrimMonad m => Prob m Double
forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
MWC.uniform :: PrimMonad m => Prob m Double)
  let (t Double
q, t Double
r) = Target (t Double)
-> Double
-> Int
-> (t Double, t (IxValue (t Double)))
-> (t Double, t (IxValue (t Double)))
forall (f :: * -> *) (t :: * -> *).
(Num (IxValue (f Double)), FunctorWithIndex (Index (f Double)) t,
 FunctorWithIndex (Index (t Double)) f, Ixed (f Double),
 Ixed (t Double), IxValue (f Double) ~ Double,
 IxValue (t Double) ~ Double) =>
Target (f Double)
-> Double
-> Int
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrogIntegrator Target (t Double)
chainTarget Double
e Int
l (t Double
chainPosition, t Double
t (IxValue (t Double))
r0)
      perturbed :: t Double
perturbed      = Target (t Double)
-> (t Double, t Double)
-> (t Double, t Double)
-> Double
-> t Double
forall (s :: * -> *) (t :: * -> *) b.
(Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t,
 FunctorWithIndex (Index (s Double)) s, Ixed (s Double),
 Ixed (t Double), IxValue (t Double) ~ Double,
 IxValue (s Double) ~ Double) =>
Target b -> (b, b) -> (s Double, t Double) -> Double -> b
nextState Target (t Double)
chainTarget (t Double
chainPosition, t Double
q) (t Double
r0, t Double
r) Double
zc
      perturbedScore :: Double
perturbedScore = Target (t Double) -> t Double -> Double
forall a. Target a -> a -> Double
lTarget Target (t Double)
chainTarget t Double
perturbed
  Chain (t Double) b -> Transition m (Chain (t Double) b)
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Target (t Double)
-> Double -> t Double -> Maybe b -> Chain (t Double) b
forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain Target (t Double)
chainTarget Double
perturbedScore t Double
perturbed Maybe b
chainTunables)

-- Calculate the next state of the chain.
nextState
  :: (Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t
     , FunctorWithIndex (Index (s Double)) s, Ixed (s Double)
     , Ixed (t Double), IxValue (t Double) ~ Double
     , IxValue (s Double) ~ Double)
  => Target b
  -> (b, b)
  -> (s Double, t Double)
  -> Double
  -> b
nextState :: Target b -> (b, b) -> (s Double, t Double) -> Double -> b
nextState Target b
target (b, b)
position (s Double, t Double)
momentum Double
z
    | Double
z Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
pAccept = (b, b) -> b
forall a b. (a, b) -> b
snd (b, b)
position
    | Bool
otherwise   = (b, b) -> b
forall a b. (a, b) -> a
fst (b, b)
position
  where
    pAccept :: Double
pAccept = Target b -> (b, b) -> (s Double, t Double) -> Double
forall (t :: * -> *) (s :: * -> *) a.
(Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t,
 FunctorWithIndex (Index (s Double)) s, Ixed (t Double),
 Ixed (s Double), IxValue (t Double) ~ Double,
 IxValue (s Double) ~ Double) =>
Target a -> (a, a) -> (s Double, t Double) -> Double
acceptProb Target b
target (b, b)
position (s Double, t Double)
momentum

-- Calculate the acceptance probability of a proposed moved.
acceptProb
  :: (Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t
     , FunctorWithIndex (Index (s Double)) s, Ixed (t Double)
     , Ixed (s Double), IxValue (t Double) ~ Double
     , IxValue (s Double) ~ Double)
  => Target a
  -> (a, a)
  -> (s Double, t Double)
  -> Double
acceptProb :: Target a -> (a, a) -> (s Double, t Double) -> Double
acceptProb Target a
target (a
q0, a
q1) (s Double
r0, t Double
r1) = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
0 (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$
  Target a -> (a, t Double) -> Double
forall (t :: * -> *) a.
(Foldable t, FunctorWithIndex (Index (t Double)) t,
 Ixed (t Double), IxValue (t Double) ~ Double) =>
Target a -> (a, t Double) -> Double
auxilliaryTarget Target a
target (a
q1, t Double
r1) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Target a -> (a, s Double) -> Double
forall (t :: * -> *) a.
(Foldable t, FunctorWithIndex (Index (t Double)) t,
 Ixed (t Double), IxValue (t Double) ~ Double) =>
Target a -> (a, t Double) -> Double
auxilliaryTarget Target a
target (a
q0, s Double
r0)

-- A momentum-augmented target.
auxilliaryTarget
  :: (Foldable t, FunctorWithIndex (Index (t Double)) t
     , Ixed (t Double), IxValue (t Double) ~ Double)
  => Target a
  -> (a, t Double)
  -> Double
auxilliaryTarget :: Target a -> (a, t Double) -> Double
auxilliaryTarget Target a
target (a
t, t Double
r) = a -> Double
f a
t Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* t (IxValue (t Double)) -> t Double -> IxValue (t Double)
forall s (t :: * -> *).
(Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t,
 Ixed s) =>
t (IxValue s) -> s -> IxValue s
innerProduct t Double
t (IxValue (t Double))
r t Double
r where
  f :: a -> Double
f = Target a -> a -> Double
forall a. Target a -> a -> Double
lTarget Target a
target

innerProduct
  :: (Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t, Ixed s)
  => t (IxValue s) -> s -> IxValue s
innerProduct :: t (IxValue s) -> s -> IxValue s
innerProduct t (IxValue s)
xs s
ys = t (IxValue s) -> IxValue s
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Foldable.sum (t (IxValue s) -> IxValue s) -> t (IxValue s) -> IxValue s
forall a b. (a -> b) -> a -> b
$ (IxValue s -> IxValue s -> IxValue s)
-> t (IxValue s) -> s -> t (IxValue s)
forall s (f :: * -> *) a b.
(FunctorWithIndex (Index s) f, Ixed s) =>
(a -> IxValue s -> b) -> f a -> s -> f b
gzipWith IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
(*) t (IxValue s)
xs s
ys

-- A container-generic zipwith.
gzipWith
  :: (FunctorWithIndex (Index s) f, Ixed s)
  => (a -> IxValue s -> b) -> f a -> s -> f b
gzipWith :: (a -> IxValue s -> b) -> f a -> s -> f b
gzipWith a -> IxValue s -> b
f f a
xs s
ys = (Index s -> a -> b) -> f a -> f b
forall i (f :: * -> *) a b.
FunctorWithIndex i f =>
(i -> a -> b) -> f a -> f b
imap (\Index s
j a
x -> a -> IxValue s -> b
f a
x (IxValue s -> Maybe (IxValue s) -> IxValue s
forall a. a -> Maybe a -> a
fromMaybe IxValue s
forall a. a
err (s
ys s -> Getting (First (IxValue s)) s (IxValue s) -> Maybe (IxValue s)
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
j))) f a
xs where
  err :: a
err = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"gzipWith: invalid index"

-- The leapfrog or Stormer-Verlet integrator.
leapfrogIntegrator
  :: (Num (IxValue (f Double))
     , FunctorWithIndex (Index (f Double)) t
     , FunctorWithIndex (Index (t Double)) f
     , Ixed (f Double), Ixed (t Double)
     , IxValue (f Double) ~ Double
     , IxValue (t Double) ~ Double)
  => Target (f Double)
  -> Double
  -> Int
  -> (f Double, t (IxValue (f Double)))
  -> (f Double, t (IxValue (f Double)))
leapfrogIntegrator :: Target (f Double)
-> Double
-> Int
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrogIntegrator Target (f Double)
target Double
e Int
l (f Double
q0, t (IxValue (f Double))
r0) = f Double -> t Double -> Int -> (f Double, t Double)
go f Double
q0 t Double
t (IxValue (f Double))
r0 Int
l where
  go :: f Double -> t Double -> Int -> (f Double, t Double)
go f Double
q t Double
r Int
0 = (f Double
q, t Double
r)
  go f Double
q t Double
r Int
n = f Double -> t Double -> Int -> (f Double, t Double)
go f Double
q1 t Double
r1 (Int -> Int
forall a. Enum a => a -> a
pred Int
n) where
    (f Double
q1, t Double
r1) = Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
forall (f :: * -> *) (t :: * -> *).
(Num (IxValue (f Double)), FunctorWithIndex (Index (f Double)) t,
 FunctorWithIndex (Index (t Double)) f, Ixed (t Double),
 Ixed (f Double), IxValue (f Double) ~ Double,
 IxValue (t Double) ~ Double) =>
Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrog Target (f Double)
target Double
e (f Double
q, t Double
t (IxValue (f Double))
r)

-- A single leapfrog step.
leapfrog
  :: (Num (IxValue (f Double))
     , FunctorWithIndex (Index (f Double)) t
     , FunctorWithIndex (Index (t Double)) f
     , Ixed (t Double), Ixed (f Double)
     , IxValue (f Double) ~ Double, IxValue (t Double) ~ Double)
  => Target (f Double)
  -> Double
  -> (f Double, t (IxValue (f Double)))
  -> (f Double, t (IxValue (f Double)))
leapfrog :: Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrog Target (f Double)
target Double
e (f Double
q, t (IxValue (f Double))
r) = (f Double
f (IxValue (t Double))
qf, t (IxValue (f Double))
rf) where
  rm :: t (IxValue (f Double))
rm = Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
forall (f :: * -> *) (t :: * -> *).
(Functor f, Num (IxValue (f Double)),
 FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) =>
Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum Target (f Double)
target Double
e (f Double
q, t (IxValue (f Double))
r)
  qf :: f (IxValue (t Double))
qf = Double
-> (t Double, f (IxValue (t Double))) -> f (IxValue (t Double))
forall (f :: * -> *) (t :: * -> *).
(Functor f, Num (IxValue (f Double)),
 FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) =>
Double
-> (f Double, t (IxValue (f Double))) -> t (IxValue (f Double))
adjustPosition Double
e (t Double
t (IxValue (f Double))
rm, f Double
f (IxValue (t Double))
q)
  rf :: t (IxValue (f Double))
rf = Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
forall (f :: * -> *) (t :: * -> *).
(Functor f, Num (IxValue (f Double)),
 FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) =>
Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum Target (f Double)
target Double
e (f Double
f (IxValue (t Double))
qf, t (IxValue (f Double))
rm)

adjustMomentum
  :: (Functor f, Num (IxValue (f Double))
     , FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
  => Target (f Double)
  -> Double
  -> (f Double, t (IxValue (f Double)))
  -> t (IxValue (f Double))
adjustMomentum :: Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum Target (f Double)
target Double
e (f Double
q, t (IxValue (f Double))
r) = t (IxValue (f Double))
r t (IxValue (f Double)) -> f Double -> t (IxValue (f Double))
forall t (f :: * -> *).
(Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t) =>
f (IxValue t) -> t -> f (IxValue t)
.+ ((Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
e) Double -> f Double -> f Double
forall a (f :: * -> *). (Num a, Functor f) => a -> f a -> f a
.* f Double -> f Double
g f Double
q) where
  g :: f Double -> f Double
g   = (f Double -> f Double)
-> Maybe (f Double -> f Double) -> f Double -> f Double
forall a. a -> Maybe a -> a
fromMaybe f Double -> f Double
forall a. a
err (Target (f Double) -> Maybe (f Double -> f Double)
forall a. Target a -> Maybe (a -> a)
glTarget Target (f Double)
target)
  err :: a
err = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"adjustMomentum: no gradient provided"

adjustPosition
  :: (Functor f, Num (IxValue (f Double))
     , FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
  => Double
  -> (f Double, t (IxValue (f Double)))
  -> t (IxValue (f Double))
adjustPosition :: Double
-> (f Double, t (IxValue (f Double))) -> t (IxValue (f Double))
adjustPosition Double
e (f Double
r, t (IxValue (f Double))
q) = t (IxValue (f Double))
q t (IxValue (f Double)) -> f Double -> t (IxValue (f Double))
forall t (f :: * -> *).
(Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t) =>
f (IxValue t) -> t -> f (IxValue t)
.+ (Double
e Double -> f Double -> f Double
forall a (f :: * -> *). (Num a, Functor f) => a -> f a -> f a
.* f Double
r)

-- Scalar-vector product.
(.*) :: (Num a, Functor f) => a -> f a -> f a
a
z .* :: a -> f a -> f a
.* f a
xs = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
z) f a
xs

-- Vector addition.
(.+)
  :: (Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t)
  => f (IxValue t)
  -> t
  -> f (IxValue t)
.+ :: f (IxValue t) -> t -> f (IxValue t)
(.+) = (IxValue t -> IxValue t -> IxValue t)
-> f (IxValue t) -> t -> f (IxValue t)
forall s (f :: * -> *) a b.
(FunctorWithIndex (Index s) f, Ixed s) =>
(a -> IxValue s -> b) -> f a -> s -> f b
gzipWith IxValue t -> IxValue t -> IxValue t
forall a. Num a => a -> a -> a
(+)