{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE ParallelListComp #-}
module Numeric.AD.Newton
(
findZero
, findZeroNoEq
, inverse
, inverseNoEq
, fixedPoint
, fixedPointNoEq
, extremum
, extremumNoEq
, gradientDescent, constrainedDescent, CC(..), eval
, gradientAscent
, conjugateGradientDescent
, conjugateGradientAscent
, stochasticGradientDescent
) where
import Data.Foldable (all, sum)
import Data.Reflection (Reifies)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Reverse (Reverse, Tape)
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Mode.Reverse as Reverse (gradWith, gradWith', grad')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, grad)
import qualified Numeric.AD.Rank1.Newton as Rank1
import Prelude hiding (all, mapM, sum)
findZero :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero :: forall a.
(Fractional a, Eq a) =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
Rank1.findZero (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE findZero #-}
findZeroNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq :: forall a.
Fractional a =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
Rank1.findZeroNoEq (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE findZeroNoEq #-}
inverse :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse :: forall a.
(Fractional a, Eq a) =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> a -> [a]
Rank1.inverse (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE inverse #-}
inverseNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq :: forall a.
Fractional a =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
Rank1.inverseNoEq (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE inverseNoEq #-}
fixedPoint :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint :: forall a.
(Fractional a, Eq a) =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
Rank1.fixedPoint (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE fixedPoint #-}
fixedPointNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq :: forall a.
Fractional a =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
Rank1.fixedPointNoEq (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE fixedPointNoEq #-}
extremum :: (Fractional a, Eq a) => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremum :: forall a.
(Fractional a, Eq a) =>
(forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a))))
-> a -> [a]
extremum forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f = (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremum (AD Any (On (Forward (Forward a))) -> On (Forward (Forward a))
forall s a. AD s a -> a
runAD(AD Any (On (Forward (Forward a))) -> On (Forward (Forward a)))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> On (Forward (Forward a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (On (Forward (Forward a)))
-> AD Any (On (Forward (Forward a)))
forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f(AD Any (On (Forward (Forward a)))
-> AD Any (On (Forward (Forward a))))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> AD Any (On (Forward (Forward a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.On (Forward (Forward a)) -> AD Any (On (Forward (Forward a)))
forall s a. a -> AD s a
AD)
{-# INLINE extremum #-}
extremumNoEq :: Fractional a => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremumNoEq :: forall a.
Fractional a =>
(forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a))))
-> a -> [a]
extremumNoEq forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f = (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremumNoEq (AD Any (On (Forward (Forward a))) -> On (Forward (Forward a))
forall s a. AD s a -> a
runAD(AD Any (On (Forward (Forward a))) -> On (Forward (Forward a)))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> On (Forward (Forward a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (On (Forward (Forward a)))
-> AD Any (On (Forward (Forward a)))
forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f(AD Any (On (Forward (Forward a)))
-> AD Any (On (Forward (Forward a))))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> AD Any (On (Forward (Forward a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.On (Forward (Forward a)) -> AD Any (On (Forward (Forward a)))
forall s a. a -> AD s a
AD)
{-# INLINE extremumNoEq #-}
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientDescent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x0 = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x0 a
fx0 f (a, a)
xgx0 a
0.1 (Int
0 :: Int)
where
(a
fx0, f (a, a)
xgx0) = (a -> a -> (a, a))
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a
f f a
x0
go :: f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx !a
eta !Int
i
| a
eta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = []
| a
fx1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
fx = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx (a
etaa -> a -> a
forall a. Fractional a => a -> a -> a
/a
2) Int
0
| f (a, a) -> Bool
forall {a}. f (a, a) -> Bool
zeroGrad f (a, a)
xgx = []
| Bool
otherwise = f a
x1 f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
10
then f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 (a
etaa -> a -> a
forall a. Num a => a -> a -> a
*a
2) Int
0
else f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 a
eta (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
where
zeroGrad :: f (a, a) -> Bool
zeroGrad = ((a, a) -> Bool) -> f (a, a) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(a
_,a
g) -> a
g a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0)
x1 :: f a
x1 = ((a, a) -> a) -> f (a, a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi,a
gxi) -> a
xi a -> a -> a
forall a. Num a => a -> a -> a
- a
eta a -> a -> a
forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
(a
fx1, f (a, a)
xgx1) = (a -> a -> (a, a))
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a
f f a
x1
{-# INLINE gradientDescent #-}
data SEnv (f :: * -> *) a = SEnv { forall (f :: * -> *) a. SEnv f a -> a
sValue :: a, forall (f :: * -> *) a. SEnv f a -> f a
origEnv :: f a }
deriving ((forall a b. (a -> b) -> SEnv f a -> SEnv f b)
-> (forall a b. a -> SEnv f b -> SEnv f a) -> Functor (SEnv f)
forall a b. a -> SEnv f b -> SEnv f a
forall a b. (a -> b) -> SEnv f a -> SEnv f b
forall (f :: * -> *) a b. Functor f => a -> SEnv f b -> SEnv f a
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> SEnv f a -> SEnv f b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> SEnv f a -> SEnv f b
fmap :: forall a b. (a -> b) -> SEnv f a -> SEnv f b
$c<$ :: forall (f :: * -> *) a b. Functor f => a -> SEnv f b -> SEnv f a
<$ :: forall a b. a -> SEnv f b -> SEnv f a
Functor, (forall m. Monoid m => SEnv f m -> m)
-> (forall m a. Monoid m => (a -> m) -> SEnv f a -> m)
-> (forall m a. Monoid m => (a -> m) -> SEnv f a -> m)
-> (forall a b. (a -> b -> b) -> b -> SEnv f a -> b)
-> (forall a b. (a -> b -> b) -> b -> SEnv f a -> b)
-> (forall b a. (b -> a -> b) -> b -> SEnv f a -> b)
-> (forall b a. (b -> a -> b) -> b -> SEnv f a -> b)
-> (forall a. (a -> a -> a) -> SEnv f a -> a)
-> (forall a. (a -> a -> a) -> SEnv f a -> a)
-> (forall a. SEnv f a -> [a])
-> (forall a. SEnv f a -> Bool)
-> (forall a. SEnv f a -> Int)
-> (forall a. Eq a => a -> SEnv f a -> Bool)
-> (forall a. Ord a => SEnv f a -> a)
-> (forall a. Ord a => SEnv f a -> a)
-> (forall a. Num a => SEnv f a -> a)
-> (forall a. Num a => SEnv f a -> a)
-> Foldable (SEnv f)
forall a. Eq a => a -> SEnv f a -> Bool
forall a. Num a => SEnv f a -> a
forall a. Ord a => SEnv f a -> a
forall m. Monoid m => SEnv f m -> m
forall a. SEnv f a -> Bool
forall a. SEnv f a -> Int
forall a. SEnv f a -> [a]
forall a. (a -> a -> a) -> SEnv f a -> a
forall m a. Monoid m => (a -> m) -> SEnv f a -> m
forall b a. (b -> a -> b) -> b -> SEnv f a -> b
forall a b. (a -> b -> b) -> b -> SEnv f a -> b
forall (f :: * -> *) a. (Foldable f, Eq a) => a -> SEnv f a -> Bool
forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
forall (f :: * -> *) m. (Foldable f, Monoid m) => SEnv f m -> m
forall (f :: * -> *) a. Foldable f => SEnv f a -> Bool
forall (f :: * -> *) a. Foldable f => SEnv f a -> Int
forall (f :: * -> *) a. Foldable f => SEnv f a -> [a]
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall (f :: * -> *) m. (Foldable f, Monoid m) => SEnv f m -> m
fold :: forall m. Monoid m => SEnv f m -> m
$cfoldMap :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> SEnv f a -> m
$cfoldMap' :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> SEnv f a -> m
$cfoldr :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
foldr :: forall a b. (a -> b -> b) -> b -> SEnv f a -> b
$cfoldr' :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> SEnv f a -> b
$cfoldl :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
foldl :: forall b a. (b -> a -> b) -> b -> SEnv f a -> b
$cfoldl' :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> SEnv f a -> b
$cfoldr1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
foldr1 :: forall a. (a -> a -> a) -> SEnv f a -> a
$cfoldl1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
foldl1 :: forall a. (a -> a -> a) -> SEnv f a -> a
$ctoList :: forall (f :: * -> *) a. Foldable f => SEnv f a -> [a]
toList :: forall a. SEnv f a -> [a]
$cnull :: forall (f :: * -> *) a. Foldable f => SEnv f a -> Bool
null :: forall a. SEnv f a -> Bool
$clength :: forall (f :: * -> *) a. Foldable f => SEnv f a -> Int
length :: forall a. SEnv f a -> Int
$celem :: forall (f :: * -> *) a. (Foldable f, Eq a) => a -> SEnv f a -> Bool
elem :: forall a. Eq a => a -> SEnv f a -> Bool
$cmaximum :: forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
maximum :: forall a. Ord a => SEnv f a -> a
$cminimum :: forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
minimum :: forall a. Ord a => SEnv f a -> a
$csum :: forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
sum :: forall a. Num a => SEnv f a -> a
$cproduct :: forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
product :: forall a. Num a => SEnv f a -> a
Foldable, Functor (SEnv f)
Foldable (SEnv f)
(Functor (SEnv f), Foldable (SEnv f)) =>
(forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SEnv f a -> f (SEnv f b))
-> (forall (f :: * -> *) a.
Applicative f =>
SEnv f (f a) -> f (SEnv f a))
-> (forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SEnv f a -> m (SEnv f b))
-> (forall (m :: * -> *) a.
Monad m =>
SEnv f (m a) -> m (SEnv f a))
-> Traversable (SEnv f)
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (f :: * -> *). Traversable f => Functor (SEnv f)
forall (f :: * -> *). Traversable f => Foldable (SEnv f)
forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
SEnv f (m a) -> m (SEnv f a)
forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
SEnv f (f a) -> f (SEnv f a)
forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
forall (m :: * -> *) a. Monad m => SEnv f (m a) -> m (SEnv f a)
forall (f :: * -> *) a.
Applicative f =>
SEnv f (f a) -> f (SEnv f a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
$ctraverse :: forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
$csequenceA :: forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
SEnv f (f a) -> f (SEnv f a)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
SEnv f (f a) -> f (SEnv f a)
$cmapM :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
$csequence :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
SEnv f (m a) -> m (SEnv f a)
sequence :: forall (m :: * -> *) a. Monad m => SEnv f (m a) -> m (SEnv f a)
Traversable)
data CC f a where
CC :: forall f a. (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> CC f a
constrainedDescent :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> (forall s. Reifies s Tape => f (Reverse s a)
-> Reverse s a)
-> [CC f a]
-> f a
-> [(a,f a)]
constrainedDescent :: forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> [CC f a] -> f a -> [(a, f a)]
constrainedDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF [] f a
env =
(f a -> (a, f a)) -> [f a] -> [(a, f a)]
forall a b. (a -> b) -> [a] -> [b]
map (\f a
x -> ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF f a
x, f a
x)) ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF f a
env)
constrainedDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF [CC f a]
cs f a
env =
let s0 :: a
s0 = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f a
env | CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c <- [CC f a]
cs]
cs' :: [CC (SEnv f) a]
cs' = [(forall s. Reifies s Tape => SEnv f (Reverse s a) -> Reverse s a)
-> CC (SEnv f) a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC (\(SEnv Reverse s a
sVal f (Reverse s a)
rest) -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
rest Reverse s a -> Reverse s a -> Reverse s a
forall a. Num a => a -> a -> a
- Reverse s a
sVal) | CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c <- [CC f a]
cs]
envS :: SEnv f a
envS = a -> f a -> SEnv f a
forall (f :: * -> *) a. a -> f a -> SEnv f a
SEnv a
s0 f a
env
cc :: [(a, SEnv f a)]
cc = CC (SEnv f) a
-> [CC (SEnv f) a]
-> SEnv f a
-> (SEnv f a -> Bool)
-> [(a, SEnv f a)]
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' ((forall s. Reifies s Tape => SEnv f (Reverse s a) -> Reverse s a)
-> CC (SEnv f) a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC SEnv f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => SEnv f (Reverse s a) -> Reverse s a
forall (f :: * -> *) a. SEnv f a -> a
sValue) [CC (SEnv f) a]
cs' SEnv f a
envS ((a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<=a
0) (a -> Bool) -> (SEnv f a -> a) -> SEnv f a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SEnv f a -> a
forall (f :: * -> *) a. SEnv f a -> a
sValue)
in case ((a, SEnv f a) -> Bool) -> [(a, SEnv f a)] -> [(a, SEnv f a)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((a
0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<) (a -> Bool) -> ((a, SEnv f a) -> a) -> (a, SEnv f a) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, SEnv f a) -> a
forall a b. (a, b) -> a
fst) (Int -> [(a, SEnv f a)] -> [(a, SEnv f a)]
forall a. Int -> [a] -> [a]
take (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
20::Int)) [(a, SEnv f a)]
cc) of
[] -> []
(a
_,SEnv f a
envFeasible) : [(a, SEnv f a)]
_ ->
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF) [CC f a]
cs (SEnv f a -> f a
forall (f :: * -> *) a. SEnv f a -> f a
origEnv SEnv f a
envFeasible) (Bool -> f a -> Bool
forall a b. a -> b -> a
const Bool
True)
{-# INLINE constrainedDescent #-}
eval :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> a
eval :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e = (a, f a) -> a
forall a b. (a, b) -> a
fst ((forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a -> (a, f a)
forall (f :: * -> *) a.
(Traversable f, Num a) =>
(forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a -> (a, f a)
grad' f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a
f f a
e)
{-# INLINE eval #-}
constrainedConvex' :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a
-> [CC f a]
-> f a
-> (f a -> Bool)
-> [(a,f a)]
constrainedConvex' :: forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' CC f a
objF [CC f a]
cs f a
env f a -> Bool
term =
let os :: [CC f a]
os = (a -> CC f a) -> [a] -> [CC f a]
forall a b. (a -> b) -> [a] -> [b]
map (CC f a -> [CC f a] -> a -> CC f a
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> a -> CC f a
mkOpt CC f a
objF [CC f a]
cs) [a]
tValues
envs :: [[(a, f a)]]
envs = [(a
forall a. HasCallStack => a
undefined,f a
env)] [(a, f a)] -> [[(a, f a)]] -> [[(a, f a)]]
forall a. a -> [a] -> [a]
:
[f a -> CC f a -> [(a, f a)]
forall {f :: * -> *} {a}.
(Traversable f, Fractional a, Ord a) =>
f a -> CC f a -> [(a, f a)]
gD ((a, f a) -> f a
forall a b. (a, b) -> b
snd ((a, f a) -> f a) -> (a, f a) -> f a
forall a b. (a -> b) -> a -> b
$ [(a, f a)] -> (a, f a)
forall a. HasCallStack => [a] -> a
last [(a, f a)]
e) CC f a
o
| CC f a
o <- [CC f a]
os
| [(a, f a)]
e <- [[(a, f a)]]
limEnvs
]
limEnvs :: [[(a, f a)]]
limEnvs = (([(a, f a)] -> [(a, f a)]) -> [(a, f a)] -> [(a, f a)])
-> [[(a, f a)] -> [(a, f a)]] -> [[(a, f a)]] -> [[(a, f a)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([(a, f a)] -> [(a, f a)]) -> [(a, f a)] -> [(a, f a)]
forall a. a -> a
id [[(a, f a)] -> [(a, f a)]]
nrSteps [[(a, f a)]]
envs
in ((a, f a) -> Bool) -> [(a, f a)] -> [(a, f a)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Bool -> Bool
not (Bool -> Bool) -> ((a, f a) -> Bool) -> (a, f a) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Bool
term (f a -> Bool) -> ((a, f a) -> f a) -> (a, f a) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, f a) -> f a
forall a b. (a, b) -> b
snd) ([[(a, f a)]] -> [(a, f a)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(a, f a)]] -> [(a, f a)]) -> [[(a, f a)]] -> [(a, f a)]
forall a b. (a -> b) -> a -> b
$ Int -> [[(a, f a)]] -> [[(a, f a)]]
forall a. Int -> [a] -> [a]
drop Int
1 [[(a, f a)]]
limEnvs)
where
tValues :: [a]
tValues = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map a -> a
forall a b. (Real a, Fractional b) => a -> b
realToFrac ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
64 ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate (a -> a -> a
forall a. Num a => a -> a -> a
*a
2) (a
2 :: a)
nrSteps :: [[(a, f a)] -> [(a, f a)]]
nrSteps = [Int -> [(a, f a)] -> [(a, f a)]
forall a. Int -> [a] -> [a]
take Int
20 | Int
_ <- [Int
1..[a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
tValues]] [[(a, f a)] -> [(a, f a)]]
-> [[(a, f a)] -> [(a, f a)]] -> [[(a, f a)] -> [(a, f a)]]
forall a. [a] -> [a] -> [a]
++ [[(a, f a)] -> [(a, f a)]
forall a. a -> a
id]
gD :: f a -> CC f a -> [(a, f a)]
gD f a
e (CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f) = ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e, f a
e) (a, f a) -> [(a, f a)] -> [(a, f a)]
forall a. a -> [a] -> [a]
:
(f a -> (a, f a)) -> [f a] -> [(a, f a)]
forall a b. (a -> b) -> [a] -> [b]
map (\f a
x -> ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x, f a
x)) ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e)
{-# INLINE constrainedConvex' #-}
mkOpt :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a -> [CC f a]
-> a -> CC f a
mkOpt :: forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> a -> CC f a
mkOpt (CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
o) [CC f a]
xs a
t = (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC (\f (Reverse s a)
e -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
o f (Reverse s a)
e Reverse s a -> Reverse s a -> Reverse s a
forall a. Num a => a -> a -> a
+ [Reverse s a] -> Reverse s a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((CC f a -> Reverse s a) -> [CC f a] -> [Reverse s a]
forall a b. (a -> b) -> [a] -> [b]
map (\(CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c) -> a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall a (f :: * -> *).
(Traversable f, RealFloat a, Floating a, Ord a) =>
a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
iHat a
t f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e) [CC f a]
xs))
{-# INLINE mkOpt #-}
iHat :: forall a f. (Traversable f, RealFloat a, Floating a, Ord a)
=> a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
iHat :: forall a (f :: * -> *).
(Traversable f, RealFloat a, Floating a, Ord a) =>
a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
iHat a
t forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e =
let r :: Reverse s a
r = f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e
in if Reverse s a
r Reverse s a -> Reverse s a -> Bool
forall a. Ord a => a -> a -> Bool
>= Reverse s a
0 Bool -> Bool -> Bool
|| Reverse s a -> Bool
forall a. RealFloat a => a -> Bool
isNaN Reverse s a
r
then Reverse s a
1 Reverse s a -> Reverse s a -> Reverse s a
forall a. Fractional a => a -> a -> a
/ Reverse s a
0
else (-Reverse s a
1 Reverse s a -> Reverse s a -> Reverse s a
forall a. Fractional a => a -> a -> a
/ Scalar (Reverse s a) -> Reverse s a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Reverse s a)
t) Reverse s a -> Reverse s a -> Reverse s a
forall a. Num a => a -> a -> a
* Reverse s a -> Reverse s a
forall a. Floating a => a -> a
log( - (f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e))
{-# INLINE iHat #-}
stochasticGradientDescent :: (Traversable f, Fractional a, Ord a)
=> (forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a)
-> [e]
-> f a
-> [f a]
stochasticGradientDescent :: forall (f :: * -> *) a e.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a)
-> [e] -> f a -> [f a]
stochasticGradientDescent forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle [e]
d0 f a
x0 = f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx0 a
0.001 [e]
dLeft
where
dLeft :: [e]
dLeft = [e] -> [e]
forall a. HasCallStack => [a] -> [a]
tail ([e] -> [e]) -> [e] -> [e]
forall a b. (a -> b) -> a -> b
$ [e] -> [e]
forall a. HasCallStack => [a] -> [a]
cycle [e]
d0
xgx0 :: f (a, a)
xgx0 = (a -> a -> (a, a))
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> f (a, a)
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> f b
Reverse.gradWith (,) (e -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle ([e] -> e
forall a. HasCallStack => [a] -> a
head [e]
d0)) f a
x0
go :: f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx !a
eta [e]
d
| a
eta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==a
0 = []
| Bool
otherwise = f a
x1 f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx1 a
eta ([e] -> [e]
forall a. HasCallStack => [a] -> [a]
tail [e]
d)
where
x1 :: f a
x1 = ((a, a) -> a) -> f (a, a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi, a
gxi) -> a
xi a -> a -> a
forall a. Num a => a -> a -> a
- a
eta a -> a -> a
forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
(a
_, f (a, a)
xgx1) = (a -> a -> (a, a))
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) (e -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle ([e] -> e
forall a. HasCallStack => [a] -> a
head [e]
d)) f a
x1
{-# INLINE stochasticGradientDescent #-}
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientAscent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientAscent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f = (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent (Reverse s a -> Reverse s a
forall a. Num a => a -> a
negate (Reverse s a -> Reverse s a)
-> (f (Reverse s a) -> Reverse s a)
-> f (Reverse s a)
-> Reverse s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f)
{-# INLINE gradientAscent #-}
conjugateGradientDescent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent :: forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f = (forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent (Or s (On (Forward (Forward a))) (Kahn a)
-> Or s (On (Forward (Forward a))) (Kahn a)
forall a. Num a => a -> a
negate (Or s (On (Forward (Forward a))) (Kahn a)
-> Or s (On (Forward (Forward a))) (Kahn a))
-> (f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f)
{-# INLINE conjugateGradientDescent #-}
lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F a b) -> Or F a b
f = Or F a b -> a
forall a b. Or F a b -> a
runL (Or F a b -> a) -> (f a -> Or F a b) -> f a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or F a b) -> Or F a b
f (f (Or F a b) -> Or F a b)
-> (f a -> f (Or F a b)) -> f a -> Or F a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Or F a b) -> f a -> f (Or F a b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Or F a b
forall a b. a -> Or F a b
L
rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T a b) -> Or T a b
f = Or T a b -> b
forall a b. Or T a b -> b
runR (Or T a b -> b) -> (f b -> Or T a b) -> f b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or T a b) -> Or T a b
f (f (Or T a b) -> Or T a b)
-> (f b -> f (Or T a b)) -> f b -> Or T a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Or T a b) -> f b -> f (Or T a b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap b -> Or T a b
forall b a. b -> Or T a b
R
conjugateGradientAscent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent :: forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f f a
x0 = (f a -> Bool) -> [f a] -> [f a]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile ((a -> Bool) -> f a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\a
a -> a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a)) (f a -> f a -> f a -> a -> [f a]
go f a
x0 f a
d0 f a
d0 a
delta0)
where
dot :: f a -> t a -> a
dot f a
x t a
y = t a -> a
forall a. Num a => t a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (t a -> a) -> t a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> t a -> t a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
forall a. Num a => a -> a -> a
(*) f a
x t a
y
d0 :: f a
d0 = (f (Kahn a) -> Kahn a) -> f a -> f a
forall (f :: * -> *) a.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
Kahn.grad ((f (Or T (On (Forward (Forward a))) (Kahn a))
-> Or T (On (Forward (Forward a))) (Kahn a))
-> f (Kahn a) -> Kahn a
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward (Forward a))) (Kahn a))
-> Or T (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f) f a
x0
delta0 :: a
delta0 = f a -> f a -> a
forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f a
d0 f a
d0
go :: f a -> f a -> f a -> a -> [f a]
go f a
xi f a
_ri f a
di a
deltai = f a
xi f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: f a -> f a -> f a -> a -> [f a]
go f a
xi1 f a
ri1 f a
di1 a
deltai1
where
ai :: a
ai = [a] -> a
forall a. HasCallStack => [a] -> a
last ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
20 ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremum (\On (Forward (Forward a))
a -> (f (Or F (On (Forward (Forward a))) (Kahn a))
-> Or F (On (Forward (Forward a))) (Kahn a))
-> f (On (Forward (Forward a))) -> On (Forward (Forward a))
forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F (On (Forward (Forward a))) (Kahn a))
-> Or F (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f (f (On (Forward (Forward a))) -> On (Forward (Forward a)))
-> f (On (Forward (Forward a))) -> On (Forward (Forward a))
forall a b. (a -> b) -> a -> b
$ (a -> a -> On (Forward (Forward a)))
-> f a -> f a -> f (On (Forward (Forward a)))
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
x a
d -> Scalar (On (Forward (Forward a))) -> On (Forward (Forward a))
forall t. Mode t => Scalar t -> t
auto a
Scalar (On (Forward (Forward a)))
x On (Forward (Forward a))
-> On (Forward (Forward a)) -> On (Forward (Forward a))
forall a. Num a => a -> a -> a
+ On (Forward (Forward a))
a On (Forward (Forward a))
-> On (Forward (Forward a)) -> On (Forward (Forward a))
forall a. Num a => a -> a -> a
* Scalar (On (Forward (Forward a))) -> On (Forward (Forward a))
forall t. Mode t => Scalar t -> t
auto a
Scalar (On (Forward (Forward a)))
d) f a
xi f a
di) a
0
xi1 :: f a
xi1 = (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
x a
d -> a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
aia -> a -> a
forall a. Num a => a -> a -> a
*a
d) f a
xi f a
di
ri1 :: f a
ri1 = (f (Kahn a) -> Kahn a) -> f a -> f a
forall (f :: * -> *) a.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
Kahn.grad ((f (Or T (On (Forward (Forward a))) (Kahn a))
-> Or T (On (Forward (Forward a))) (Kahn a))
-> f (Kahn a) -> Kahn a
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward (Forward a))) (Kahn a))
-> Or T (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f) f a
xi1
deltai1 :: a
deltai1 = f a -> f a -> a
forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f a
ri1 f a
ri1
bi1 :: a
bi1 = a
deltai1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
deltai
di1 :: f a
di1 = (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
r a
d -> a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a
bi1 a -> a -> a
forall a. Num a => a -> a -> a
* a
d) f a
ri1 f a
di
{-# INLINE conjugateGradientAscent #-}