{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.AD.Rank1.Newton
(
findZero
, findZeroNoEq
, inverse
, inverseNoEq
, fixedPoint
, fixedPointNoEq
, extremum
, extremumNoEq
, gradientDescent
, gradientAscent
) where
import Prelude hiding (all, mapM)
import Data.Foldable (all)
import Numeric.AD.Mode
import Numeric.AD.Rank1.Forward (Forward, diff, diff')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, gradWith')
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Combinators (takeWhileDifferent)
findZero :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> [a]
findZero :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
findZero Forward a -> Forward a
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq Forward a -> Forward a
f
{-# INLINE findZero #-}
findZeroNoEq :: Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq Forward a -> Forward a
f = (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate a -> a
go where
go :: a -> a
go a
x = a
xn where
(a
y,a
y') = (Forward a -> Forward a) -> a -> (a, a)
forall a. Num a => (Forward a -> Forward a) -> a -> (a, a)
diff' Forward a -> Forward a
f a
x
xn :: a
xn = a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
ya -> a -> a
forall a. Fractional a => a -> a -> a
/a
y'
{-# INLINE findZeroNoEq #-}
inverse :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> a -> [a]
inverse :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> a -> [a]
inverse Forward a -> Forward a
f a
x0 = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Forward a -> Forward a) -> a -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq Forward a -> Forward a
f a
x0
{-# INLINE inverse #-}
inverseNoEq :: Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq Forward a -> Forward a
f a
x0 a
y = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (\Forward a
x -> Forward a -> Forward a
f Forward a
x Forward a -> Forward a -> Forward a
forall a. Num a => a -> a -> a
- Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Forward a)
y) a
x0
{-# INLINE inverseNoEq #-}
fixedPoint :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> [a]
fixedPoint :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
fixedPoint Forward a -> Forward a
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq Forward a -> Forward a
f
{-# INLINE fixedPoint #-}
fixedPointNoEq :: Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq Forward a -> Forward a
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (\Forward a
x -> Forward a -> Forward a
f Forward a
x Forward a -> Forward a -> Forward a
forall a. Num a => a -> a -> a
- Forward a
x)
{-# INLINE fixedPointNoEq #-}
extremum :: (Fractional a, Eq a) => (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremum :: forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremum On (Forward (Forward a)) -> On (Forward (Forward a))
f = [a] -> [a]
forall a. Eq a => [a] -> [a]
takeWhileDifferent ([a] -> [a]) -> (a -> [a]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq On (Forward (Forward a)) -> On (Forward (Forward a))
f
{-# INLINE extremum #-}
extremumNoEq :: Fractional a => (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq :: forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq On (Forward (Forward a)) -> On (Forward (Forward a))
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq ((Forward (Forward a) -> Forward (Forward a))
-> Forward a -> Forward a
forall a. Num a => (Forward a -> Forward a) -> a -> a
diff (On (Forward (Forward a)) -> Forward (Forward a)
forall t. On t -> t
off (On (Forward (Forward a)) -> Forward (Forward a))
-> (Forward (Forward a) -> On (Forward (Forward a)))
-> Forward (Forward a)
-> Forward (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. On (Forward (Forward a)) -> On (Forward (Forward a))
f (On (Forward (Forward a)) -> On (Forward (Forward a)))
-> (Forward (Forward a) -> On (Forward (Forward a)))
-> Forward (Forward a)
-> On (Forward (Forward a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forward (Forward a) -> On (Forward (Forward a))
forall t. t -> On t
On))
{-# INLINE extremumNoEq #-}
gradientDescent :: (Traversable f, Fractional a, Ord a) => (f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent f (Kahn a) -> Kahn a
f f a
x0 = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x0 a
fx0 f (a, a)
xgx0 a
0.1 (Int
0 :: Int)
where
(a
fx0, f (a, a)
xgx0) = (a -> a -> (a, a))
-> (f (Kahn a) -> Kahn a) -> f a -> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b) -> (f (Kahn a) -> Kahn a) -> f a -> (a, f b)
Kahn.gradWith' (,) f (Kahn a) -> Kahn a
f f a
x0
go :: f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx !a
eta !Int
i
| a
eta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = []
| a
fx1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
fx = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx (a
etaa -> a -> a
forall a. Fractional a => a -> a -> a
/a
2) Int
0
| f (a, a) -> Bool
forall {a}. f (a, a) -> Bool
zeroGrad f (a, a)
xgx = []
| Bool
otherwise = f a
x1 f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
10
then f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 (a
etaa -> a -> a
forall a. Num a => a -> a -> a
*a
2) Int
0
else f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 a
eta (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
where
zeroGrad :: f (a, a) -> Bool
zeroGrad = ((a, a) -> Bool) -> f (a, a) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(a
_,a
g) -> a
g a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0)
x1 :: f a
x1 = ((a, a) -> a) -> f (a, a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi,a
gxi) -> a
xi a -> a -> a
forall a. Num a => a -> a -> a
- a
eta a -> a -> a
forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
(a
fx1, f (a, a)
xgx1) = (a -> a -> (a, a))
-> (f (Kahn a) -> Kahn a) -> f a -> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b) -> (f (Kahn a) -> Kahn a) -> f a -> (a, f b)
Kahn.gradWith' (,) f (Kahn a) -> Kahn a
f f a
x1
{-# INLINE gradientDescent #-}
gradientAscent :: (Traversable f, Fractional a, Ord a) => (f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientAscent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientAscent f (Kahn a) -> Kahn a
f = (f (Kahn a) -> Kahn a) -> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent (Kahn a -> Kahn a
forall a. Num a => a -> a
negate (Kahn a -> Kahn a)
-> (f (Kahn a) -> Kahn a) -> f (Kahn a) -> Kahn a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Kahn a) -> Kahn a
f)
{-# INLINE gradientAscent #-}