{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------

module Numeric.AD.Rank1.Newton
  (
  -- * Newton's Method (Forward)
    findZero
  , findZeroNoEq
  , inverse
  , inverseNoEq
  , fixedPoint
  , fixedPointNoEq
  , extremum
  , extremumNoEq
  -- * Gradient Ascent/Descent (Kahn)
  , gradientDescent
  , gradientAscent
  ) where

import Prelude hiding (all, mapM)
import Data.Foldable (all)
import Numeric.AD.Mode
import Numeric.AD.Rank1.Forward (Forward, diff, diff')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, gradWith')
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Combinators (takeWhileDifferent)

-- $setup
-- >>> import Data.Complex

-- | The 'findZero' function finds a zero of a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes constant
-- ("it converges"), no further elements are returned.
--
-- Examples:
--
-- >>> take 10 $ findZero (\x->x^2-4) 1
-- [1.0,2.5,2.05,2.000609756097561,2.0000000929222947,2.000000000000002,2.0]
--
-- >>> last $ take 10 $ findZero ((+1).(^2)) (1 :+ 1)
-- 0.0 :+ 1.0
findZero :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> [a]
findZero :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
findZero Forward a -> Forward a
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq Forward a -> Forward a
f
{-# INLINE findZero #-}

-- | The 'findZeroNoEq' function behaves the same as 'findZero' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
findZeroNoEq :: Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq Forward a -> Forward a
f = (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate a -> a
go where
  go :: a -> a
go a
x = a
xn where
    (a
y,a
y') = (Forward a -> Forward a) -> a -> (a, a)
forall a. Num a => (Forward a -> Forward a) -> a -> (a, a)
diff' Forward a -> Forward a
f a
x
    xn :: a
xn = a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
ya -> a -> a
forall a. Fractional a => a -> a -> a
/a
y'
{-# INLINE findZeroNoEq #-}

-- | The 'inverse' function inverts a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes
-- constant ("it converges"), no further elements are returned.
--
-- Example:
--
-- >>> last $ take 10 $ inverse sqrt 1 (sqrt 10)
-- 10.0
inverse :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> a -> [a]
inverse :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> a -> [a]
inverse Forward a -> Forward a
f a
x0 = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Forward a -> Forward a) -> a -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq Forward a -> Forward a
f a
x0
{-# INLINE inverse  #-}

-- | The 'inverseNoEq' function behaves the same as 'inverse' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
inverseNoEq :: Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq Forward a -> Forward a
f a
x0 a
y = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (\Forward a
x -> Forward a -> Forward a
f Forward a
x Forward a -> Forward a -> Forward a
forall a. Num a => a -> a -> a
- Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Forward a)
y) a
x0
{-# INLINE inverseNoEq #-}

-- | The 'fixedPoint' function find a fixedpoint of a scalar
-- function using Newton's method; its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- If the stream becomes constant ("it converges"), no further
-- elements are returned.
--
-- >>> last $ take 10 $ fixedPoint cos 1
-- 0.7390851332151607
fixedPoint :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> [a]
fixedPoint :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
fixedPoint Forward a -> Forward a
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq Forward a -> Forward a
f
{-# INLINE fixedPoint #-}

-- | The 'fixedPointNoEq' function behaves the same as 'fixedPoint' except that
-- it doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
fixedPointNoEq :: Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq Forward a -> Forward a
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (\Forward a
x -> Forward a -> Forward a
f Forward a
x Forward a -> Forward a -> Forward a
forall a. Num a => a -> a -> a
- Forward a
x)
{-# INLINE fixedPointNoEq #-}

-- | The 'extremum' function finds an extremum of a scalar
-- function using Newton's method; produces a stream of increasingly
-- accurate results.  (Modulo the usual caveats.) If the stream
-- becomes constant ("it converges"), no further elements are returned.
--
-- >>> last $ take 10 $ extremum cos 1
-- 0.0
extremum :: (Fractional a, Eq a) => (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremum :: forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremum On (Forward (Forward a)) -> On (Forward (Forward a))
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq On (Forward (Forward a)) -> On (Forward (Forward a))
f
{-# INLINE extremum #-}

-- | The 'extremumNoEq' function behaves the same as 'extremum' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
extremumNoEq :: Fractional a => (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq :: forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq On (Forward (Forward a)) -> On (Forward (Forward a))
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq ((Forward (Forward a) -> Forward (Forward a))
-> Forward a -> Forward a
forall a. Num a => (Forward a -> Forward a) -> a -> a
diff (On (Forward (Forward a)) -> Forward (Forward a)
forall t. On t -> t
off (On (Forward (Forward a)) -> Forward (Forward a))
-> (Forward (Forward a) -> On (Forward (Forward a)))
-> Forward (Forward a)
-> Forward (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. On (Forward (Forward a)) -> On (Forward (Forward a))
f (On (Forward (Forward a)) -> On (Forward (Forward a)))
-> (Forward (Forward a) -> On (Forward (Forward a)))
-> Forward (Forward a)
-> On (Forward (Forward a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forward (Forward a) -> On (Forward (Forward a))
forall t. t -> On t
On))
{-# INLINE extremumNoEq #-}

-- | The 'gradientDescent' function performs a multivariate
-- optimization, based on the naive-gradient-descent in the file
-- @stalingrad\/examples\/flow-tests\/pre-saddle-1a.vlad@ from the
-- VLAD compiler Stalingrad sources.  Its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- It uses reverse mode automatic differentiation to compute the gradient.
gradientDescent :: (Traversable f, Fractional a, Ord a) => (f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent f (Kahn a) -> Kahn a
f f a
x0 = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x0 a
fx0 f (a, a)
xgx0 a
0.1 (Int
0 :: Int)
  where
    (a
fx0, f (a, a)
xgx0) = (a -> a -> (a, a))
-> (f (Kahn a) -> Kahn a) -> f a -> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b) -> (f (Kahn a) -> Kahn a) -> f a -> (a, f b)
Kahn.gradWith' (,) f (Kahn a) -> Kahn a
f f a
x0
    go :: f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx !a
eta !Int
i
      | a
eta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0     = [] -- step size is 0
      | a
fx1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
fx     = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx (a
etaa -> a -> a
forall a. Fractional a => a -> a -> a
/a
2) Int
0 -- we stepped too far
      | f (a, a) -> Bool
forall {a}. f (a, a) -> Bool
zeroGrad f (a, a)
xgx = [] -- gradient is 0
      | Bool
otherwise    = f a
x1 f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
10
                            then f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 (a
etaa -> a -> a
forall a. Num a => a -> a -> a
*a
2) Int
0
                            else f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 a
eta (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      where
        zeroGrad :: f (a, a) -> Bool
zeroGrad = ((a, a) -> Bool) -> f (a, a) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(a
_,a
g) -> a
g a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0)
        x1 :: f a
x1 = ((a, a) -> a) -> f (a, a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi,a
gxi) -> a
xi a -> a -> a
forall a. Num a => a -> a -> a
- a
eta a -> a -> a
forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
        (a
fx1, f (a, a)
xgx1) = (a -> a -> (a, a))
-> (f (Kahn a) -> Kahn a) -> f a -> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b) -> (f (Kahn a) -> Kahn a) -> f a -> (a, f b)
Kahn.gradWith' (,) f (Kahn a) -> Kahn a
f f a
x1
{-# INLINE gradientDescent #-}

-- | Perform a gradient descent using reverse mode automatic differentiation to compute the gradient.
gradientAscent :: (Traversable f, Fractional a, Ord a) => (f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientAscent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientAscent f (Kahn a) -> Kahn a
f = (f (Kahn a) -> Kahn a) -> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent (Kahn a -> Kahn a
forall a. Num a => a -> a
negate (Kahn a -> Kahn a)
-> (f (Kahn a) -> Kahn a) -> f (Kahn a) -> Kahn a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Kahn a) -> Kahn a
f)
{-# INLINE gradientAscent #-}