{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.AD.Internal.Reverse
( Reverse(..)
, Tape(..)
, Head(..)
, Cells(..)
, reifyTape
, reifyTypeableTape
, partials
, partialArrayOf
, partialMapOf
, derivativeOf
, derivativeOf'
, bind
, unbind
, unbindMap
, unbindWith
, unbindMapWithDefault
, var
, varId
, primal
) where
import Data.Functor
import Control.Monad hiding (mapM)
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Array.ST
import Data.Array
import Data.Array.Unsafe as Unsafe
import Data.IORef
import Data.IntMap (IntMap, fromDistinctAscList, findWithDefault)
import Data.Number.Erf
import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)
import Unsafe.Coerce
data Cells where
Nil :: Cells
Unary :: {-# UNPACK #-} !Int -> a -> Cells -> Cells
Binary :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Int -> a -> a -> Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells Int
0 Cells
xs = Cells
xs
dropCells Int
_ Cells
Nil = Cells
Nil
dropCells Int
n (Unary Int
_ a
_ Cells
xs) = (Int -> Cells -> Cells
dropCells (Int -> Cells -> Cells) -> Int -> Cells -> Cells
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs
dropCells Int
n (Binary Int
_ Int
_ a
_ a
_ Cells
xs) = (Int -> Cells -> Cells
dropCells (Int -> Cells -> Cells) -> Int -> Cells -> Cells
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs
data Head = Head {-# UNPACK #-} !Int Cells
newtype Tape = Tape { Tape -> IORef Head
getTape :: IORef Head }
un :: Int -> a -> Head -> (Head, Int)
un :: forall a. Int -> a -> Head -> (Head, Int)
un Int
i a
di (Head Int
r Cells
t) = Head
h Head -> (Head, Int) -> (Head, Int)
forall a b. a -> b -> b
`seq` Int
r' Int -> (Head, Int) -> (Head, Int)
forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> a -> Cells -> Cells
forall a. Int -> a -> Cells -> Cells
Unary Int
i a
di Cells
t)
{-# INLINE un #-}
bin :: Int -> Int -> a -> a -> Head -> (Head, Int)
bin :: forall a. Int -> Int -> a -> a -> Head -> (Head, Int)
bin Int
i Int
j a
di a
dj (Head Int
r Cells
t) = Head
h Head -> (Head, Int) -> (Head, Int)
forall a b. a -> b -> b
`seq` Int
r' Int -> (Head, Int) -> (Head, Int)
forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
r' :: Int
r' = Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Int -> a -> a -> Cells -> Cells
forall a. Int -> Int -> a -> a -> Cells -> Cells
Binary Int
i Int
j a
di a
dj Cells
t)
{-# INLINE bin #-}
modifyTape :: Reifies s Tape => p s -> (Head -> (Head, r)) -> IO r
modifyTape :: forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape p s
p = IORef Head -> (Head -> (Head, r)) -> IO r
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (Tape -> IORef Head
getTape (p s -> Tape
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> Tape
reflect p s
p))
{-# INLINE modifyTape #-}
unarily :: forall s a. Reifies s Tape => (a -> a) -> a -> Int -> a -> Reverse s a
unarily :: forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily a -> a
f a
di Int
i a
b = Int -> a -> Reverse s a
forall a s. Int -> a -> Reverse s a
Reverse (IO Int -> Int
forall a. IO a -> a
unsafePerformIO (Proxy s -> (Head -> (Head, Int)) -> IO Int
forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (Proxy s
forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (Int -> a -> Head -> (Head, Int)
forall a. Int -> a -> Head -> (Head, Int)
un Int
i a
di))) (a -> Reverse s a) -> a -> Reverse s a
forall a b. (a -> b) -> a -> b
$! a -> a
f a
b
{-# INLINE unarily #-}
binarily :: forall s a. Reifies s Tape => (a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily :: forall s a.
Reifies s Tape =>
(a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily a -> a -> a
f a
di a
dj Int
i a
b Int
j a
c = Int -> a -> Reverse s a
forall a s. Int -> a -> Reverse s a
Reverse (IO Int -> Int
forall a. IO a -> a
unsafePerformIO (Proxy s -> (Head -> (Head, Int)) -> IO Int
forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (Proxy s
forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Int -> a -> a -> Head -> (Head, Int)
forall a. Int -> Int -> a -> a -> Head -> (Head, Int)
bin Int
i Int
j a
di a
dj))) (a -> Reverse s a) -> a -> Reverse s a
forall a b. (a -> b) -> a -> b
$! a -> a -> a
f a
b a
c
{-# INLINE binarily #-}
data Reverse s a where
Zero :: Reverse s a
Lift :: a -> Reverse s a
Reverse :: {-# UNPACK #-} !Int -> a -> Reverse s a
deriving (Int -> Reverse s a -> ShowS
[Reverse s a] -> ShowS
Reverse s a -> String
(Int -> Reverse s a -> ShowS)
-> (Reverse s a -> String)
-> ([Reverse s a] -> ShowS)
-> Show (Reverse s a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall s a. Show a => Int -> Reverse s a -> ShowS
forall s a. Show a => [Reverse s a] -> ShowS
forall s a. Show a => Reverse s a -> String
$cshowsPrec :: forall s a. Show a => Int -> Reverse s a -> ShowS
showsPrec :: Int -> Reverse s a -> ShowS
$cshow :: forall s a. Show a => Reverse s a -> String
show :: Reverse s a -> String
$cshowList :: forall s a. Show a => [Reverse s a] -> ShowS
showList :: [Reverse s a] -> ShowS
Show, Typeable)
instance (Reifies s Tape, Num a) => Mode (Reverse s a) where
type Scalar (Reverse s a) = a
isKnownZero :: Reverse s a -> Bool
isKnownZero Reverse s a
Zero = Bool
True
isKnownZero Reverse s a
_ = Bool
False
asKnownConstant :: Reverse s a -> Maybe (Scalar (Reverse s a))
asKnownConstant Reverse s a
Zero = a -> Maybe a
forall a. a -> Maybe a
Just a
0
asKnownConstant (Lift a
n) = a -> Maybe a
forall a. a -> Maybe a
Just a
n
asKnownConstant Reverse s a
_ = Maybe a
Maybe (Scalar (Reverse s a))
forall a. Maybe a
Nothing
isKnownConstant :: Reverse s a -> Bool
isKnownConstant Reverse{} = Bool
False
isKnownConstant Reverse s a
_ = Bool
True
auto :: Scalar (Reverse s a) -> Reverse s a
auto = a -> Reverse s a
Scalar (Reverse s a) -> Reverse s a
forall a s. a -> Reverse s a
Lift
zero :: Reverse s a
zero = Reverse s a
forall s a. Reverse s a
Zero
Scalar (Reverse s a)
a *^ :: Scalar (Reverse s a) -> Reverse s a -> Reverse s a
*^ Reverse s a
b = (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Reverse s a)
a Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
forall a. Num a => a -> a -> a
*) (\D (Reverse s a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto Scalar (Id a)
Scalar (Reverse s a)
a) Reverse s a
b
Reverse s a
a ^* :: Reverse s a -> Scalar (Reverse s a) -> Reverse s a
^* Scalar (Reverse s a)
b = (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
forall a. Num a => a -> a -> a
* Scalar (Reverse s a)
b) (\D (Reverse s a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto Scalar (Id a)
Scalar (Reverse s a)
b) Reverse s a
a
Reverse s a
a ^/ :: Fractional (Scalar (Reverse s a)) =>
Reverse s a -> Scalar (Reverse s a) -> Reverse s a
^/ Scalar (Reverse s a)
b = (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
forall a. Fractional a => a -> a -> a
/ Scalar (Reverse s a)
b) (\D (Reverse s a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto (a -> a
forall a. Fractional a => a -> a
recip a
Scalar (Reverse s a)
b)) Reverse s a
a
(<+>) :: (Reifies s Tape, Num a) => Reverse s a -> Reverse s a -> Reverse s a
<+> :: forall s a.
(Reifies s Tape, Num a) =>
Reverse s a -> Reverse s a -> Reverse s a
(<+>) = (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a)
-> D (Reverse s a)
-> Reverse s a
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary a -> a -> a
Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
forall a. Num a => a -> a -> a
(+) D (Reverse s a)
Id a
1 D (Reverse s a)
Id a
1
primal :: Num a => Reverse s a -> a
primal :: forall a s. Num a => Reverse s a -> a
primal Reverse s a
Zero = a
0
primal (Lift a
a) = a
a
primal (Reverse Int
_ a
a) = a
a
instance (Reifies s Tape, Num a) => Jacobian (Reverse s a) where
type D (Reverse s a) = Id a
unary :: (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a) -> Reverse s a -> Reverse s a
unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ Reverse s a
Zero = a -> Reverse s a
forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
0)
unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ (Lift a
a) = a -> Reverse s a
forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
a)
unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadi) (Reverse Int
i a
b) = (a -> a) -> a -> Int -> a -> Reverse s a
forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily a -> a
Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
dadi Int
i a
b
lift1 :: (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
lift1 Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a) -> D (Reverse s a)
df Reverse s a
b = (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a) -> Reverse s a -> Reverse s a
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f (D (Reverse s a) -> D (Reverse s a)
df (a -> Id a
forall a. a -> Id a
Id a
pb)) Reverse s a
b where
pb :: a
pb = Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
b
lift1_ :: (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
lift1_ Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a) -> D (Reverse s a) -> D (Reverse s a)
df Reverse s a
b = (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a) -> Reverse s a -> Reverse s a
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (a -> a -> a
forall a b. a -> b -> a
const a
Scalar (Reverse s a)
a) (D (Reverse s a) -> D (Reverse s a) -> D (Reverse s a)
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Reverse s a)
a) (a -> Id a
forall a. a -> Id a
Id a
pb)) Reverse s a
b where
pb :: a
pb = Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
b
a :: Scalar (Reverse s a)
a = Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
pb
binary :: (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a)
-> D (Reverse s a)
-> Reverse s a
-> Reverse s a
-> Reverse s a
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ D (Reverse s a)
_ Reverse s a
Zero Reverse s a
Zero = a -> Reverse s a
forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
0 a
Scalar (Reverse s a)
0)
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ D (Reverse s a)
_ Reverse s a
Zero (Lift a
c) = a -> Reverse s a
forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
0 a
Scalar (Reverse s a)
c)
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ D (Reverse s a)
_ (Lift a
b) Reverse s a
Zero = a -> Reverse s a
forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
b a
Scalar (Reverse s a)
0)
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ D (Reverse s a)
_ (Lift a
b) (Lift a
c) = a -> Reverse s a
forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
b a
Scalar (Reverse s a)
c)
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ (Id a
dadc) Reverse s a
Zero (Reverse Int
i a
c) = (a -> a) -> a -> Int -> a -> Reverse s a
forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
0) a
dadc Int
i a
c
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_ (Id a
dadc) (Lift a
b) (Reverse Int
i a
c) = (a -> a) -> a -> Int -> a -> Reverse s a
forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
b) a
dadc Int
i a
c
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadb) D (Reverse s a)
_ (Reverse Int
i a
b) Reverse s a
Zero = (a -> a) -> a -> Int -> a -> Reverse s a
forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
`f` a
Scalar (Reverse s a)
0) a
dadb Int
i a
b
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadb) D (Reverse s a)
_ (Reverse Int
i a
b) (Lift a
c) = (a -> a) -> a -> Int -> a -> Reverse s a
forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
`f` a
Scalar (Reverse s a)
c) a
dadb Int
i a
b
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadb) (Id a
dadc) (Reverse Int
i a
b) (Reverse Int
j a
c) = (a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
forall s a.
Reifies s Tape =>
(a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily a -> a -> a
Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
dadb a
dadc Int
i a
b Int
j a
c
lift2 :: (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a)
-> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a)))
-> Reverse s a
-> Reverse s a
-> Reverse s a
lift2 Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
-> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a))
df Reverse s a
b Reverse s a
c = (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a)
-> D (Reverse s a)
-> Reverse s a
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
dadb D (Reverse s a)
dadc Reverse s a
b Reverse s a
c where
(D (Reverse s a)
dadb, D (Reverse s a)
dadc) = D (Reverse s a)
-> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a))
df (a -> Id a
forall a. a -> Id a
Id (Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
b)) (a -> Id a
forall a. a -> Id a
Id (Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
c))
lift2_ :: (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a)
-> D (Reverse s a)
-> D (Reverse s a)
-> (D (Reverse s a), D (Reverse s a)))
-> Reverse s a
-> Reverse s a
-> Reverse s a
lift2_ Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
-> D (Reverse s a)
-> D (Reverse s a)
-> (D (Reverse s a), D (Reverse s a))
df Reverse s a
b Reverse s a
c = (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a)
-> D (Reverse s a)
-> Reverse s a
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (Reverse s a)
_ Scalar (Reverse s a)
_ -> Scalar (Reverse s a)
a) D (Reverse s a)
dadb D (Reverse s a)
dadc Reverse s a
b Reverse s a
c where
pb :: a
pb = Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
b
pc :: a
pc = Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
c
a :: Scalar (Reverse s a)
a = Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
Scalar (Reverse s a)
pb a
Scalar (Reverse s a)
pc
(D (Reverse s a)
dadb, D (Reverse s a)
dadc) = D (Reverse s a)
-> D (Reverse s a)
-> D (Reverse s a)
-> (D (Reverse s a), D (Reverse s a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Reverse s a)
a) (a -> Id a
forall a. a -> Id a
Id a
pb) (a -> Id a
forall a. a -> Id a
Id a
pc)
mul :: (Reifies s Tape, Num a) => Reverse s a -> Reverse s a -> Reverse s a
mul :: forall s a.
(Reifies s Tape, Num a) =>
Reverse s a -> Reverse s a -> Reverse s a
mul = (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a)
-> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a)))
-> Reverse s a
-> Reverse s a
-> Reverse s a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 a -> a -> a
Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
forall a. Num a => a -> a -> a
(*) (\D (Reverse s a)
x D (Reverse s a)
y -> (D (Reverse s a)
y, D (Reverse s a)
x))
#define BODY1(x) (Reifies s Tape,x) =>
#define BODY2(x,y) (Reifies s Tape,x,y) =>
#define HEAD (Reverse s a)
#include "instances.h"
derivativeOf :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf :: forall s a. (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf Proxy s
_ = [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> (Reverse s a -> [a]) -> Reverse s a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reverse s a -> [a]
forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials
{-# INLINE derivativeOf #-}
derivativeOf' :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> (a, a)
derivativeOf' :: forall s a.
(Reifies s Tape, Num a) =>
Proxy s -> Reverse s a -> (a, a)
derivativeOf' Proxy s
p Reverse s a
r = (Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
r, Proxy s -> Reverse s a -> a
forall s a. (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf Proxy s
p Reverse s a
r)
{-# INLINE derivativeOf' #-}
backPropagate :: Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate :: forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate Int
k Cells
Nil STArray s Int a
_ = Int -> ST s Int
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
k
backPropagate Int
k (Unary Int
i a
g Cells
xs) STArray s Int a
ss = do
a
da <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
k
a
db <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
i (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a b. a -> b
unsafeCoerce a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
da
(Int -> Cells -> STArray s Int a -> ST s Int
forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate (Int -> Cells -> STArray s Int a -> ST s Int)
-> Int -> Cells -> STArray s Int a -> ST s Int
forall a b. (a -> b) -> a -> b
$! Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int a
ss
backPropagate Int
k (Binary Int
i Int
j a
g a
h Cells
xs) STArray s Int a
ss = do
a
da <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
k
a
db <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
i (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a b. a -> b
unsafeCoerce a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
da
a
dc <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
j
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
j (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$! a
dc a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a b. a -> b
unsafeCoerce a
ha -> a -> a
forall a. Num a => a -> a -> a
*a
da
(Int -> Cells -> STArray s Int a -> ST s Int
forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate (Int -> Cells -> STArray s Int a -> ST s Int)
-> Int -> Cells -> STArray s Int a -> ST s Int
forall a b. (a -> b) -> a -> b
$! Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int a
ss
{-# SPECIALIZE partials :: Reifies s Tape => Reverse s Double -> [Double] #-}
partials :: forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials :: forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials Reverse s a
Zero = []
partials (Lift a
_) = []
partials (Reverse Int
k a
_) = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Array Int a
sensitivities Array Int a -> Int -> a
forall i e. Ix i => Array i e -> i -> e
!) [Int
0..Int
vs] where
Head Int
n Cells
t = IO Head -> Head
forall a. IO a -> a
unsafePerformIO (IO Head -> Head) -> IO Head -> Head
forall a b. (a -> b) -> a -> b
$ IORef Head -> IO Head
forall a. IORef a -> IO a
readIORef (Tape -> IORef Head
getTape (Proxy s -> Tape
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy s -> Tape
reflect (Proxy s
forall {k} (t :: k). Proxy t
Proxy :: Proxy s)))
tk :: Cells
tk = Int -> Cells -> Cells
dropCells (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) Cells
t
(Int
vs,Array Int a
sensitivities) = (forall s. ST s (Int, Array Int a)) -> (Int, Array Int a)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Int, Array Int a)) -> (Int, Array Int a))
-> (forall s. ST s (Int, Array Int a)) -> (Int, Array Int a)
forall a b. (a -> b) -> a -> b
$ do
STArray s Int a
ss <- (Int, Int) -> a -> ST s (STArray s Int a)
forall i. Ix i => (i, i) -> a -> ST s (STArray s i a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
k) a
0
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
k a
1
Int
v <- Int -> Cells -> STArray s Int a -> ST s Int
forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate Int
k Cells
tk STArray s Int a
ss
Array Int a
as <- STArray s Int a -> ST s (Array Int a)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
Unsafe.unsafeFreeze STArray s Int a
ss
(Int, Array Int a) -> ST s (Int, Array Int a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
v, Array Int a
as)
partialArrayOf :: (Reifies s Tape, Num a) => Proxy s -> (Int, Int) -> Reverse s a -> Array Int a
partialArrayOf :: forall s a.
(Reifies s Tape, Num a) =>
Proxy s -> (Int, Int) -> Reverse s a -> Array Int a
partialArrayOf Proxy s
_ (Int, Int)
vbounds = (a -> a -> a) -> a -> (Int, Int) -> [(Int, a)] -> Array Int a
forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 (Int, Int)
vbounds ([(Int, a)] -> Array Int a)
-> (Reverse s a -> [(Int, a)]) -> Reverse s a -> Array Int a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([a] -> [(Int, a)])
-> (Reverse s a -> [a]) -> Reverse s a -> [(Int, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reverse s a -> [a]
forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials
{-# INLINE partialArrayOf #-}
partialMapOf :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> IntMap a
partialMapOf :: forall s a.
(Reifies s Tape, Num a) =>
Proxy s -> Reverse s a -> IntMap a
partialMapOf Proxy s
_ = [(Int, a)] -> IntMap a
forall a. [(Int, a)] -> IntMap a
fromDistinctAscList ([(Int, a)] -> IntMap a)
-> (Reverse s a -> [(Int, a)]) -> Reverse s a -> IntMap a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([a] -> [(Int, a)])
-> (Reverse s a -> [a]) -> Reverse s a -> [(Int, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reverse s a -> [a]
forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials
{-# INLINE partialMapOf #-}
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape :: forall r. Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape Int
vs forall s. Reifies s Tape => Proxy s -> r
k = IO r -> r
forall a. IO a -> a
unsafePerformIO (IO r -> r) -> IO r -> r
forall a b. (a -> b) -> a -> b
$ do
IORef Head
h <- Head -> IO (IORef Head)
forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
r -> IO r
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tape -> (forall s. Reifies s Tape => Proxy s -> r) -> r
forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify (IORef Head -> Tape
Tape IORef Head
h) Proxy s -> r
forall s. Reifies s Tape => Proxy s -> r
k)
{-# NOINLINE reifyTape #-}
reifyTypeableTape :: Int -> (forall s. (Typeable s, Reifies s Tape) => Proxy s -> r) -> r
reifyTypeableTape :: forall r.
Int
-> (forall s. (Typeable s, Reifies s Tape) => Proxy s -> r) -> r
reifyTypeableTape Int
vs forall s. (Typeable s, Reifies s Tape) => Proxy s -> r
k = IO r -> r
forall a. IO a -> a
unsafePerformIO (IO r -> r) -> IO r -> r
forall a b. (a -> b) -> a -> b
$ do
IORef Head
h <- Head -> IO (IORef Head)
forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
r -> IO r
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tape
-> (forall s. (Typeable s, Reifies s Tape) => Proxy s -> r) -> r
forall a r.
Typeable a =>
a -> (forall s. (Typeable s, Reifies s a) => Proxy s -> r) -> r
reifyTypeable (IORef Head -> Tape
Tape IORef Head
h) Proxy s -> r
forall s. (Typeable s, Reifies s Tape) => Proxy s -> r
k)
{-# NOINLINE reifyTypeableTape #-}
var :: a -> Int -> Reverse s a
var :: forall a s. a -> Int -> Reverse s a
var a
a Int
v = Int -> a -> Reverse s a
forall a s. Int -> a -> Reverse s a
Reverse Int
v a
a
varId :: Reverse s a -> Int
varId :: forall s a. Reverse s a -> Int
varId (Reverse Int
v a
_) = Int
v
varId Reverse s a
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"varId: not a Var"
bind :: Traversable f => f a -> (f (Reverse s a), (Int,Int))
bind :: forall (f :: * -> *) a s.
Traversable f =>
f a -> (f (Reverse s a), (Int, Int))
bind f a
xs = (f (Reverse s a)
r,(Int
0,Int
hi)) where
(f (Reverse s a)
r,Int
hi) = State Int (f (Reverse s a)) -> Int -> (f (Reverse s a), Int)
forall s a. State s a -> s -> (a, s)
runState ((a -> StateT Int Identity (Reverse s a))
-> f a -> State Int (f (Reverse s a))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> f a -> m (f b)
mapM a -> StateT Int Identity (Reverse s a)
forall {m :: * -> *} {a} {s}.
Monad m =>
a -> StateT Int m (Reverse s a)
freshVar f a
xs) Int
0
freshVar :: a -> StateT Int m (Reverse s a)
freshVar a
a = (Int -> (Reverse s a, Int)) -> StateT Int m (Reverse s a)
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> (Reverse s a, Int)) -> StateT Int m (Reverse s a))
-> (Int -> (Reverse s a, Int)) -> StateT Int m (Reverse s a)
forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
s' Int -> (Reverse s a, Int) -> (Reverse s a, Int)
forall a b. a -> b -> b
`seq` (a -> Int -> Reverse s a
forall a s. a -> Int -> Reverse s a
var a
a Int
s, Int
s')
unbind :: Functor f => f (Reverse s a) -> Array Int a -> f a
unbind :: forall (f :: * -> *) s a.
Functor f =>
f (Reverse s a) -> Array Int a -> f a
unbind f (Reverse s a)
xs Array Int a
ys = (Reverse s a -> a) -> f (Reverse s a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> Array Int a
ys Array Int a -> Int -> a
forall i e. Ix i => Array i e -> i -> e
! Reverse s a -> Int
forall s a. Reverse s a -> Int
varId Reverse s a
v) f (Reverse s a)
xs
unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Reverse s a) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) a b c s.
(Functor f, Num a) =>
(a -> b -> c) -> f (Reverse s a) -> Array Int b -> f c
unbindWith a -> b -> c
f f (Reverse s a)
xs Array Int b
ys = (Reverse s a -> c) -> f (Reverse s a) -> f c
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> a -> b -> c
f (Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
v) (Array Int b
ys Array Int b -> Int -> b
forall i e. Ix i => Array i e -> i -> e
! Reverse s a -> Int
forall s a. Reverse s a -> Int
varId Reverse s a
v)) f (Reverse s a)
xs
unbindMap :: (Functor f, Num a) => f (Reverse s a) -> IntMap a -> f a
unbindMap :: forall (f :: * -> *) a s.
(Functor f, Num a) =>
f (Reverse s a) -> IntMap a -> f a
unbindMap f (Reverse s a)
xs IntMap a
ys = (Reverse s a -> a) -> f (Reverse s a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> a -> Int -> IntMap a -> a
forall a. a -> Int -> IntMap a -> a
findWithDefault a
0 (Reverse s a -> Int
forall s a. Reverse s a -> Int
varId Reverse s a
v) IntMap a
ys) f (Reverse s a)
xs
unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Reverse s a) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) a b c s.
(Functor f, Num a) =>
b -> (a -> b -> c) -> f (Reverse s a) -> IntMap b -> f c
unbindMapWithDefault b
z a -> b -> c
f f (Reverse s a)
xs IntMap b
ys = (Reverse s a -> c) -> f (Reverse s a) -> f c
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> a -> b -> c
f (Reverse s a -> a
forall a s. Num a => Reverse s a -> a
primal Reverse s a
v) (b -> c) -> b -> c
forall a b. (a -> b) -> a -> b
$ b -> Int -> IntMap b -> b
forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (Reverse s a -> Int
forall s a. Reverse s a -> Int
varId Reverse s a
v) IntMap b
ys) f (Reverse s a)
xs