{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.MCMC.Hamiltonian (
mcmc
, chain
, hamiltonian
, Target(..)
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Lens hiding (index)
import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimState, PrimMonad)
import Control.Monad.Trans.State.Strict hiding (state)
import qualified Data.Foldable as Foldable (sum)
import Data.Maybe (fromMaybe)
import Data.Sampling.Types
import Data.Traversable (for)
import Pipes hiding (for, next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Prob, Gen)
import qualified System.Random.MWC.Probability as MWC
mcmc
:: ( MonadIO m, PrimMonad m
, Num (IxValue (t Double)), Show (t Double), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
, IxValue (t Double) ~ Double)
=> Int
-> Double
-> Int
-> t Double
-> Target (t Double)
-> Gen (PrimState m)
-> m ()
mcmc :: Int
-> Double
-> Int
-> t Double
-> Target (t Double)
-> Gen (PrimState m)
-> m ()
mcmc Int
n Double
step Int
leaps t Double
chainPosition Target (t Double)
chainTarget Gen (PrimState m)
gen = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$
Double
-> Int
-> Chain (t Double) Any
-> Gen (PrimState m)
-> Producer (Chain (t Double) Any) m ()
forall (t :: * -> *) (m :: * -> *) b c.
(Num (IxValue (t Double)), Traversable t,
FunctorWithIndex (Index (t Double)) t, Ixed (t Double),
PrimMonad m, IxValue (t Double) ~ Double) =>
Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive Double
step Int
leaps Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t Double
Double
Maybe Any
Target (t Double)
forall a. Maybe a
chainTarget :: Target (t Double)
chainScore :: Double
chainPosition :: t Double
chainTunables :: Maybe Any
chainTunables :: forall a. Maybe a
chainScore :: Double
chainTarget :: Target (t Double)
chainPosition :: t Double
..} Gen (PrimState m)
gen
Producer (Chain (t Double) Any) m ()
-> Proxy () (Chain (t Double) Any) () (Chain (t Double) Any) m ()
-> Producer (Chain (t Double) Any) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int
-> Proxy () (Chain (t Double) Any) () (Chain (t Double) Any) m ()
forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
Pipes.take Int
n
Producer (Chain (t Double) Any) m ()
-> Proxy () (Chain (t Double) Any) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Chain (t Double) Any -> m ())
-> Consumer' (Chain (t Double) Any) m ()
forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
Pipes.mapM_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (Chain (t Double) Any -> IO ()) -> Chain (t Double) Any -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain (t Double) Any -> IO ()
forall a. Show a => a -> IO ()
print)
where
chainScore :: Double
chainScore = Target (t Double) -> t Double -> Double
forall a. Target a -> a -> Double
lTarget Target (t Double)
chainTarget t Double
chainPosition
chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing
chain
:: (PrimMonad m, Traversable f
, FunctorWithIndex (Index (f Double)) f, Ixed (f Double)
, IxValue (f Double) ~ Double)
=> Int
-> Double
-> Int
-> f Double
-> Target (f Double)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain :: Int
-> Double
-> Int
-> f Double
-> Target (f Double)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain Int
n Double
step Int
leaps f Double
position Target (f Double)
target Gen (PrimState m)
gen = Effect m [Chain (f Double) b] -> m [Chain (f Double) b]
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m [Chain (f Double) b] -> m [Chain (f Double) b])
-> Effect m [Chain (f Double) b] -> m [Chain (f Double) b]
forall a b. (a -> b) -> a -> b
$
Double
-> Int
-> Chain (f Double) b
-> Gen (PrimState m)
-> Producer (Chain (f Double) b) m [Chain (f Double) b]
forall (t :: * -> *) (m :: * -> *) b c.
(Num (IxValue (t Double)), Traversable t,
FunctorWithIndex (Index (t Double)) t, Ixed (t Double),
PrimMonad m, IxValue (t Double) ~ Double) =>
Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive Double
step Int
leaps Chain (f Double) b
origin Gen (PrimState m)
gen
Producer (Chain (f Double) b) m [Chain (f Double) b]
-> Proxy () (Chain (f Double) b) () X m [Chain (f Double) b]
-> Effect m [Chain (f Double) b]
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (f Double) b) () X m [Chain (f Double) b]
forall (m :: * -> *) a. Monad m => Int -> Consumer a m [a]
collect Int
n
where
origin :: Chain (f Double) b
origin = Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {
chainScore :: Double
chainScore = Target (f Double) -> f Double -> Double
forall a. Target a -> a -> Double
lTarget Target (f Double)
target f Double
position
, chainTunables :: Maybe b
chainTunables = Maybe b
forall a. Maybe a
Nothing
, chainTarget :: Target (f Double)
chainTarget = Target (f Double)
target
, chainPosition :: f Double
chainPosition = f Double
position
}
collect :: Monad m => Int -> Consumer a m [a]
collect :: Int -> Consumer a m [a]
collect Int
size = Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall (f :: * -> *) a. Applicative f => Codensity f a -> f a
lowerCodensity (Codensity (Proxy () a () X m) [a] -> Consumer a m [a])
-> Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall a b. (a -> b) -> a -> b
$
Int
-> Codensity (Proxy () a () X m) a
-> Codensity (Proxy () a () X m) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Proxy () a () X m a -> Codensity (Proxy () a () X m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Proxy () a () X m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
Pipes.await)
drive
:: (Num (IxValue (t Double)), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
, PrimMonad m, IxValue (t Double) ~ Double)
=> Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive :: Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive Double
step Int
leaps = Chain (t Double) b
-> Gen (PrimState m) -> Producer (Chain (t Double) b) m c
loop where
loop :: Chain (t Double) b
-> Gen (PrimState m) -> Producer (Chain (t Double) b) m c
loop Chain (t Double) b
state Gen (PrimState m)
prng = do
Chain (t Double) b
next <- m (Chain (t Double) b)
-> Proxy X () () (Chain (t Double) b) m (Chain (t Double) b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (Chain (t Double) b)
-> Gen (PrimState m) -> m (Chain (t Double) b)
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.sample (StateT (Chain (t Double) b) (Prob m) ()
-> Chain (t Double) b -> Prob m (Chain (t Double) b)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (Double -> Int -> StateT (Chain (t Double) b) (Prob m) ()
forall (t :: * -> *) (m :: * -> *) b.
(Num (IxValue (t Double)), Traversable t,
FunctorWithIndex (Index (t Double)) t, Ixed (t Double),
PrimMonad m, IxValue (t Double) ~ Double) =>
Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian Double
step Int
leaps) Chain (t Double) b
state) Gen (PrimState m)
prng)
Chain (t Double) b -> Proxy X () () (Chain (t Double) b) m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield Chain (t Double) b
next
Chain (t Double) b
-> Gen (PrimState m) -> Producer (Chain (t Double) b) m c
loop Chain (t Double) b
next Gen (PrimState m)
prng
hamiltonian
:: (Num (IxValue (t Double)), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double), PrimMonad m
, IxValue (t Double) ~ Double)
=> Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian :: Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian Double
e Int
l = do
Chain {t Double
Double
Maybe b
Target (t Double)
chainTunables :: Maybe b
chainPosition :: t Double
chainScore :: Double
chainTarget :: Target (t Double)
chainTarget :: forall a b. Chain a b -> Target a
chainScore :: forall a b. Chain a b -> Double
chainPosition :: forall a b. Chain a b -> a
chainTunables :: forall a b. Chain a b -> Maybe b
..} <- StateT (Chain (t Double) b) (Prob m) (Chain (t Double) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
t Double
r0 <- Prob m (t Double)
-> StateT (Chain (t Double) b) (Prob m) (t Double)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (t Double -> (Double -> Prob m Double) -> Prob m (t Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for t Double
chainPosition (Prob m Double -> Double -> Prob m Double
forall a b. a -> b -> a
const Prob m Double
forall (m :: * -> *). PrimMonad m => Prob m Double
MWC.standardNormal))
Double
zc <- Prob m Double -> StateT (Chain (t Double) b) (Prob m) Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). PrimMonad m => Prob m Double
forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
MWC.uniform :: PrimMonad m => Prob m Double)
let (t Double
q, t Double
r) = Target (t Double)
-> Double
-> Int
-> (t Double, t (IxValue (t Double)))
-> (t Double, t (IxValue (t Double)))
forall (f :: * -> *) (t :: * -> *).
(Num (IxValue (f Double)), FunctorWithIndex (Index (f Double)) t,
FunctorWithIndex (Index (t Double)) f, Ixed (f Double),
Ixed (t Double), IxValue (f Double) ~ Double,
IxValue (t Double) ~ Double) =>
Target (f Double)
-> Double
-> Int
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrogIntegrator Target (t Double)
chainTarget Double
e Int
l (t Double
chainPosition, t Double
t (IxValue (t Double))
r0)
perturbed :: t Double
perturbed = Target (t Double)
-> (t Double, t Double)
-> (t Double, t Double)
-> Double
-> t Double
forall (s :: * -> *) (t :: * -> *) b.
(Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t,
FunctorWithIndex (Index (s Double)) s, Ixed (s Double),
Ixed (t Double), IxValue (t Double) ~ Double,
IxValue (s Double) ~ Double) =>
Target b -> (b, b) -> (s Double, t Double) -> Double -> b
nextState Target (t Double)
chainTarget (t Double
chainPosition, t Double
q) (t Double
r0, t Double
r) Double
zc
perturbedScore :: Double
perturbedScore = Target (t Double) -> t Double -> Double
forall a. Target a -> a -> Double
lTarget Target (t Double)
chainTarget t Double
perturbed
Chain (t Double) b -> Transition m (Chain (t Double) b)
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Target (t Double)
-> Double -> t Double -> Maybe b -> Chain (t Double) b
forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain Target (t Double)
chainTarget Double
perturbedScore t Double
perturbed Maybe b
chainTunables)
nextState
:: (Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t
, FunctorWithIndex (Index (s Double)) s, Ixed (s Double)
, Ixed (t Double), IxValue (t Double) ~ Double
, IxValue (s Double) ~ Double)
=> Target b
-> (b, b)
-> (s Double, t Double)
-> Double
-> b
nextState :: Target b -> (b, b) -> (s Double, t Double) -> Double -> b
nextState Target b
target (b, b)
position (s Double, t Double)
momentum Double
z
| Double
z Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
pAccept = (b, b) -> b
forall a b. (a, b) -> b
snd (b, b)
position
| Bool
otherwise = (b, b) -> b
forall a b. (a, b) -> a
fst (b, b)
position
where
pAccept :: Double
pAccept = Target b -> (b, b) -> (s Double, t Double) -> Double
forall (t :: * -> *) (s :: * -> *) a.
(Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t,
FunctorWithIndex (Index (s Double)) s, Ixed (t Double),
Ixed (s Double), IxValue (t Double) ~ Double,
IxValue (s Double) ~ Double) =>
Target a -> (a, a) -> (s Double, t Double) -> Double
acceptProb Target b
target (b, b)
position (s Double, t Double)
momentum
acceptProb
:: (Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t
, FunctorWithIndex (Index (s Double)) s, Ixed (t Double)
, Ixed (s Double), IxValue (t Double) ~ Double
, IxValue (s Double) ~ Double)
=> Target a
-> (a, a)
-> (s Double, t Double)
-> Double
acceptProb :: Target a -> (a, a) -> (s Double, t Double) -> Double
acceptProb Target a
target (a
q0, a
q1) (s Double
r0, t Double
r1) = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
0 (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$
Target a -> (a, t Double) -> Double
forall (t :: * -> *) a.
(Foldable t, FunctorWithIndex (Index (t Double)) t,
Ixed (t Double), IxValue (t Double) ~ Double) =>
Target a -> (a, t Double) -> Double
auxilliaryTarget Target a
target (a
q1, t Double
r1) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Target a -> (a, s Double) -> Double
forall (t :: * -> *) a.
(Foldable t, FunctorWithIndex (Index (t Double)) t,
Ixed (t Double), IxValue (t Double) ~ Double) =>
Target a -> (a, t Double) -> Double
auxilliaryTarget Target a
target (a
q0, s Double
r0)
auxilliaryTarget
:: (Foldable t, FunctorWithIndex (Index (t Double)) t
, Ixed (t Double), IxValue (t Double) ~ Double)
=> Target a
-> (a, t Double)
-> Double
auxilliaryTarget :: Target a -> (a, t Double) -> Double
auxilliaryTarget Target a
target (a
t, t Double
r) = a -> Double
f a
t Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* t (IxValue (t Double)) -> t Double -> IxValue (t Double)
forall s (t :: * -> *).
(Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t,
Ixed s) =>
t (IxValue s) -> s -> IxValue s
innerProduct t Double
t (IxValue (t Double))
r t Double
r where
f :: a -> Double
f = Target a -> a -> Double
forall a. Target a -> a -> Double
lTarget Target a
target
innerProduct
:: (Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t, Ixed s)
=> t (IxValue s) -> s -> IxValue s
innerProduct :: t (IxValue s) -> s -> IxValue s
innerProduct t (IxValue s)
xs s
ys = t (IxValue s) -> IxValue s
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Foldable.sum (t (IxValue s) -> IxValue s) -> t (IxValue s) -> IxValue s
forall a b. (a -> b) -> a -> b
$ (IxValue s -> IxValue s -> IxValue s)
-> t (IxValue s) -> s -> t (IxValue s)
forall s (f :: * -> *) a b.
(FunctorWithIndex (Index s) f, Ixed s) =>
(a -> IxValue s -> b) -> f a -> s -> f b
gzipWith IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
(*) t (IxValue s)
xs s
ys
gzipWith
:: (FunctorWithIndex (Index s) f, Ixed s)
=> (a -> IxValue s -> b) -> f a -> s -> f b
gzipWith :: (a -> IxValue s -> b) -> f a -> s -> f b
gzipWith a -> IxValue s -> b
f f a
xs s
ys = (Index s -> a -> b) -> f a -> f b
forall i (f :: * -> *) a b.
FunctorWithIndex i f =>
(i -> a -> b) -> f a -> f b
imap (\Index s
j a
x -> a -> IxValue s -> b
f a
x (IxValue s -> Maybe (IxValue s) -> IxValue s
forall a. a -> Maybe a -> a
fromMaybe IxValue s
forall a. a
err (s
ys s -> Getting (First (IxValue s)) s (IxValue s) -> Maybe (IxValue s)
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
j))) f a
xs where
err :: a
err = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"gzipWith: invalid index"
leapfrogIntegrator
:: (Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t
, FunctorWithIndex (Index (t Double)) f
, Ixed (f Double), Ixed (t Double)
, IxValue (f Double) ~ Double
, IxValue (t Double) ~ Double)
=> Target (f Double)
-> Double
-> Int
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrogIntegrator :: Target (f Double)
-> Double
-> Int
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrogIntegrator Target (f Double)
target Double
e Int
l (f Double
q0, t (IxValue (f Double))
r0) = f Double -> t Double -> Int -> (f Double, t Double)
go f Double
q0 t Double
t (IxValue (f Double))
r0 Int
l where
go :: f Double -> t Double -> Int -> (f Double, t Double)
go f Double
q t Double
r Int
0 = (f Double
q, t Double
r)
go f Double
q t Double
r Int
n = f Double -> t Double -> Int -> (f Double, t Double)
go f Double
q1 t Double
r1 (Int -> Int
forall a. Enum a => a -> a
pred Int
n) where
(f Double
q1, t Double
r1) = Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
forall (f :: * -> *) (t :: * -> *).
(Num (IxValue (f Double)), FunctorWithIndex (Index (f Double)) t,
FunctorWithIndex (Index (t Double)) f, Ixed (t Double),
Ixed (f Double), IxValue (f Double) ~ Double,
IxValue (t Double) ~ Double) =>
Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrog Target (f Double)
target Double
e (f Double
q, t Double
t (IxValue (f Double))
r)
leapfrog
:: (Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t
, FunctorWithIndex (Index (t Double)) f
, Ixed (t Double), Ixed (f Double)
, IxValue (f Double) ~ Double, IxValue (t Double) ~ Double)
=> Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrog :: Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrog Target (f Double)
target Double
e (f Double
q, t (IxValue (f Double))
r) = (f Double
f (IxValue (t Double))
qf, t (IxValue (f Double))
rf) where
rm :: t (IxValue (f Double))
rm = Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
forall (f :: * -> *) (t :: * -> *).
(Functor f, Num (IxValue (f Double)),
FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) =>
Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum Target (f Double)
target Double
e (f Double
q, t (IxValue (f Double))
r)
qf :: f (IxValue (t Double))
qf = Double
-> (t Double, f (IxValue (t Double))) -> f (IxValue (t Double))
forall (f :: * -> *) (t :: * -> *).
(Functor f, Num (IxValue (f Double)),
FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) =>
Double
-> (f Double, t (IxValue (f Double))) -> t (IxValue (f Double))
adjustPosition Double
e (t Double
t (IxValue (f Double))
rm, f Double
f (IxValue (t Double))
q)
rf :: t (IxValue (f Double))
rf = Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
forall (f :: * -> *) (t :: * -> *).
(Functor f, Num (IxValue (f Double)),
FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) =>
Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum Target (f Double)
target Double
e (f Double
f (IxValue (t Double))
qf, t (IxValue (f Double))
rm)
adjustMomentum
:: (Functor f, Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
=> Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum :: Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum Target (f Double)
target Double
e (f Double
q, t (IxValue (f Double))
r) = t (IxValue (f Double))
r t (IxValue (f Double)) -> f Double -> t (IxValue (f Double))
forall t (f :: * -> *).
(Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t) =>
f (IxValue t) -> t -> f (IxValue t)
.+ ((Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
e) Double -> f Double -> f Double
forall a (f :: * -> *). (Num a, Functor f) => a -> f a -> f a
.* f Double -> f Double
g f Double
q) where
g :: f Double -> f Double
g = (f Double -> f Double)
-> Maybe (f Double -> f Double) -> f Double -> f Double
forall a. a -> Maybe a -> a
fromMaybe f Double -> f Double
forall a. a
err (Target (f Double) -> Maybe (f Double -> f Double)
forall a. Target a -> Maybe (a -> a)
glTarget Target (f Double)
target)
err :: a
err = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"adjustMomentum: no gradient provided"
adjustPosition
:: (Functor f, Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
=> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustPosition :: Double
-> (f Double, t (IxValue (f Double))) -> t (IxValue (f Double))
adjustPosition Double
e (f Double
r, t (IxValue (f Double))
q) = t (IxValue (f Double))
q t (IxValue (f Double)) -> f Double -> t (IxValue (f Double))
forall t (f :: * -> *).
(Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t) =>
f (IxValue t) -> t -> f (IxValue t)
.+ (Double
e Double -> f Double -> f Double
forall a (f :: * -> *). (Num a, Functor f) => a -> f a -> f a
.* f Double
r)
(.*) :: (Num a, Functor f) => a -> f a -> f a
a
z .* :: a -> f a -> f a
.* f a
xs = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
z) f a
xs
(.+)
:: (Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t)
=> f (IxValue t)
-> t
-> f (IxValue t)
.+ :: f (IxValue t) -> t -> f (IxValue t)
(.+) = (IxValue t -> IxValue t -> IxValue t)
-> f (IxValue t) -> t -> f (IxValue t)
forall s (f :: * -> *) a b.
(FunctorWithIndex (Index s) f, Ixed s) =>
(a -> IxValue s -> b) -> f a -> s -> f b
gzipWith IxValue t -> IxValue t -> IxValue t
forall a. Num a => a -> a -> a
(+)