module Numeric.AD.Internal.Reverse
( Reverse(..)
, Tape(..)
, Head(..)
, Cells(..)
, reifyTape
, partials
, partialArrayOf
, partialMapOf
, derivativeOf
, derivativeOf'
, bind
, unbind
, unbindMap
, unbindWith
, unbindMapWithDefault
, var
, varId
, primal
) where
import Data.Functor
import Control.Monad hiding (mapM)
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Array.ST
import Data.Array
import Data.Array.Unsafe as Unsafe
import Data.IORef
import Data.IntMap (IntMap, fromDistinctAscList, findWithDefault)
import Data.Number.Erf
import Data.Proxy
import Data.Reflection
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable (Traversable, mapM)
#else
import Data.Traversable (mapM)
#endif
import Data.Typeable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)
import Unsafe.Coerce
#ifdef HLINT
#endif
#ifndef HLINT
data Cells where
Nil :: Cells
Unary :: !Int -> a -> Cells -> Cells
Binary :: !Int -> !Int -> a -> a -> Cells -> Cells
#endif
dropCells :: Int -> Cells -> Cells
dropCells 0 xs = xs
dropCells _ Nil = Nil
dropCells n (Unary _ _ xs) = (dropCells $! n 1) xs
dropCells n (Binary _ _ _ _ xs) = (dropCells $! n 1) xs
data Head = Head !Int Cells
newtype Tape = Tape { getTape :: IORef Head }
un :: Int -> a -> Head -> (Head, Int)
un i di (Head r t) = h `seq` r' `seq` (h, r') where
r' = r + 1
h = Head r' (Unary i di t)
bin :: Int -> Int -> a -> a -> Head -> (Head, Int)
bin i j di dj (Head r t) = h `seq` r' `seq` (h, r') where
r' = r + 1
h = Head r' (Binary i j di dj t)
modifyTape :: Reifies s Tape => p s -> (Head -> (Head, r)) -> IO r
modifyTape p = atomicModifyIORef (getTape (reflect p))
unarily :: forall s a. Reifies s Tape => (a -> a) -> a -> Int -> a -> Reverse s a
unarily f di i b = Reverse (unsafePerformIO (modifyTape (Proxy :: Proxy s) (un i di))) $! f b
binarily :: forall s a. Reifies s Tape => (a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily f di dj i b j c = Reverse (unsafePerformIO (modifyTape (Proxy :: Proxy s) (bin i j di dj))) $! f b c
#ifndef HLINT
data Reverse s a where
Zero :: Reverse s a
Lift :: a -> Reverse s a
Reverse :: !Int -> a -> Reverse s a
deriving (Show, Typeable)
#endif
instance (Reifies s Tape, Num a) => Mode (Reverse s a) where
type Scalar (Reverse s a) = a
isKnownZero Zero = True
isKnownZero _ = False
isKnownConstant Reverse{} = False
isKnownConstant _ = True
auto = Lift
zero = Zero
a *^ b = lift1 (a *) (\_ -> auto a) b
a ^* b = lift1 (* b) (\_ -> auto b) a
a ^/ b = lift1 (/ b) (\_ -> auto (recip b)) a
(<+>) :: (Reifies s Tape, Num a) => Reverse s a -> Reverse s a -> Reverse s a
(<+>) = binary (+) 1 1
(<**>) :: (Reifies s Tape, Floating a) => Reverse s a -> Reverse s a -> Reverse s a
Zero <**> y = auto (0 ** primal y)
_ <**> Zero = auto 1
x <**> Lift y = lift1 (**y) (\z -> y *^ z ** Id (y 1)) x
x <**> y = lift2_ (**) (\z xi yi -> (yi * xi ** (yi 1), z * log xi)) x y
primal :: Num a => Reverse s a -> a
primal Zero = 0
primal (Lift a) = a
primal (Reverse _ a) = a
instance (Reifies s Tape, Num a) => Jacobian (Reverse s a) where
type D (Reverse s a) = Id a
unary f _ (Zero) = Lift (f 0)
unary f _ (Lift a) = Lift (f a)
unary f (Id dadi) (Reverse i b) = unarily f dadi i b
lift1 f df b = unary f (df (Id pb)) b where
pb = primal b
lift1_ f df b = unary (const a) (df (Id a) (Id pb)) b where
pb = primal b
a = f pb
binary f _ _ Zero Zero = Lift (f 0 0)
binary f _ _ Zero (Lift c) = Lift (f 0 c)
binary f _ _ (Lift b) Zero = Lift (f b 0)
binary f _ _ (Lift b) (Lift c) = Lift (f b c)
binary f _ (Id dadc) Zero (Reverse i c) = unarily (f 0) dadc i c
binary f _ (Id dadc) (Lift b) (Reverse i c) = unarily (f b) dadc i c
binary f (Id dadb) _ (Reverse i b) Zero = unarily (`f` 0) dadb i b
binary f (Id dadb) _ (Reverse i b) (Lift c) = unarily (`f` c) dadb i b
binary f (Id dadb) (Id dadc) (Reverse i b) (Reverse j c) = binarily f dadb dadc i b j c
lift2 f df b c = binary f dadb dadc b c where
(dadb, dadc) = df (Id (primal b)) (Id (primal c))
lift2_ f df b c = binary (\_ _ -> a) dadb dadc b c where
pb = primal b
pc = primal c
a = f pb pc
(dadb, dadc) = df (Id a) (Id pb) (Id pc)
mul :: (Reifies s Tape, Num a) => Reverse s a -> Reverse s a -> Reverse s a
mul = lift2 (*) (\x y -> (y, x))
#define BODY1(x) (Reifies s Tape,x)
#define BODY2(x,y) (Reifies s Tape,x,y)
#define HEAD Reverse s a
#include "instances.h"
derivativeOf :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf _ = sum . partials
derivativeOf' :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> (a, a)
derivativeOf' p r = (primal r, derivativeOf p r)
backPropagate :: Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate k Nil _ = return k
backPropagate k (Unary i g xs) ss = do
da <- readArray ss k
db <- readArray ss i
writeArray ss i $! db + unsafeCoerce g*da
(backPropagate $! k 1) xs ss
backPropagate k (Binary i j g h xs) ss = do
da <- readArray ss k
db <- readArray ss i
writeArray ss i $! db + unsafeCoerce g*da
dc <- readArray ss j
writeArray ss j $! dc + unsafeCoerce h*da
(backPropagate $! k 1) xs ss
partials :: forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials Zero = []
partials (Lift _) = []
partials (Reverse k _) = map (sensitivities !) [0..vs] where
Head n t = unsafePerformIO $ readIORef (getTape (reflect (Proxy :: Proxy s)))
tk = dropCells (n k) t
(vs,sensitivities) = runST $ do
ss <- newArray (0, k) 0
writeArray ss k 1
v <- backPropagate k tk ss
as <- Unsafe.unsafeFreeze ss
return (v, as)
partialArrayOf :: (Reifies s Tape, Num a) => Proxy s -> (Int, Int) -> Reverse s a -> Array Int a
partialArrayOf _ vbounds = accumArray (+) 0 vbounds . zip [0..] . partials
partialMapOf :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> IntMap a
partialMapOf _ = fromDistinctAscList . zip [0..] . partials
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape vs k = unsafePerformIO $ do
h <- newIORef (Head vs Nil)
return (reify (Tape h) k)
var :: a -> Int -> Reverse s a
var a v = Reverse v a
varId :: Reverse s a -> Int
varId (Reverse v _) = v
varId _ = error "varId: not a Var"
bind :: Traversable f => f a -> (f (Reverse s a), (Int,Int))
bind xs = (r,(0,hi)) where
(r,hi) = runState (mapM freshVar xs) 0
freshVar a = state $ \s -> let s' = s + 1 in s' `seq` (var a s, s')
unbind :: Functor f => f (Reverse s a) -> Array Int a -> f a
unbind xs ys = fmap (\v -> ys ! varId v) xs
unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Reverse s a) -> Array Int b -> f c
unbindWith f xs ys = fmap (\v -> f (primal v) (ys ! varId v)) xs
unbindMap :: (Functor f, Num a) => f (Reverse s a) -> IntMap a -> f a
unbindMap xs ys = fmap (\v -> findWithDefault 0 (varId v) ys) xs
unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Reverse s a) -> IntMap b -> f c
unbindMapWithDefault z f xs ys = fmap (\v -> f (primal v) $ findWithDefault z (varId v) ys) xs