{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2015-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------

module Numeric.AD.Newton.Double
  (
  -- * Newton's Method (Forward AD)
    findZero
  , inverse
  , fixedPoint
  , extremum
  -- * Gradient Ascent/Descent (Reverse AD)
  , conjugateGradientDescent
  , conjugateGradientAscent
  ) where

import Data.Foldable (all, sum)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.Forward.Double (ForwardDouble)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Rank1.Kahn.Double as Kahn (KahnDouble, grad)
import qualified Numeric.AD.Rank1.Newton.Double as Rank1
import Prelude hiding (all, mapM, sum)

-- | The 'findZero' function finds a zero of a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes constant
-- ("it converges"), no further elements are returned.
--
-- Examples:
--
-- >>> take 10 $ findZero (\x->x^2-4) 1
-- [1.0,2.5,2.05,2.000609756097561,2.0000000929222947,2.000000000000002,2.0]
findZero :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
findZero :: (forall s. AD s ForwardDouble -> AD s ForwardDouble)
-> Double -> [Double]
findZero forall s. AD s ForwardDouble -> AD s ForwardDouble
f = (ForwardDouble -> ForwardDouble) -> Double -> [Double]
Rank1.findZero (AD Any ForwardDouble -> ForwardDouble
forall s a. AD s a -> a
runAD(AD Any ForwardDouble -> ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any ForwardDouble -> AD Any ForwardDouble
forall s. AD s ForwardDouble -> AD s ForwardDouble
f(AD Any ForwardDouble -> AD Any ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> AD Any ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.ForwardDouble -> AD Any ForwardDouble
forall s a. a -> AD s a
AD)
{-# INLINE findZero #-}

-- | The 'inverse' function inverts a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes
-- constant ("it converges"), no further elements are returned.
--
-- Example:
--
-- >>> last $ take 10 $ inverse sqrt 1 (sqrt 10)
-- 10.0
inverse :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> Double -> [Double]
inverse :: (forall s. AD s ForwardDouble -> AD s ForwardDouble)
-> Double -> Double -> [Double]
inverse forall s. AD s ForwardDouble -> AD s ForwardDouble
f = (ForwardDouble -> ForwardDouble) -> Double -> Double -> [Double]
Rank1.inverse (AD Any ForwardDouble -> ForwardDouble
forall s a. AD s a -> a
runAD(AD Any ForwardDouble -> ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any ForwardDouble -> AD Any ForwardDouble
forall s. AD s ForwardDouble -> AD s ForwardDouble
f(AD Any ForwardDouble -> AD Any ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> AD Any ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.ForwardDouble -> AD Any ForwardDouble
forall s a. a -> AD s a
AD)
{-# INLINE inverse  #-}

-- | The 'fixedPoint' function find a fixedpoint of a scalar
-- function using Newton's method; its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- If the stream becomes constant ("it converges"), no further
-- elements are returned.
--
-- >>> last $ take 10 $ fixedPoint cos 1
-- 0.7390851332151607
fixedPoint :: (forall s. AD s ForwardDouble -> AD s ForwardDouble) -> Double -> [Double]
fixedPoint :: (forall s. AD s ForwardDouble -> AD s ForwardDouble)
-> Double -> [Double]
fixedPoint forall s. AD s ForwardDouble -> AD s ForwardDouble
f = (ForwardDouble -> ForwardDouble) -> Double -> [Double]
Rank1.fixedPoint (AD Any ForwardDouble -> ForwardDouble
forall s a. AD s a -> a
runAD(AD Any ForwardDouble -> ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any ForwardDouble -> AD Any ForwardDouble
forall s. AD s ForwardDouble -> AD s ForwardDouble
f(AD Any ForwardDouble -> AD Any ForwardDouble)
-> (ForwardDouble -> AD Any ForwardDouble)
-> ForwardDouble
-> AD Any ForwardDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
.ForwardDouble -> AD Any ForwardDouble
forall s a. a -> AD s a
AD)
{-# INLINE fixedPoint #-}

-- | The 'extremum' function finds an extremum of a scalar
-- function using Newton's method; produces a stream of increasingly
-- accurate results.  (Modulo the usual caveats.) If the stream
-- becomes constant ("it converges"), no further elements are returned.
--
-- >>> last $ take 10 $ extremum cos 1
-- 0.0
extremum :: (forall s. AD s (On (Forward ForwardDouble)) -> AD s (On (Forward ForwardDouble))) -> Double -> [Double]
extremum :: (forall s.
 AD s (On (Forward ForwardDouble))
 -> AD s (On (Forward ForwardDouble)))
-> Double -> [Double]
extremum forall s.
AD s (On (Forward ForwardDouble))
-> AD s (On (Forward ForwardDouble))
f = (On (Forward ForwardDouble) -> On (Forward ForwardDouble))
-> Double -> [Double]
Rank1.extremum (AD Any (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall s a. AD s a -> a
runAD(AD Any (On (Forward ForwardDouble)) -> On (Forward ForwardDouble))
-> (On (Forward ForwardDouble)
    -> AD Any (On (Forward ForwardDouble)))
-> On (Forward ForwardDouble)
-> On (Forward ForwardDouble)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (On (Forward ForwardDouble))
-> AD Any (On (Forward ForwardDouble))
forall s.
AD s (On (Forward ForwardDouble))
-> AD s (On (Forward ForwardDouble))
f(AD Any (On (Forward ForwardDouble))
 -> AD Any (On (Forward ForwardDouble)))
-> (On (Forward ForwardDouble)
    -> AD Any (On (Forward ForwardDouble)))
-> On (Forward ForwardDouble)
-> AD Any (On (Forward ForwardDouble))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.On (Forward ForwardDouble) -> AD Any (On (Forward ForwardDouble))
forall s a. a -> AD s a
AD)
{-# INLINE extremum #-}

-- | Perform a conjugate gradient descent using reverse mode automatic differentiation to compute the gradient, and using forward-on-forward mode for computing extrema.
--
-- >>> let sq x = x * x
-- >>> let rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x)
-- >>> rosenbrock [0,0]
-- 1
-- >>> rosenbrock (conjugateGradientDescent rosenbrock [0, 0] !! 5) < 0.1
-- True
conjugateGradientDescent
  :: Traversable f
  => (forall s. Chosen s => f (Or s (On (Forward ForwardDouble)) KahnDouble) -> Or s (On (Forward ForwardDouble)) KahnDouble)
  -> f Double -> [f Double]
conjugateGradientDescent :: forall (f :: * -> *).
Traversable f =>
(forall s.
 Chosen s =>
 f (Or s (On (Forward ForwardDouble)) KahnDouble)
 -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientDescent forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f = (forall s.
 Chosen s =>
 f (Or s (On (Forward ForwardDouble)) KahnDouble)
 -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
forall (f :: * -> *).
Traversable f =>
(forall s.
 Chosen s =>
 f (Or s (On (Forward ForwardDouble)) KahnDouble)
 -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientAscent (Or s (On (Forward ForwardDouble)) KahnDouble
-> Or s (On (Forward ForwardDouble)) KahnDouble
forall a. Num a => a -> a
negate (Or s (On (Forward ForwardDouble)) KahnDouble
 -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> (f (Or s (On (Forward ForwardDouble)) KahnDouble)
    -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f)
{-# INLINE conjugateGradientDescent #-}

lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F a b) -> Or F a b
f = Or F a b -> a
forall a b. Or F a b -> a
runL (Or F a b -> a) -> (f a -> Or F a b) -> f a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or F a b) -> Or F a b
f (f (Or F a b) -> Or F a b)
-> (f a -> f (Or F a b)) -> f a -> Or F a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Or F a b) -> f a -> f (Or F a b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Or F a b
forall a b. a -> Or F a b
L

rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T a b) -> Or T a b
f = Or T a b -> b
forall a b. Or T a b -> b
runR (Or T a b -> b) -> (f b -> Or T a b) -> f b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or T a b) -> Or T a b
f (f (Or T a b) -> Or T a b)
-> (f b -> f (Or T a b)) -> f b -> Or T a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Or T a b) -> f b -> f (Or T a b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap b -> Or T a b
forall b a. b -> Or T a b
R

-- | Perform a conjugate gradient ascent using reverse mode automatic differentiation to compute the gradient.
conjugateGradientAscent
  :: Traversable f
  => (forall s. Chosen s => f (Or s (On (Forward ForwardDouble)) KahnDouble) -> Or s (On (Forward ForwardDouble)) KahnDouble)
  -> f Double -> [f Double]
conjugateGradientAscent :: forall (f :: * -> *).
Traversable f =>
(forall s.
 Chosen s =>
 f (Or s (On (Forward ForwardDouble)) KahnDouble)
 -> Or s (On (Forward ForwardDouble)) KahnDouble)
-> f Double -> [f Double]
conjugateGradientAscent forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f f Double
x0 = (f Double -> Bool) -> [f Double] -> [f Double]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile ((Double -> Bool) -> f Double -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\Double
a -> Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
a)) (f Double -> f Double -> f Double -> Double -> [f Double]
go f Double
x0 f Double
d0 f Double
d0 Double
delta0)
  where
    dot :: f a -> t a -> a
dot f a
x t a
y = t a -> a
forall a. Num a => t a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (t a -> a) -> t a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> t a -> t a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
forall a. Num a => a -> a -> a
(*) f a
x t a
y
    d0 :: f Double
d0 = (f KahnDouble -> KahnDouble) -> f Double -> f Double
forall (f :: * -> *).
Traversable f =>
(f KahnDouble -> KahnDouble) -> f Double -> f Double
Kahn.grad ((f (Or T (On (Forward ForwardDouble)) KahnDouble)
 -> Or T (On (Forward ForwardDouble)) KahnDouble)
-> f KahnDouble -> KahnDouble
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward ForwardDouble)) KahnDouble)
-> Or T (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f) f Double
x0
    delta0 :: Double
delta0 = f Double -> f Double -> Double
forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f Double
d0 f Double
d0
    go :: f Double -> f Double -> f Double -> Double -> [f Double]
go f Double
xi f Double
_ri f Double
di Double
deltai = f Double
xi f Double -> [f Double] -> [f Double]
forall a. a -> [a] -> [a]
: f Double -> f Double -> f Double -> Double -> [f Double]
go f Double
xi1 f Double
ri1 f Double
di1 Double
deltai1
      where
        ai :: Double
ai = [Double] -> Double
forall a. HasCallStack => [a] -> a
last ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ Int -> [Double] -> [Double]
forall a. Int -> [a] -> [a]
take Int
20 ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ (On (Forward ForwardDouble) -> On (Forward ForwardDouble))
-> Double -> [Double]
Rank1.extremum (\On (Forward ForwardDouble)
a -> (f (Or F (On (Forward ForwardDouble)) KahnDouble)
 -> Or F (On (Forward ForwardDouble)) KahnDouble)
-> f (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F (On (Forward ForwardDouble)) KahnDouble)
-> Or F (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f (f (On (Forward ForwardDouble)) -> On (Forward ForwardDouble))
-> f (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> On (Forward ForwardDouble))
-> f Double -> f Double -> f (On (Forward ForwardDouble))
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\Double
x Double
d -> Scalar (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall t. Mode t => Scalar t -> t
auto Double
Scalar (On (Forward ForwardDouble))
x On (Forward ForwardDouble)
-> On (Forward ForwardDouble) -> On (Forward ForwardDouble)
forall a. Num a => a -> a -> a
+ On (Forward ForwardDouble)
a On (Forward ForwardDouble)
-> On (Forward ForwardDouble) -> On (Forward ForwardDouble)
forall a. Num a => a -> a -> a
* Scalar (On (Forward ForwardDouble)) -> On (Forward ForwardDouble)
forall t. Mode t => Scalar t -> t
auto Double
Scalar (On (Forward ForwardDouble))
d) f Double
xi f Double
di) Double
0
        xi1 :: f Double
xi1 = (Double -> Double -> Double) -> f Double -> f Double -> f Double
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\Double
x Double
d -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
aiDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
d) f Double
xi f Double
di
        ri1 :: f Double
ri1 = (f KahnDouble -> KahnDouble) -> f Double -> f Double
forall (f :: * -> *).
Traversable f =>
(f KahnDouble -> KahnDouble) -> f Double -> f Double
Kahn.grad ((f (Or T (On (Forward ForwardDouble)) KahnDouble)
 -> Or T (On (Forward ForwardDouble)) KahnDouble)
-> f KahnDouble -> KahnDouble
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward ForwardDouble)) KahnDouble)
-> Or T (On (Forward ForwardDouble)) KahnDouble
forall s.
Chosen s =>
f (Or s (On (Forward ForwardDouble)) KahnDouble)
-> Or s (On (Forward ForwardDouble)) KahnDouble
f) f Double
xi1
        deltai1 :: Double
deltai1 = f Double -> f Double -> Double
forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f Double
ri1 f Double
ri1
        bi1 :: Double
bi1 = Double
deltai1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
deltai
        di1 :: f Double
di1 = (Double -> Double -> Double) -> f Double -> f Double -> f Double
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\Double
r Double
d -> Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
bi1 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
d) f Double
ri1 f Double
di
{-# INLINE conjugateGradientAscent #-}