{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.AD.Newton.Double
(
findZero
, inverse
, fixedPoint
, extremum
, conjugateGradientDescent
, conjugateGradientAscent
) where
import Data.Foldable (all, sum)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.Forward.Double (ForwardDouble)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Rank1.Kahn.Double as Kahn (KahnDouble, grad)
import qualified Numeric.AD.Rank1.Newton.Double as Rank1
import Prelude hiding (all, mapM, sum)
findZero :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
findZero :: (forall s. AD s ForwardDouble -> AD s ForwardDouble)
-> Double -> [Double]
findZero forall s. AD s ForwardDouble -> AD s ForwardDouble
f = (ForwardDouble -> ForwardDouble) -> Double -> [Double]
Rank1.findZero (AD Any ForwardDouble -> ForwardDouble
forall s a. AD s a -> a
runAD(AD Any ForwardDouble -> ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any ForwardDouble -> AD Any ForwardDouble
forall s. AD s ForwardDouble -> AD s ForwardDouble
f(AD Any ForwardDouble -> AD Any ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> AD Any ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.ForwardDouble -> AD Any ForwardDouble
forall s a. a -> AD s a
AD)
{-# INLINE findZero #-}
inverse :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> Double -> [Double]
inverse :: (forall s. AD s ForwardDouble -> AD s ForwardDouble)
-> Double -> Double -> [Double]
inverse forall s. AD s ForwardDouble -> AD s ForwardDouble
f = (ForwardDouble -> ForwardDouble) -> Double -> Double -> [Double]
Rank1.inverse (AD Any ForwardDouble -> ForwardDouble
forall s a. AD s a -> a
runAD(AD Any ForwardDouble -> ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any ForwardDouble -> AD Any ForwardDouble
forall s. AD s ForwardDouble -> AD s ForwardDouble
f(AD Any ForwardDouble -> AD Any ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> AD Any ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.ForwardDouble -> AD Any ForwardDouble
forall s a. a -> AD s a
AD)
{-# INLINE inverse #-}
fixedPoint :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
fixedPoint :: (forall s. AD s ForwardDouble -> AD s ForwardDouble)
-> Double -> [Double]
fixedPoint forall s. AD s ForwardDouble -> AD s ForwardDouble
f = (ForwardDouble -> ForwardDouble) -> Double -> [Double]
Rank1.fixedPoint (AD Any ForwardDouble -> ForwardDouble
forall s a. AD s a -> a
runAD(AD Any ForwardDouble -> ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any ForwardDouble -> AD Any ForwardDouble
forall s. AD s ForwardDouble -> AD s ForwardDouble
f(AD Any ForwardDouble -> AD Any ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> AD Any ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.ForwardDouble -> AD Any ForwardDouble
forall s a. a -> AD s a
AD)
{-# INLINE fixedPoint #-}
extremum :: (forall s. AD s (On (Forward ForwardDouble)) -> AD s (On (Forward ForwardDouble))) -> Double -> [Double]
extremum :: (forall s.
AD s (On (Forward ForwardDouble))
-> AD s (On (Forward ForwardDouble)))
-> Double -> [Double]
extremum forall s.
AD s (On (Forward ForwardDouble))
-> AD s (On (Forward ForwardDouble))
f = (On (Forward ForwardDouble) -> On (Forward ForwardDouble))
-> Double -> [Double]
Rank1.extremum (AD Any (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall s a. AD s a -> a
runAD(AD Any (On (Forward ForwardDouble)) -> On (Forward ForwardDouble))
-> (On (Forward ForwardDouble)
-> AD Any (On (Forward ForwardDouble)))
-> On (Forward ForwardDouble)
-> On (Forward ForwardDouble)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (On (Forward ForwardDouble))
-> AD Any (On (Forward ForwardDouble))
forall s.
AD s (On (Forward ForwardDouble))
-> AD s (On (Forward ForwardDouble))
f(AD Any (On (Forward ForwardDouble))
-> AD Any (On (Forward ForwardDouble)))
-> (On (Forward ForwardDouble)
-> AD Any (On (Forward ForwardDouble)))
-> On (Forward ForwardDouble)
-> AD Any (On (Forward ForwardDouble))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.On (Forward ForwardDouble) -> AD Any (On (Forward ForwardDouble))
forall s a. a -> AD s a
AD)
{-# INLINE extremum #-}
conjugateGradientDescent
:: Traversable f
=> (forall s. Chosen s => f (Or s (On (Forward ForwardDouble)) KahnDouble) -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientDescent :: forall (f :: * -> *).
Traversable f =>
(forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientDescent forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f = (forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
forall (f :: * -> *).
Traversable f =>
(forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientAscent (Or s (On (Forward ForwardDouble)) KahnDouble
-> Or s (On (Forward ForwardDouble)) KahnDouble
forall a. Num a => a -> a
negate (Or s (On (Forward ForwardDouble)) KahnDouble
-> Or s (On (Forward ForwardDouble)) KahnDouble)
-> (f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f)
{-# INLINE conjugateGradientDescent #-}
lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F a b) -> Or F a b
f = Or F a b -> a
forall a b. Or F a b -> a
runL (Or F a b -> a) -> (f a -> Or F a b) -> f a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or F a b) -> Or F a b
f (f (Or F a b) -> Or F a b)
-> (f a -> f (Or F a b)) -> f a -> Or F a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Or F a b) -> f a -> f (Or F a b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Or F a b
forall a b. a -> Or F a b
L
rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T a b) -> Or T a b
f = Or T a b -> b
forall a b. Or T a b -> b
runR (Or T a b -> b) -> (f b -> Or T a b) -> f b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or T a b) -> Or T a b
f (f (Or T a b) -> Or T a b)
-> (f b -> f (Or T a b)) -> f b -> Or T a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Or T a b) -> f b -> f (Or T a b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap b -> Or T a b
forall b a. b -> Or T a b
R
conjugateGradientAscent
:: Traversable f
=> (forall s. Chosen s => f (Or s (On (Forward ForwardDouble)) KahnDouble) -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientAscent :: forall (f :: * -> *).
Traversable f =>
(forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientAscent forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f f Double
x0 = (f Double -> Bool) -> [f Double] -> [f Double]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile ((Double -> Bool) -> f Double -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\Double
a -> Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
a)) (f Double -> f Double -> f Double -> Double -> [f Double]
go f Double
x0 f Double
d0 f Double
d0 Double
delta0)
where
dot :: f a -> t a -> a
dot f a
x t a
y = t a -> a
forall a. Num a => t a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (t a -> a) -> t a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> t a -> t a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
forall a. Num a => a -> a -> a
(*) f a
x t a
y
d0 :: f Double
d0 = (f KahnDouble -> KahnDouble) -> f Double -> f Double
forall (f :: * -> *).
Traversable f =>
(f KahnDouble -> KahnDouble) -> f Double -> f Double
Kahn.grad ((f (Or T (On (Forward ForwardDouble)) KahnDouble)
-> Or T (On (Forward ForwardDouble)) KahnDouble)
-> f KahnDouble -> KahnDouble
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward ForwardDouble)) KahnDouble)
-> Or T (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f) f Double
x0
delta0 :: Double
delta0 = f Double -> f Double -> Double
forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f Double
d0 f Double
d0
go :: f Double -> f Double -> f Double -> Double -> [f Double]
go f Double
xi f Double
_ri f Double
di Double
deltai = f Double
xi f Double -> [f Double] -> [f Double]
forall a. a -> [a] -> [a]
: f Double -> f Double -> f Double -> Double -> [f Double]
go f Double
xi1 f Double
ri1 f Double
di1 Double
deltai1
where
ai :: Double
ai = [Double] -> Double
forall a. HasCallStack => [a] -> a
last ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ Int -> [Double] -> [Double]
forall a. Int -> [a] -> [a]
take Int
20 ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ (On (Forward ForwardDouble) -> On (Forward ForwardDouble))
-> Double -> [Double]
Rank1.extremum (\On (Forward ForwardDouble)
a -> (f (Or F (On (Forward ForwardDouble)) KahnDouble)
-> Or F (On (Forward ForwardDouble)) KahnDouble)
-> f (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F (On (Forward ForwardDouble)) KahnDouble)
-> Or F (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f (f (On (Forward ForwardDouble)) -> On (Forward ForwardDouble))
-> f (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> On (Forward ForwardDouble))
-> f Double -> f Double -> f (On (Forward ForwardDouble))
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\Double
x Double
d -> Scalar (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall t. Mode t => Scalar t -> t
auto Double
Scalar (On (Forward ForwardDouble))
x On (Forward ForwardDouble)
-> On (Forward ForwardDouble) -> On (Forward ForwardDouble)
forall a. Num a => a -> a -> a
+ On (Forward ForwardDouble)
a On (Forward ForwardDouble)
-> On (Forward ForwardDouble) -> On (Forward ForwardDouble)
forall a. Num a => a -> a -> a
* Scalar (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall t. Mode t => Scalar t -> t
auto Double
Scalar (On (Forward ForwardDouble))
d) f Double
xi f Double
di) Double
0
xi1 :: f Double
xi1 = (Double -> Double -> Double) -> f Double -> f Double -> f Double
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\Double
x Double
d -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
aiDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
d) f Double
xi f Double
di
ri1 :: f Double
ri1 = (f KahnDouble -> KahnDouble) -> f Double -> f Double
forall (f :: * -> *).
Traversable f =>
(f KahnDouble -> KahnDouble) -> f Double -> f Double
Kahn.grad ((f (Or T (On (Forward ForwardDouble)) KahnDouble)
-> Or T (On (Forward ForwardDouble)) KahnDouble)
-> f KahnDouble -> KahnDouble
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward ForwardDouble)) KahnDouble)
-> Or T (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f) f Double
xi1
deltai1 :: Double
deltai1 = f Double -> f Double -> Double
forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f Double
ri1 f Double
ri1
bi1 :: Double
bi1 = Double
deltai1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
deltai
di1 :: f Double
di1 = (Double -> Double -> Double) -> f Double -> f Double -> f Double
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\Double
r Double
d -> Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
bi1 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
d) f Double
ri1 f Double
di
{-# INLINE conjugateGradientAscent #-}