{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.AD.Internal.Kahn
( Kahn(..)
, Tape(..)
, partials
, partialArray
, partialMap
, derivative
, derivative'
, vgrad, vgrad'
, Grad(..)
, bind
, unbind
, unbindMap
, unbindWith
, unbindMapWithDefault
, primal
, var
, varId
) where
import Control.Monad.ST
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import qualified Data.List as List (foldl')
import Data.Array.ST
import Data.Array
import Data.IntMap (IntMap, fromListWith, findWithDefault)
import Data.Graph (Vertex, transposeG, Graph)
import Data.Number.Erf
import Data.Reify (reifyGraph, MuRef(..))
import qualified Data.Reify.Graph as Reified
import System.IO.Unsafe (unsafePerformIO)
import Data.Data (Data)
import Data.Typeable (Typeable)
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
data Tape a t
= Zero
| Lift !a
| Var !a {-# UNPACK #-} !Int
| Binary !a a a t t
| Unary !a a t
deriving (Int -> Tape a t -> ShowS
[Tape a t] -> ShowS
Tape a t -> String
(Int -> Tape a t -> ShowS)
-> (Tape a t -> String) -> ([Tape a t] -> ShowS) -> Show (Tape a t)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a t. (Show a, Show t) => Int -> Tape a t -> ShowS
forall a t. (Show a, Show t) => [Tape a t] -> ShowS
forall a t. (Show a, Show t) => Tape a t -> String
$cshowsPrec :: forall a t. (Show a, Show t) => Int -> Tape a t -> ShowS
showsPrec :: Int -> Tape a t -> ShowS
$cshow :: forall a t. (Show a, Show t) => Tape a t -> String
show :: Tape a t -> String
$cshowList :: forall a t. (Show a, Show t) => [Tape a t] -> ShowS
showList :: [Tape a t] -> ShowS
Show, Typeable (Tape a t)
Typeable (Tape a t) =>
(forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t))
-> (forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t))
-> (Tape a t -> Constr)
-> (Tape a t -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Tape a t)))
-> ((forall b. Data b => b -> b) -> Tape a t -> Tape a t)
-> (forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r)
-> (forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r)
-> (forall u. (forall d. Data d => d -> u) -> Tape a t -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u)
-> (forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t))
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t))
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t))
-> Data (Tape a t)
Tape a t -> Constr
Tape a t -> DataType
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
forall a.
Typeable a =>
(forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u
forall u. (forall d. Data d => d -> u) -> Tape a t -> [u]
forall a t. (Data a, Data t) => Typeable (Tape a t)
forall a t. (Data a, Data t) => Tape a t -> Constr
forall a t. (Data a, Data t) => Tape a t -> DataType
forall a t.
(Data a, Data t) =>
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
forall a t u.
(Data a, Data t) =>
Int -> (forall d. Data d => d -> u) -> Tape a t -> u
forall a t u.
(Data a, Data t) =>
(forall d. Data d => d -> u) -> Tape a t -> [u]
forall a t r r'.
(Data a, Data t) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall a t r r'.
(Data a, Data t) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall a t (m :: * -> *).
(Data a, Data t, Monad m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall a t (c :: * -> *).
(Data a, Data t) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
forall a t (c :: * -> *).
(Data a, Data t) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
forall a t (t :: * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
forall a t (t :: * -> * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
$cgfoldl :: forall a t (c :: * -> *).
(Data a, Data t) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
$cgunfold :: forall a t (c :: * -> *).
(Data a, Data t) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
$ctoConstr :: forall a t. (Data a, Data t) => Tape a t -> Constr
toConstr :: Tape a t -> Constr
$cdataTypeOf :: forall a t. (Data a, Data t) => Tape a t -> DataType
dataTypeOf :: Tape a t -> DataType
$cdataCast1 :: forall a t (t :: * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
$cdataCast2 :: forall a t (t :: * -> * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
$cgmapT :: forall a t.
(Data a, Data t) =>
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
gmapT :: (forall b. Data b => b -> b) -> Tape a t -> Tape a t
$cgmapQl :: forall a t r r'.
(Data a, Data t) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
$cgmapQr :: forall a t r r'.
(Data a, Data t) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
$cgmapQ :: forall a t u.
(Data a, Data t) =>
(forall d. Data d => d -> u) -> Tape a t -> [u]
gmapQ :: forall u. (forall d. Data d => d -> u) -> Tape a t -> [u]
$cgmapQi :: forall a t u.
(Data a, Data t) =>
Int -> (forall d. Data d => d -> u) -> Tape a t -> u
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u
$cgmapM :: forall a t (m :: * -> *).
(Data a, Data t, Monad m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapMp :: forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapMo :: forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
Data, Typeable)
newtype Kahn a = Kahn (Tape a (Kahn a)) deriving (Int -> Kahn a -> ShowS
[Kahn a] -> ShowS
Kahn a -> String
(Int -> Kahn a -> ShowS)
-> (Kahn a -> String) -> ([Kahn a] -> ShowS) -> Show (Kahn a)
forall a. Show a => Int -> Kahn a -> ShowS
forall a. Show a => [Kahn a] -> ShowS
forall a. Show a => Kahn a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Kahn a -> ShowS
showsPrec :: Int -> Kahn a -> ShowS
$cshow :: forall a. Show a => Kahn a -> String
show :: Kahn a -> String
$cshowList :: forall a. Show a => [Kahn a] -> ShowS
showList :: [Kahn a] -> ShowS
Show, Typeable)
instance MuRef (Kahn a) where
type DeRef (Kahn a) = Tape a
mapDeRef :: forall (f :: * -> *) u.
Applicative f =>
(forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u)
-> Kahn a -> f (DeRef (Kahn a) u)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn Tape a (Kahn a)
Zero) = Tape a u -> f (Tape a u)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tape a u
forall a t. Tape a t
Zero
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn (Lift a
a)) = Tape a u -> f (Tape a u)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Tape a u
forall a t. a -> Tape a t
Lift a
a)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn (Var a
a Int
v)) = Tape a u -> f (Tape a u)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Int -> Tape a u
forall a t. a -> Int -> Tape a t
Var a
a Int
v)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f (Kahn (Binary a
a a
dadb a
dadc Kahn a
b Kahn a
c)) = a -> a -> a -> u -> u -> Tape a u
forall a t. a -> a -> a -> t -> t -> Tape a t
Binary a
a a
dadb a
dadc (u -> u -> Tape a u) -> f u -> f (u -> Tape a u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Kahn a -> f u
forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
b f (u -> Tape a u) -> f u -> f (Tape a u)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Kahn a -> f u
forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
c
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f (Kahn (Unary a
a a
dadb Kahn a
b)) = a -> a -> u -> Tape a u
forall a t. a -> a -> t -> Tape a t
Unary a
a a
dadb (u -> Tape a u) -> f u -> f (Tape a u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Kahn a -> f u
forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
b
instance Num a => Mode (Kahn a) where
type Scalar (Kahn a) = a
isKnownZero :: Kahn a -> Bool
isKnownZero (Kahn Tape a (Kahn a)
Zero) = Bool
True
isKnownZero Kahn a
_ = Bool
False
asKnownConstant :: Kahn a -> Maybe (Scalar (Kahn a))
asKnownConstant (Kahn Tape a (Kahn a)
Zero) = a -> Maybe a
forall a. a -> Maybe a
Just a
0
asKnownConstant (Kahn (Lift a
n)) = a -> Maybe a
forall a. a -> Maybe a
Just a
n
asKnownConstant Kahn a
_ = Maybe a
Maybe (Scalar (Kahn a))
forall a. Maybe a
Nothing
isKnownConstant :: Kahn a -> Bool
isKnownConstant (Kahn Tape a (Kahn a)
Zero) = Bool
True
isKnownConstant (Kahn (Lift a
_)) = Bool
True
isKnownConstant Kahn a
_ = Bool
False
auto :: Scalar (Kahn a) -> Kahn a
auto Scalar (Kahn a)
a = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift a
Scalar (Kahn a)
a)
zero :: Kahn a
zero = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn Tape a (Kahn a)
forall a t. Tape a t
Zero
Scalar (Kahn a)
a *^ :: Scalar (Kahn a) -> Kahn a -> Kahn a
*^ Kahn a
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a)
a Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
*) (\D (Kahn a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto Scalar (Id a)
Scalar (Kahn a)
a) Kahn a
b
Kahn a
a ^* :: Kahn a -> Scalar (Kahn a) -> Kahn a
^* Scalar (Kahn a)
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
* Scalar (Kahn a)
b) (\D (Kahn a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto Scalar (Id a)
Scalar (Kahn a)
b) Kahn a
a
Kahn a
a ^/ :: Fractional (Scalar (Kahn a)) => Kahn a -> Scalar (Kahn a) -> Kahn a
^/ Scalar (Kahn a)
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Fractional a => a -> a -> a
/ Scalar (Kahn a)
b) (\D (Kahn a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto (a -> a
forall a. Fractional a => a -> a
recip a
Scalar (Kahn a)
b)) Kahn a
a
(<+>) :: Num a => Kahn a -> Kahn a -> Kahn a
<+> :: forall a. Num a => Kahn a -> Kahn a -> Kahn a
(<+>) = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary a -> a -> a
Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
(+) D (Kahn a)
Id a
1 D (Kahn a)
Id a
1
primal :: Num a => Kahn a -> a
primal :: forall a. Num a => Kahn a -> a
primal (Kahn Tape a (Kahn a)
Zero) = a
0
primal (Kahn (Lift a
a)) = a
a
primal (Kahn (Var a
a Int
_)) = a
a
primal (Kahn (Binary a
a a
_ a
_ Kahn a
_ Kahn a
_)) = a
a
primal (Kahn (Unary a
a a
_ Kahn a
_)) = a
a
instance Num a => Jacobian (Kahn a) where
type D (Kahn a) = Id a
unary :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
unary Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Kahn Tape a (Kahn a)
Zero) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0))
unary Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Kahn (Lift a
a)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
a))
unary Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) Kahn a
b = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b)) a
dadb Kahn a
b)
lift1 :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
lift1 Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a)
df Kahn a
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (Kahn a) -> Scalar (Kahn a)
f (D (Kahn a) -> D (Kahn a)
df (a -> Id a
forall a. a -> Id a
Id a
pb)) Kahn a
b where
pb :: a
pb = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b
lift1_ :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
lift1_ Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> D (Kahn a)
df Kahn a
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (a -> a -> a
forall a b. a -> b -> a
const a
Scalar (Kahn a)
a) (D (Kahn a) -> D (Kahn a) -> D (Kahn a)
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Kahn a)
a) (a -> Id a
forall a. a -> Id a
Id a
pb)) Kahn a
b where
pb :: a
pb = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b
a :: Scalar (Kahn a)
a = Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
pb
binary :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn Tape a (Kahn a)
Zero) (Kahn Tape a (Kahn a)
Zero) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0 a
Scalar (Kahn a)
0))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn Tape a (Kahn a)
Zero) (Kahn (Lift a
c)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0 a
Scalar (Kahn a)
c))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn (Lift a
b)) (Kahn Tape a (Kahn a)
Zero) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
b a
Scalar (Kahn a)
0))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn (Lift a
b)) (Kahn (Lift a
c)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
b a
Scalar (Kahn a)
c))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Id a
dadc) (Kahn Tape a (Kahn a)
Zero) Kahn a
c = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0 (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadc Kahn a
c)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Id a
dadc) (Kahn (Lift a
b)) Kahn a
c = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
b (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadc Kahn a
c)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) D (Kahn a)
_ Kahn a
b (Kahn Tape a (Kahn a)
Zero) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b) a
Scalar (Kahn a)
0) a
dadb Kahn a
b)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) D (Kahn a)
_ Kahn a
b (Kahn (Lift a
c)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b) a
Scalar (Kahn a)
c) a
dadb Kahn a
b)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) (Id a
dadc) Kahn a
b Kahn a
c = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> a -> Kahn a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> a -> t -> t -> Tape a t
Binary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b) (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadb a
dadc Kahn a
b Kahn a
c)
lift2 :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
lift2 Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df Kahn a
b Kahn a
c = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
dadb D (Kahn a)
dadc Kahn a
b Kahn a
c where
(D (Kahn a)
dadb, D (Kahn a)
dadc) = D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df (a -> Id a
forall a. a -> Id a
Id (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b)) (a -> Id a
forall a. a -> Id a
Id (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c))
lift2_ :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a)
-> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
lift2_ Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df Kahn a
b Kahn a
c = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (Kahn a)
_ Scalar (Kahn a)
_ -> Scalar (Kahn a)
a) D (Kahn a)
dadb D (Kahn a)
dadc Kahn a
b Kahn a
c where
pb :: a
pb = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b
pc :: a
pc = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c
a :: Scalar (Kahn a)
a = Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
pb a
Scalar (Kahn a)
pc
(D (Kahn a)
dadb, D (Kahn a)
dadc) = D (Kahn a) -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Kahn a)
a) (a -> Id a
forall a. a -> Id a
Id a
pb) (a -> Id a
forall a. a -> Id a
Id a
pc)
mul :: Num a => Kahn a -> Kahn a -> Kahn a
mul :: forall a. Num a => Kahn a -> Kahn a -> Kahn a
mul = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 a -> a -> a
Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
(*) (\D (Kahn a)
x D (Kahn a)
y -> (D (Kahn a)
y, D (Kahn a)
x))
#define HEAD (Kahn a)
#include <instances.h>
derivative :: Num a => Kahn a -> a
derivative :: forall a. Num a => Kahn a -> a
derivative = [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> (Kahn a -> [a]) -> Kahn a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> (Kahn a -> [(Int, a)]) -> Kahn a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kahn a -> [(Int, a)]
forall a. Num a => Kahn a -> [(Int, a)]
partials
{-# INLINE derivative #-}
derivative' :: Num a => Kahn a -> (a, a)
derivative' :: forall a. Num a => Kahn a -> (a, a)
derivative' Kahn a
r = (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
r, Kahn a -> a
forall a. Num a => Kahn a -> a
derivative Kahn a
r)
{-# INLINE derivative' #-}
backPropagate :: Num a => (Vertex -> (Tape a Int, Int, [Int])) -> STArray s Int a -> Vertex -> ST s ()
backPropagate :: forall a s.
Num a =>
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
backPropagate Int -> (Tape a Int, Int, [Int])
vmap STArray s Int a
ss Int
v = case Tape a Int
node of
Unary a
_ a
g Int
b -> do
a
da <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
a
db <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
b
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
b (a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
da)
Binary a
_ a
gb a
gc Int
b Int
c -> do
a
da <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
a
db <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
b
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
b (a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
gba -> a -> a
forall a. Num a => a -> a -> a
*a
da)
a
dc <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
c
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
c (a
dc a -> a -> a
forall a. Num a => a -> a -> a
+ a
gca -> a -> a
forall a. Num a => a -> a -> a
*a
da)
Tape a Int
_ -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where
(Tape a Int
node, Int
i, [Int]
_) = Int -> (Tape a Int, Int, [Int])
vmap Int
v
topSortAcyclic :: Graph -> [Vertex]
topSortAcyclic :: Graph -> [Int]
topSortAcyclic Graph
g = [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (forall s. ST s [Int]) -> [Int]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [Int]) -> [Int])
-> (forall s. ST s [Int]) -> [Int]
forall a b. (a -> b) -> a -> b
$ do
STUArray s Int Bool
del <- (Int, Int) -> Bool -> ST s (STUArray s Int Bool)
forall i. Ix i => (i, i) -> Bool -> ST s (STUArray s i Bool)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Graph -> (Int, Int)
forall i e. Array i e -> (i, i)
bounds Graph
g) Bool
False :: ST s (STUArray s Int Bool)
let tg :: Graph
tg = Graph -> Graph
transposeG Graph
g
starters :: [Int]
starters = [ Int
n | (Int
n, []) <- Graph -> [(Int, [Int])]
forall i e. Ix i => Array i e -> [(i, e)]
assocs Graph
tg ]
loop :: [Int] -> [Int] -> ST s [Int]
loop [] [Int]
rs = [Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return [Int]
rs
loop (Int
n:[Int]
ns) [Int]
rs = do
STUArray s Int Bool -> Int -> Bool -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Bool
del Int
n Bool
True
let add :: [Int] -> ST s [Int]
add [] = [Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return [Int]
ns
add (Int
m:[Int]
ms) = do
Bool
b <- [Int] -> ST s Bool
ok (Graph
tgGraph -> Int -> [Int]
forall i e. Ix i => Array i e -> i -> e
!Int
m)
[Int]
ms' <- [Int] -> ST s [Int]
add [Int]
ms
[Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> ST s [Int]) -> [Int] -> ST s [Int]
forall a b. (a -> b) -> a -> b
$ if Bool
b then Int
m Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
ms' else [Int]
ms'
ok :: [Int] -> ST s Bool
ok [] = Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
ok (Int
x:[Int]
xs) = do Bool
b <- STUArray s Int Bool -> Int -> ST s Bool
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Bool
del Int
x; if Bool
b then [Int] -> ST s Bool
ok [Int]
xs else Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
[Int]
ns' <- [Int] -> ST s [Int]
add (Graph
gGraph -> Int -> [Int]
forall i e. Ix i => Array i e -> i -> e
!Int
n)
[Int] -> [Int] -> ST s [Int]
loop [Int]
ns' (Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
rs)
[Int] -> [Int] -> ST s [Int]
loop [Int]
starters []
{-# SPECIALIZE partials :: Kahn Double -> [(Int, Double)] #-}
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials Kahn a
tape = [ let v :: a
v = Array Int a
sensitivities Array Int a -> Int -> a
forall i e. Ix i => Array i e -> i -> e
! Int
ix in a -> (Int, a) -> (Int, a)
forall a b. a -> b -> b
seq a
v (Int
ident, a
v) | (Int
ix, Var a
_ Int
ident) <- [(Int, Tape a Int)]
xs ] where
Reified.Graph [(Int, Tape a Int)]
xs Int
start = IO (Graph (Tape a)) -> Graph (Tape a)
forall a. IO a -> a
unsafePerformIO (IO (Graph (Tape a)) -> Graph (Tape a))
-> IO (Graph (Tape a)) -> Graph (Tape a)
forall a b. (a -> b) -> a -> b
$ Kahn a -> IO (Graph (DeRef (Kahn a)))
forall s. MuRef s => s -> IO (Graph (DeRef s))
reifyGraph Kahn a
tape
g :: Graph
g = (Int, Int) -> [(Int, [Int])] -> Graph
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int, Int)
xsBounds [ (Int
i, Tape a Int -> [Int]
successors Tape a Int
t) | (Int
i, Tape a Int
t) <- [(Int, Tape a Int)]
xs ]
vertexMap :: Array Int (Tape a Int)
vertexMap = (Int, Int) -> [(Int, Tape a Int)] -> Array Int (Tape a Int)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int, Int)
xsBounds [(Int, Tape a Int)]
xs
vmap :: Int -> (Tape a Int, Int, [Int])
vmap Int
i = (Array Int (Tape a Int)
vertexMap Array Int (Tape a Int) -> Int -> Tape a Int
forall i e. Ix i => Array i e -> i -> e
! Int
i, Int
i, [])
xsBounds :: (Int, Int)
xsBounds = [(Int, Tape a Int)] -> (Int, Int)
forall {a} {b}. Ord a => [(a, b)] -> (a, a)
sbounds [(Int, Tape a Int)]
xs
sensitivities :: Array Int a
sensitivities = (forall s. ST s (STArray s Int a)) -> Array Int a
forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray ((forall s. ST s (STArray s Int a)) -> Array Int a)
-> (forall s. ST s (STArray s Int a)) -> Array Int a
forall a b. (a -> b) -> a -> b
$ do
STArray s Int a
ss <- (Int, Int) -> a -> ST s (STArray s Int a)
forall i. Ix i => (i, i) -> a -> ST s (STArray s i a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int, Int)
xsBounds a
0
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
start a
1
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph -> [Int]
topSortAcyclic Graph
g) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
forall a s.
Num a =>
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
backPropagate Int -> (Tape a Int, Int, [Int])
vmap STArray s Int a
ss
STArray s Int a -> ST s (STArray s Int a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return STArray s Int a
ss
sbounds :: [(a, b)] -> (a, a)
sbounds ((a
a,b
_):[(a, b)]
as) = ((a, a) -> (a, b) -> (a, a)) -> (a, a) -> [(a, b)] -> (a, a)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' (\(a
lo,a
hi) (a
b,b
_) -> let lo' :: a
lo' = a -> a -> a
forall a. Ord a => a -> a -> a
min a
lo a
b; hi' :: a
hi' = a -> a -> a
forall a. Ord a => a -> a -> a
max a
hi a
b in a
lo' a -> (a, a) -> (a, a)
forall a b. a -> b -> b
`seq` a
hi' a -> (a, a) -> (a, a)
forall a b. a -> b -> b
`seq` (a
lo', a
hi')) (a
a,a
a) [(a, b)]
as
sbounds [(a, b)]
_ = (a, a)
forall a. HasCallStack => a
undefined
successors :: Tape a Int -> [Int]
successors :: Tape a Int -> [Int]
successors (Unary a
_ a
_ Int
b) = [Int
b]
successors (Binary a
_ a
_ a
_ Int
b Int
c) = if Int
b Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
c then [Int
b] else [Int
b,Int
c]
successors Tape a Int
_ = []
partialArray :: Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray :: forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
vbounds Kahn a
tape = (a -> a -> a) -> a -> (Int, Int) -> [(Int, a)] -> Array Int a
forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 (Int, Int)
vbounds (Kahn a -> [(Int, a)]
forall a. Num a => Kahn a -> [(Int, a)]
partials Kahn a
tape)
{-# INLINE partialArray #-}
partialMap :: Num a => Kahn a -> IntMap a
partialMap :: forall a. Num a => Kahn a -> IntMap a
partialMap = (a -> a -> a) -> [(Int, a)] -> IntMap a
forall a. (a -> a -> a) -> [(Int, a)] -> IntMap a
fromListWith a -> a -> a
forall a. Num a => a -> a -> a
(+) ([(Int, a)] -> IntMap a)
-> (Kahn a -> [(Int, a)]) -> Kahn a -> IntMap a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kahn a -> [(Int, a)]
forall a. Num a => Kahn a -> [(Int, a)]
partials
{-# INLINE partialMap #-}
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
pack :: i -> [Kahn a] -> Kahn a
unpack :: ([a] -> [a]) -> o
unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (Kahn a) [a] (a, [a]) a where
pack :: Kahn a -> [Kahn a] -> Kahn a
pack Kahn a
i [Kahn a]
_ = Kahn a
i
unpack :: ([a] -> [a]) -> [a]
unpack [a] -> [a]
f = [a] -> [a]
f []
unpack' :: ([a] -> (a, [a])) -> (a, [a])
unpack' [a] -> (a, [a])
f = [a] -> (a, [a])
f []
instance Grad i o o' a => Grad (Kahn a -> i) (a -> o) (a -> o') a where
pack :: (Kahn a -> i) -> [Kahn a] -> Kahn a
pack Kahn a -> i
f (Kahn a
a:[Kahn a]
as) = i -> [Kahn a] -> Kahn a
forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack (Kahn a -> i
f Kahn a
a) [Kahn a]
as
pack Kahn a -> i
_ [] = String -> Kahn a
forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
unpack :: ([a] -> [a]) -> a -> o
unpack [a] -> [a]
f a
a = ([a] -> [a]) -> o
forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack ([a] -> [a]
f ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))
unpack' :: ([a] -> (a, [a])) -> a -> o'
unpack' [a] -> (a, [a])
f a
a = ([a] -> (a, [a])) -> o'
forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' ([a] -> (a, [a])
f ([a] -> (a, [a])) -> ([a] -> [a]) -> [a] -> (a, [a])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))
vgrad :: Grad i o o' a => i -> o
vgrad :: forall i o o' a. Grad i o o' a => i -> o
vgrad i
i = ([a] -> [a]) -> o
forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack (([Kahn a] -> Kahn a) -> [a] -> [a]
forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
unsafeGrad (i -> [Kahn a] -> Kahn a
forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack i
i)) where
unsafeGrad :: (f (Kahn a) -> Kahn a) -> f a -> f a
unsafeGrad f (Kahn a) -> Kahn a
f f a
as = f (Kahn a) -> Array Int a -> f a
forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
vs ((Int, Int) -> Kahn a -> Array Int a
forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
bds (Kahn a -> Array Int a) -> Kahn a -> Array Int a
forall a b. (a -> b) -> a -> b
$ f (Kahn a) -> Kahn a
f f (Kahn a)
vs) where
(f (Kahn a)
vs,(Int, Int)
bds) = f a -> (f (Kahn a), (Int, Int))
forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
as
vgrad' :: Grad i o o' a => i -> o'
vgrad' :: forall i o o' a. Grad i o o' a => i -> o'
vgrad' i
i = ([a] -> (a, [a])) -> o'
forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' (([Kahn a] -> Kahn a) -> [a] -> (a, [a])
forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> (a, f a)
unsafeGrad' (i -> [Kahn a] -> Kahn a
forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack i
i)) where
unsafeGrad' :: (f (Kahn a) -> Kahn a) -> f a -> (a, f a)
unsafeGrad' f (Kahn a) -> Kahn a
f f a
as = (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
r, f (Kahn a) -> Array Int a -> f a
forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
vs ((Int, Int) -> Kahn a -> Array Int a
forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
bds Kahn a
r)) where
r :: Kahn a
r = f (Kahn a) -> Kahn a
f f (Kahn a)
vs
(f (Kahn a)
vs,(Int, Int)
bds) = f a -> (f (Kahn a), (Int, Int))
forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
as
var :: a -> Int -> Kahn a
var :: forall a. a -> Int -> Kahn a
var a
a Int
v = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Int -> Tape a (Kahn a)
forall a t. a -> Int -> Tape a t
Var a
a Int
v)
varId :: Kahn a -> Int
varId :: forall a. Kahn a -> Int
varId (Kahn (Var a
_ Int
v)) = Int
v
varId Kahn a
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"varId: not a Var"
bind :: Traversable f => f a -> (f (Kahn a), (Int,Int))
bind :: forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
xs = (f (Kahn a)
r,(Int
0,Int
hi)) where
(f (Kahn a)
r,Int
hi) = State Int (f (Kahn a)) -> Int -> (f (Kahn a), Int)
forall s a. State s a -> s -> (a, s)
runState ((a -> StateT Int Identity (Kahn a))
-> f a -> State Int (f (Kahn a))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> f a -> m (f b)
mapM a -> StateT Int Identity (Kahn a)
forall {m :: * -> *} {a}. Monad m => a -> StateT Int m (Kahn a)
freshVar f a
xs) Int
0
freshVar :: a -> StateT Int m (Kahn a)
freshVar a
a = (Int -> (Kahn a, Int)) -> StateT Int m (Kahn a)
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> (Kahn a, Int)) -> StateT Int m (Kahn a))
-> (Int -> (Kahn a, Int)) -> StateT Int m (Kahn a)
forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
s' Int -> (Kahn a, Int) -> (Kahn a, Int)
forall a b. a -> b -> b
`seq` (a -> Int -> Kahn a
forall a. a -> Int -> Kahn a
var a
a Int
s, Int
s')
unbind :: Functor f => f (Kahn a) -> Array Int a -> f a
unbind :: forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
xs Array Int a
ys = (Kahn a -> a) -> f (Kahn a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> Array Int a
ys Array Int a -> Int -> a
forall i e. Ix i => Array i e -> i -> e
! Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v) f (Kahn a)
xs
unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) a b c.
(Functor f, Num a) =>
(a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith a -> b -> c
f f (Kahn a)
xs Array Int b
ys = (Kahn a -> c) -> f (Kahn a) -> f c
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> b -> c
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
v) (Array Int b
ys Array Int b -> Int -> b
forall i e. Ix i => Array i e -> i -> e
! Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v)) f (Kahn a)
xs
unbindMap :: (Functor f, Num a) => f (Kahn a) -> IntMap a -> f a
unbindMap :: forall (f :: * -> *) a.
(Functor f, Num a) =>
f (Kahn a) -> IntMap a -> f a
unbindMap f (Kahn a)
xs IntMap a
ys = (Kahn a -> a) -> f (Kahn a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> Int -> IntMap a -> a
forall a. a -> Int -> IntMap a -> a
findWithDefault a
0 (Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v) IntMap a
ys) f (Kahn a)
xs
unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) a b c.
(Functor f, Num a) =>
b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault b
z a -> b -> c
f f (Kahn a)
xs IntMap b
ys = (Kahn a -> c) -> f (Kahn a) -> f c
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> b -> c
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
v) (b -> c) -> b -> c
forall a b. (a -> b) -> a -> b
$ b -> Int -> IntMap b -> b
forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v) IntMap b
ys) f (Kahn a)
xs