{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- This module provides reverse-mode Automatic Differentiation implementation using
-- linear time topological sorting after the fact.
--
-- For this form of reverse-mode AD we use 'System.Mem.StableName.StableName' to recover
-- sharing information from the tape to avoid combinatorial explosion, and thus
-- run asymptotically faster than it could without such sharing information, but the use
-- of side-effects contained herein is benign.
--
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Kahn
  ( Kahn(..)
  , Tape(..)
  , partials
  , partialArray
  , partialMap
  , derivative
  , derivative'
  , vgrad, vgrad'
  , Grad(..)
  , bind
  , unbind
  , unbindMap
  , unbindWith
  , unbindMapWithDefault
  , primal
  , var
  , varId
  ) where

import Control.Monad.ST
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import qualified Data.List as List (foldl')
import Data.Array.ST
import Data.Array
import Data.IntMap (IntMap, fromListWith, findWithDefault)
import Data.Graph (Vertex, transposeG, Graph)
import Data.Number.Erf
import Data.Reify (reifyGraph, MuRef(..))
import qualified Data.Reify.Graph as Reified
import System.IO.Unsafe (unsafePerformIO)
import Data.Data (Data)
import Data.Typeable (Typeable)
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | A @Tape@ records the information needed back propagate from the output to each input during reverse 'Mode' AD.
data Tape a t
  = Zero
  | Lift !a
  | Var !a {-# UNPACK #-} !Int
  | Binary !a a a t t
  | Unary !a a t
  deriving (Int -> Tape a t -> ShowS
[Tape a t] -> ShowS
Tape a t -> String
(Int -> Tape a t -> ShowS)
-> (Tape a t -> String) -> ([Tape a t] -> ShowS) -> Show (Tape a t)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a t. (Show a, Show t) => Int -> Tape a t -> ShowS
forall a t. (Show a, Show t) => [Tape a t] -> ShowS
forall a t. (Show a, Show t) => Tape a t -> String
$cshowsPrec :: forall a t. (Show a, Show t) => Int -> Tape a t -> ShowS
showsPrec :: Int -> Tape a t -> ShowS
$cshow :: forall a t. (Show a, Show t) => Tape a t -> String
show :: Tape a t -> String
$cshowList :: forall a t. (Show a, Show t) => [Tape a t] -> ShowS
showList :: [Tape a t] -> ShowS
Show, Typeable (Tape a t)
Typeable (Tape a t) =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> Tape a t -> c (Tape a t))
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c (Tape a t))
-> (Tape a t -> Constr)
-> (Tape a t -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c (Tape a t)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e))
    -> Maybe (c (Tape a t)))
-> ((forall b. Data b => b -> b) -> Tape a t -> Tape a t)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> Tape a t -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> Tape a t -> r)
-> (forall u. (forall d. Data d => d -> u) -> Tape a t -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t))
-> Data (Tape a t)
Tape a t -> Constr
Tape a t -> DataType
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
forall a.
Typeable a =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u
forall u. (forall d. Data d => d -> u) -> Tape a t -> [u]
forall a t. (Data a, Data t) => Typeable (Tape a t)
forall a t. (Data a, Data t) => Tape a t -> Constr
forall a t. (Data a, Data t) => Tape a t -> DataType
forall a t.
(Data a, Data t) =>
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
forall a t u.
(Data a, Data t) =>
Int -> (forall d. Data d => d -> u) -> Tape a t -> u
forall a t u.
(Data a, Data t) =>
(forall d. Data d => d -> u) -> Tape a t -> [u]
forall a t r r'.
(Data a, Data t) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall a t r r'.
(Data a, Data t) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall a t (m :: * -> *).
(Data a, Data t, Monad m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall a t (c :: * -> *).
(Data a, Data t) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
forall a t (c :: * -> *).
(Data a, Data t) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
forall a t (t :: * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
forall a t (t :: * -> * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
$cgfoldl :: forall a t (c :: * -> *).
(Data a, Data t) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
$cgunfold :: forall a t (c :: * -> *).
(Data a, Data t) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
$ctoConstr :: forall a t. (Data a, Data t) => Tape a t -> Constr
toConstr :: Tape a t -> Constr
$cdataTypeOf :: forall a t. (Data a, Data t) => Tape a t -> DataType
dataTypeOf :: Tape a t -> DataType
$cdataCast1 :: forall a t (t :: * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
$cdataCast2 :: forall a t (t :: * -> * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
$cgmapT :: forall a t.
(Data a, Data t) =>
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
gmapT :: (forall b. Data b => b -> b) -> Tape a t -> Tape a t
$cgmapQl :: forall a t r r'.
(Data a, Data t) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
$cgmapQr :: forall a t r r'.
(Data a, Data t) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
$cgmapQ :: forall a t u.
(Data a, Data t) =>
(forall d. Data d => d -> u) -> Tape a t -> [u]
gmapQ :: forall u. (forall d. Data d => d -> u) -> Tape a t -> [u]
$cgmapQi :: forall a t u.
(Data a, Data t) =>
Int -> (forall d. Data d => d -> u) -> Tape a t -> u
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u
$cgmapM :: forall a t (m :: * -> *).
(Data a, Data t, Monad m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapMp :: forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapMo :: forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
Data, Typeable)

-- | @Kahn@ is a 'Mode' using reverse-mode automatic differentiation that provides fast 'diffFU', 'diff2FU', 'grad', 'grad2' and a fast 'jacobian' when you have a significantly smaller number of outputs than inputs.
newtype Kahn a = Kahn (Tape a (Kahn a)) deriving (Int -> Kahn a -> ShowS
[Kahn a] -> ShowS
Kahn a -> String
(Int -> Kahn a -> ShowS)
-> (Kahn a -> String) -> ([Kahn a] -> ShowS) -> Show (Kahn a)
forall a. Show a => Int -> Kahn a -> ShowS
forall a. Show a => [Kahn a] -> ShowS
forall a. Show a => Kahn a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Kahn a -> ShowS
showsPrec :: Int -> Kahn a -> ShowS
$cshow :: forall a. Show a => Kahn a -> String
show :: Kahn a -> String
$cshowList :: forall a. Show a => [Kahn a] -> ShowS
showList :: [Kahn a] -> ShowS
Show, Typeable)

instance MuRef (Kahn a) where
  type DeRef (Kahn a) = Tape a

  mapDeRef :: forall (f :: * -> *) u.
Applicative f =>
(forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u)
-> Kahn a -> f (DeRef (Kahn a) u)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn Tape a (Kahn a)
Zero) = Tape a u -> f (Tape a u)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tape a u
forall a t. Tape a t
Zero
  mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn (Lift a
a)) = Tape a u -> f (Tape a u)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Tape a u
forall a t. a -> Tape a t
Lift a
a)
  mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn (Var a
a Int
v)) = Tape a u -> f (Tape a u)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Int -> Tape a u
forall a t. a -> Int -> Tape a t
Var a
a Int
v)
  mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f (Kahn (Binary a
a a
dadb a
dadc Kahn a
b Kahn a
c)) = a -> a -> a -> u -> u -> Tape a u
forall a t. a -> a -> a -> t -> t -> Tape a t
Binary a
a a
dadb a
dadc (u -> u -> Tape a u) -> f u -> f (u -> Tape a u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Kahn a -> f u
forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
b f (u -> Tape a u) -> f u -> f (Tape a u)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Kahn a -> f u
forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
c
  mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f (Kahn (Unary a
a a
dadb Kahn a
b)) = a -> a -> u -> Tape a u
forall a t. a -> a -> t -> Tape a t
Unary a
a a
dadb (u -> Tape a u) -> f u -> f (Tape a u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Kahn a -> f u
forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
b

instance Num a => Mode (Kahn a) where
  type Scalar (Kahn a) = a

  isKnownZero :: Kahn a -> Bool
isKnownZero (Kahn Tape a (Kahn a)
Zero) = Bool
True
  isKnownZero Kahn a
_    = Bool
False

  asKnownConstant :: Kahn a -> Maybe (Scalar (Kahn a))
asKnownConstant (Kahn Tape a (Kahn a)
Zero) = a -> Maybe a
forall a. a -> Maybe a
Just a
0
  asKnownConstant (Kahn (Lift a
n)) = a -> Maybe a
forall a. a -> Maybe a
Just a
n
  asKnownConstant Kahn a
_ = Maybe a
Maybe (Scalar (Kahn a))
forall a. Maybe a
Nothing

  isKnownConstant :: Kahn a -> Bool
isKnownConstant (Kahn Tape a (Kahn a)
Zero) = Bool
True
  isKnownConstant (Kahn (Lift a
_)) = Bool
True
  isKnownConstant Kahn a
_ = Bool
False

  auto :: Scalar (Kahn a) -> Kahn a
auto Scalar (Kahn a)
a = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift a
Scalar (Kahn a)
a)
  zero :: Kahn a
zero   = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn Tape a (Kahn a)
forall a t. Tape a t
Zero

  Scalar (Kahn a)
a *^ :: Scalar (Kahn a) -> Kahn a -> Kahn a
*^ Kahn a
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a)
a Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
*) (\D (Kahn a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto Scalar (Id a)
Scalar (Kahn a)
a) Kahn a
b
  Kahn a
a ^* :: Kahn a -> Scalar (Kahn a) -> Kahn a
^* Scalar (Kahn a)
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
* Scalar (Kahn a)
b) (\D (Kahn a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto Scalar (Id a)
Scalar (Kahn a)
b) Kahn a
a
  Kahn a
a ^/ :: Fractional (Scalar (Kahn a)) => Kahn a -> Scalar (Kahn a) -> Kahn a
^/ Scalar (Kahn a)
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Fractional a => a -> a -> a
/ Scalar (Kahn a)
b) (\D (Kahn a)
_ -> Scalar (Id a) -> Id a
forall t. Mode t => Scalar t -> t
auto (a -> a
forall a. Fractional a => a -> a
recip a
Scalar (Kahn a)
b)) Kahn a
a

(<+>) :: Num a => Kahn a -> Kahn a -> Kahn a
<+> :: forall a. Num a => Kahn a -> Kahn a -> Kahn a
(<+>)  = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary a -> a -> a
Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
(+) D (Kahn a)
Id a
1 D (Kahn a)
Id a
1

primal :: Num a => Kahn a -> a
primal :: forall a. Num a => Kahn a -> a
primal (Kahn Tape a (Kahn a)
Zero) = a
0
primal (Kahn (Lift a
a)) = a
a
primal (Kahn (Var a
a Int
_)) = a
a
primal (Kahn (Binary a
a a
_ a
_ Kahn a
_ Kahn a
_)) = a
a
primal (Kahn (Unary a
a a
_ Kahn a
_)) = a
a

instance Num a => Jacobian (Kahn a) where
  type D (Kahn a) = Id a

  unary :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
unary Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         (Kahn Tape a (Kahn a)
Zero)     = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0))
  unary Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         (Kahn (Lift a
a)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
a))
  unary Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) Kahn a
b                  = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b)) a
dadb Kahn a
b)

  lift1 :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
lift1 Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a)
df Kahn a
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (Kahn a) -> Scalar (Kahn a)
f (D (Kahn a) -> D (Kahn a)
df (a -> Id a
forall a. a -> Id a
Id a
pb)) Kahn a
b where
    pb :: a
pb = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b

  lift1_ :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
lift1_ Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> D (Kahn a)
df Kahn a
b = (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (a -> a -> a
forall a b. a -> b -> a
const a
Scalar (Kahn a)
a) (D (Kahn a) -> D (Kahn a) -> D (Kahn a)
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Kahn a)
a) (a -> Id a
forall a. a -> Id a
Id a
pb)) Kahn a
b where
    pb :: a
pb = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b
    a :: Scalar (Kahn a)
a = Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
pb

  binary :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         D (Kahn a)
_         (Kahn Tape a (Kahn a)
Zero)     (Kahn Tape a (Kahn a)
Zero)     = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0 a
Scalar (Kahn a)
0))
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         D (Kahn a)
_         (Kahn Tape a (Kahn a)
Zero)     (Kahn (Lift a
c)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0 a
Scalar (Kahn a)
c))
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         D (Kahn a)
_         (Kahn (Lift a
b)) (Kahn Tape a (Kahn a)
Zero)     = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
b a
Scalar (Kahn a)
0))
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         D (Kahn a)
_         (Kahn (Lift a
b)) (Kahn (Lift a
c)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Tape a (Kahn a)
forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
b a
Scalar (Kahn a)
c))
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         (Id a
dadc) (Kahn Tape a (Kahn a)
Zero)     Kahn a
c                  = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
0 (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadc Kahn a
c)
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_         (Id a
dadc) (Kahn (Lift a
b)) Kahn a
c                  = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
b (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadc Kahn a
c)
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) D (Kahn a)
_         Kahn a
b                  (Kahn Tape a (Kahn a)
Zero)     = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b) a
Scalar (Kahn a)
0) a
dadb Kahn a
b)
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) D (Kahn a)
_         Kahn a
b                  (Kahn (Lift a
c)) = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b) a
Scalar (Kahn a)
c) a
dadb Kahn a
b)
  binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) (Id a
dadc) Kahn a
b                  Kahn a
c                  = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> a -> a -> Kahn a -> Kahn a -> Tape a (Kahn a)
forall a t. a -> a -> a -> t -> t -> Tape a t
Binary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b) (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadb a
dadc Kahn a
b Kahn a
c)

  lift2 :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
lift2 Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df Kahn a
b Kahn a
c = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
dadb D (Kahn a)
dadc Kahn a
b Kahn a
c where
    (D (Kahn a)
dadb, D (Kahn a)
dadc) = D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df (a -> Id a
forall a. a -> Id a
Id (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b)) (a -> Id a
forall a. a -> Id a
Id (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c))

  lift2_ :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a)
    -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
lift2_ Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df Kahn a
b Kahn a
c = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (Kahn a)
_ Scalar (Kahn a)
_ -> Scalar (Kahn a)
a) D (Kahn a)
dadb D (Kahn a)
dadc Kahn a
b Kahn a
c where
    pb :: a
pb = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
b
    pc :: a
pc = Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
c
    a :: Scalar (Kahn a)
a = Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
Scalar (Kahn a)
pb a
Scalar (Kahn a)
pc
    (D (Kahn a)
dadb, D (Kahn a)
dadc) = D (Kahn a) -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Kahn a)
a) (a -> Id a
forall a. a -> Id a
Id a
pb) (a -> Id a
forall a. a -> Id a
Id a
pc)


mul :: Num a => Kahn a -> Kahn a -> Kahn a
mul :: forall a. Num a => Kahn a -> Kahn a -> Kahn a
mul = (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 a -> a -> a
Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
forall a. Num a => a -> a -> a
(*) (\D (Kahn a)
x D (Kahn a)
y -> (D (Kahn a)
y, D (Kahn a)
x))

#define HEAD (Kahn a)
#include <instances.h>

derivative :: Num a => Kahn a -> a
derivative :: forall a. Num a => Kahn a -> a
derivative = [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> (Kahn a -> [a]) -> Kahn a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> (Kahn a -> [(Int, a)]) -> Kahn a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kahn a -> [(Int, a)]
forall a. Num a => Kahn a -> [(Int, a)]
partials
{-# INLINE derivative #-}

derivative' :: Num a => Kahn a -> (a, a)
derivative' :: forall a. Num a => Kahn a -> (a, a)
derivative' Kahn a
r = (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
r, Kahn a -> a
forall a. Num a => Kahn a -> a
derivative Kahn a
r)
{-# INLINE derivative' #-}

-- | back propagate sensitivities along a tape.
backPropagate :: Num a => (Vertex -> (Tape a Int, Int, [Int])) -> STArray s Int a -> Vertex -> ST s ()
backPropagate :: forall a s.
Num a =>
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
backPropagate Int -> (Tape a Int, Int, [Int])
vmap STArray s Int a
ss Int
v = case Tape a Int
node of
  Unary a
_ a
g Int
b -> do
    a
da <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
    a
db <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
b
    STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
b (a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
da)
  Binary a
_ a
gb a
gc Int
b Int
c -> do
    a
da <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
    a
db <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
b
    STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
b (a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
gba -> a -> a
forall a. Num a => a -> a -> a
*a
da)
    a
dc <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
c
    STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
c (a
dc a -> a -> a
forall a. Num a => a -> a -> a
+ a
gca -> a -> a
forall a. Num a => a -> a -> a
*a
da)
  Tape a Int
_ -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    (Tape a Int
node, Int
i, [Int]
_) = Int -> (Tape a Int, Int, [Int])
vmap Int
v
    -- this isn't _quite_ right, as it should allow negative zeros to multiply through

topSortAcyclic :: Graph -> [Vertex]
topSortAcyclic :: Graph -> [Int]
topSortAcyclic Graph
g = [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (forall s. ST s [Int]) -> [Int]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [Int]) -> [Int])
-> (forall s. ST s [Int]) -> [Int]
forall a b. (a -> b) -> a -> b
$ do
  STUArray s Int Bool
del <- (Int, Int) -> Bool -> ST s (STUArray s Int Bool)
forall i. Ix i => (i, i) -> Bool -> ST s (STUArray s i Bool)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Graph -> (Int, Int)
forall i e. Array i e -> (i, i)
bounds Graph
g) Bool
False :: ST s (STUArray s Int Bool)
  let tg :: Graph
tg = Graph -> Graph
transposeG Graph
g
      starters :: [Int]
starters = [ Int
n | (Int
n, []) <- Graph -> [(Int, [Int])]
forall i e. Ix i => Array i e -> [(i, e)]
assocs Graph
tg ]
      loop :: [Int] -> [Int] -> ST s [Int]
loop [] [Int]
rs = [Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return [Int]
rs
      loop (Int
n:[Int]
ns) [Int]
rs = do
        STUArray s Int Bool -> Int -> Bool -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Bool
del Int
n Bool
True
        let add :: [Int] -> ST s [Int]
add [] = [Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return [Int]
ns
            add (Int
m:[Int]
ms) = do
              Bool
b <- [Int] -> ST s Bool
ok (Graph
tgGraph -> Int -> [Int]
forall i e. Ix i => Array i e -> i -> e
!Int
m)
              [Int]
ms' <- [Int] -> ST s [Int]
add [Int]
ms
              [Int] -> ST s [Int]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> ST s [Int]) -> [Int] -> ST s [Int]
forall a b. (a -> b) -> a -> b
$ if Bool
b then Int
m Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
ms' else [Int]
ms'
            ok :: [Int] -> ST s Bool
ok [] = Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            ok (Int
x:[Int]
xs) = do Bool
b <- STUArray s Int Bool -> Int -> ST s Bool
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Bool
del Int
x; if Bool
b then [Int] -> ST s Bool
ok [Int]
xs else Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        [Int]
ns' <- [Int] -> ST s [Int]
add (Graph
gGraph -> Int -> [Int]
forall i e. Ix i => Array i e -> i -> e
!Int
n)
        [Int] -> [Int] -> ST s [Int]
loop [Int]
ns' (Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
rs)
  [Int] -> [Int] -> ST s [Int]
loop [Int]
starters []

-- | This returns a list of contributions to the partials.
-- The variable ids returned in the list are likely /not/ unique!
{-# SPECIALIZE partials :: Kahn Double -> [(Int, Double)] #-}
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials Kahn a
tape = [ let v :: a
v = Array Int a
sensitivities Array Int a -> Int -> a
forall i e. Ix i => Array i e -> i -> e
! Int
ix in a -> (Int, a) -> (Int, a)
forall a b. a -> b -> b
seq a
v (Int
ident, a
v) | (Int
ix, Var a
_ Int
ident) <- [(Int, Tape a Int)]
xs ] where
  Reified.Graph [(Int, Tape a Int)]
xs Int
start = IO (Graph (Tape a)) -> Graph (Tape a)
forall a. IO a -> a
unsafePerformIO (IO (Graph (Tape a)) -> Graph (Tape a))
-> IO (Graph (Tape a)) -> Graph (Tape a)
forall a b. (a -> b) -> a -> b
$ Kahn a -> IO (Graph (DeRef (Kahn a)))
forall s. MuRef s => s -> IO (Graph (DeRef s))
reifyGraph Kahn a
tape
  g :: Graph
g = (Int, Int) -> [(Int, [Int])] -> Graph
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int, Int)
xsBounds [ (Int
i, Tape a Int -> [Int]
successors Tape a Int
t) | (Int
i, Tape a Int
t) <- [(Int, Tape a Int)]
xs ]
  vertexMap :: Array Int (Tape a Int)
vertexMap = (Int, Int) -> [(Int, Tape a Int)] -> Array Int (Tape a Int)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int, Int)
xsBounds [(Int, Tape a Int)]
xs
  vmap :: Int -> (Tape a Int, Int, [Int])
vmap Int
i = (Array Int (Tape a Int)
vertexMap Array Int (Tape a Int) -> Int -> Tape a Int
forall i e. Ix i => Array i e -> i -> e
! Int
i, Int
i, [])
  xsBounds :: (Int, Int)
xsBounds = [(Int, Tape a Int)] -> (Int, Int)
forall {a} {b}. Ord a => [(a, b)] -> (a, a)
sbounds [(Int, Tape a Int)]
xs

  sensitivities :: Array Int a
sensitivities = (forall s. ST s (STArray s Int a)) -> Array Int a
forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray ((forall s. ST s (STArray s Int a)) -> Array Int a)
-> (forall s. ST s (STArray s Int a)) -> Array Int a
forall a b. (a -> b) -> a -> b
$ do
    STArray s Int a
ss <- (Int, Int) -> a -> ST s (STArray s Int a)
forall i. Ix i => (i, i) -> a -> ST s (STArray s i a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int, Int)
xsBounds a
0
    STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
start a
1
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph -> [Int]
topSortAcyclic Graph
g) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$
      (Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
forall a s.
Num a =>
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
backPropagate Int -> (Tape a Int, Int, [Int])
vmap STArray s Int a
ss
    STArray s Int a -> ST s (STArray s Int a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return STArray s Int a
ss

  sbounds :: [(a, b)] -> (a, a)
sbounds ((a
a,b
_):[(a, b)]
as) = ((a, a) -> (a, b) -> (a, a)) -> (a, a) -> [(a, b)] -> (a, a)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' (\(a
lo,a
hi) (a
b,b
_) -> let lo' :: a
lo' = a -> a -> a
forall a. Ord a => a -> a -> a
min a
lo a
b; hi' :: a
hi' = a -> a -> a
forall a. Ord a => a -> a -> a
max a
hi a
b in a
lo' a -> (a, a) -> (a, a)
forall a b. a -> b -> b
`seq` a
hi' a -> (a, a) -> (a, a)
forall a b. a -> b -> b
`seq` (a
lo', a
hi')) (a
a,a
a) [(a, b)]
as
  sbounds [(a, b)]
_ = (a, a)
forall a. HasCallStack => a
undefined -- the graph can't be empty, it contains the output node!

  successors :: Tape a Int -> [Int]
  successors :: Tape a Int -> [Int]
successors (Unary a
_ a
_ Int
b) = [Int
b]
  successors (Binary a
_ a
_ a
_ Int
b Int
c) = if Int
b Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
c then [Int
b] else [Int
b,Int
c]
  successors Tape a Int
_ = []

-- | Return an 'Array' of 'partials' given bounds for the variable IDs.
partialArray :: Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray :: forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
vbounds Kahn a
tape = (a -> a -> a) -> a -> (Int, Int) -> [(Int, a)] -> Array Int a
forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 (Int, Int)
vbounds (Kahn a -> [(Int, a)]
forall a. Num a => Kahn a -> [(Int, a)]
partials Kahn a
tape)
{-# INLINE partialArray #-}

-- | Return an 'IntMap' of sparse partials
partialMap :: Num a => Kahn a -> IntMap a
partialMap :: forall a. Num a => Kahn a -> IntMap a
partialMap = (a -> a -> a) -> [(Int, a)] -> IntMap a
forall a. (a -> a -> a) -> [(Int, a)] -> IntMap a
fromListWith a -> a -> a
forall a. Num a => a -> a -> a
(+) ([(Int, a)] -> IntMap a)
-> (Kahn a -> [(Int, a)]) -> Kahn a -> IntMap a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kahn a -> [(Int, a)]
forall a. Num a => Kahn a -> [(Int, a)]
partials
{-# INLINE partialMap #-}

class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
  pack :: i -> [Kahn a] -> Kahn a
  unpack :: ([a] -> [a]) -> o
  unpack' :: ([a] -> (a, [a])) -> o'

instance Num a => Grad (Kahn a) [a] (a, [a]) a where
  pack :: Kahn a -> [Kahn a] -> Kahn a
pack Kahn a
i [Kahn a]
_ = Kahn a
i
  unpack :: ([a] -> [a]) -> [a]
unpack [a] -> [a]
f = [a] -> [a]
f []
  unpack' :: ([a] -> (a, [a])) -> (a, [a])
unpack' [a] -> (a, [a])
f = [a] -> (a, [a])
f []

instance Grad i o o' a => Grad (Kahn a -> i) (a -> o) (a -> o') a where
  pack :: (Kahn a -> i) -> [Kahn a] -> Kahn a
pack Kahn a -> i
f (Kahn a
a:[Kahn a]
as) = i -> [Kahn a] -> Kahn a
forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack (Kahn a -> i
f Kahn a
a) [Kahn a]
as
  pack Kahn a -> i
_ [] = String -> Kahn a
forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpack :: ([a] -> [a]) -> a -> o
unpack [a] -> [a]
f a
a = ([a] -> [a]) -> o
forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack ([a] -> [a]
f ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))
  unpack' :: ([a] -> (a, [a])) -> a -> o'
unpack' [a] -> (a, [a])
f a
a = ([a] -> (a, [a])) -> o'
forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' ([a] -> (a, [a])
f ([a] -> (a, [a])) -> ([a] -> [a]) -> [a] -> (a, [a])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))

vgrad :: Grad i o o' a => i -> o
vgrad :: forall i o o' a. Grad i o o' a => i -> o
vgrad i
i = ([a] -> [a]) -> o
forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack (([Kahn a] -> Kahn a) -> [a] -> [a]
forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
unsafeGrad (i -> [Kahn a] -> Kahn a
forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack i
i)) where
  unsafeGrad :: (f (Kahn a) -> Kahn a) -> f a -> f a
unsafeGrad f (Kahn a) -> Kahn a
f f a
as = f (Kahn a) -> Array Int a -> f a
forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
vs ((Int, Int) -> Kahn a -> Array Int a
forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
bds (Kahn a -> Array Int a) -> Kahn a -> Array Int a
forall a b. (a -> b) -> a -> b
$ f (Kahn a) -> Kahn a
f f (Kahn a)
vs) where
    (f (Kahn a)
vs,(Int, Int)
bds) = f a -> (f (Kahn a), (Int, Int))
forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
as

vgrad' :: Grad i o o' a => i -> o'
vgrad' :: forall i o o' a. Grad i o o' a => i -> o'
vgrad' i
i = ([a] -> (a, [a])) -> o'
forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' (([Kahn a] -> Kahn a) -> [a] -> (a, [a])
forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> (a, f a)
unsafeGrad' (i -> [Kahn a] -> Kahn a
forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack i
i)) where
  unsafeGrad' :: (f (Kahn a) -> Kahn a) -> f a -> (a, f a)
unsafeGrad' f (Kahn a) -> Kahn a
f f a
as = (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
r, f (Kahn a) -> Array Int a -> f a
forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
vs ((Int, Int) -> Kahn a -> Array Int a
forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
bds Kahn a
r)) where
    r :: Kahn a
r = f (Kahn a) -> Kahn a
f f (Kahn a)
vs
    (f (Kahn a)
vs,(Int, Int)
bds) = f a -> (f (Kahn a), (Int, Int))
forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
as

var :: a -> Int -> Kahn a
var :: forall a. a -> Int -> Kahn a
var a
a Int
v = Tape a (Kahn a) -> Kahn a
forall a. Tape a (Kahn a) -> Kahn a
Kahn (a -> Int -> Tape a (Kahn a)
forall a t. a -> Int -> Tape a t
Var a
a Int
v)

varId :: Kahn a -> Int
varId :: forall a. Kahn a -> Int
varId (Kahn (Var a
_ Int
v)) = Int
v
varId Kahn a
_ = String -> Int
forall a. HasCallStack => String -> a
error String
"varId: not a Var"

bind :: Traversable f => f a -> (f (Kahn a), (Int,Int))
bind :: forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
xs = (f (Kahn a)
r,(Int
0,Int
hi)) where
  (f (Kahn a)
r,Int
hi) = State Int (f (Kahn a)) -> Int -> (f (Kahn a), Int)
forall s a. State s a -> s -> (a, s)
runState ((a -> StateT Int Identity (Kahn a))
-> f a -> State Int (f (Kahn a))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> f a -> m (f b)
mapM a -> StateT Int Identity (Kahn a)
forall {m :: * -> *} {a}. Monad m => a -> StateT Int m (Kahn a)
freshVar f a
xs) Int
0
  freshVar :: a -> StateT Int m (Kahn a)
freshVar a
a = (Int -> (Kahn a, Int)) -> StateT Int m (Kahn a)
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Int -> (Kahn a, Int)) -> StateT Int m (Kahn a))
-> (Int -> (Kahn a, Int)) -> StateT Int m (Kahn a)
forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
s' Int -> (Kahn a, Int) -> (Kahn a, Int)
forall a b. a -> b -> b
`seq` (a -> Int -> Kahn a
forall a. a -> Int -> Kahn a
var a
a Int
s, Int
s')

unbind :: Functor f => f (Kahn a) -> Array Int a -> f a
unbind :: forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
xs Array Int a
ys = (Kahn a -> a) -> f (Kahn a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> Array Int a
ys Array Int a -> Int -> a
forall i e. Ix i => Array i e -> i -> e
! Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v) f (Kahn a)
xs

unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) a b c.
(Functor f, Num a) =>
(a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith a -> b -> c
f f (Kahn a)
xs Array Int b
ys = (Kahn a -> c) -> f (Kahn a) -> f c
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> b -> c
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
v) (Array Int b
ys Array Int b -> Int -> b
forall i e. Ix i => Array i e -> i -> e
! Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v)) f (Kahn a)
xs

unbindMap :: (Functor f, Num a) => f (Kahn a) -> IntMap a -> f a
unbindMap :: forall (f :: * -> *) a.
(Functor f, Num a) =>
f (Kahn a) -> IntMap a -> f a
unbindMap f (Kahn a)
xs IntMap a
ys = (Kahn a -> a) -> f (Kahn a) -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> Int -> IntMap a -> a
forall a. a -> Int -> IntMap a -> a
findWithDefault a
0 (Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v) IntMap a
ys) f (Kahn a)
xs

unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) a b c.
(Functor f, Num a) =>
b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault b
z a -> b -> c
f f (Kahn a)
xs IntMap b
ys = (Kahn a -> c) -> f (Kahn a) -> f c
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> b -> c
f (Kahn a -> a
forall a. Num a => Kahn a -> a
primal Kahn a
v) (b -> c) -> b -> c
forall a b. (a -> b) -> a -> b
$ b -> Int -> IntMap b -> b
forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (Kahn a -> Int
forall a. Kahn a -> Int
varId Kahn a
v) IntMap b
ys) f (Kahn a)
xs