{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.AD.Internal.Forward.Double
( ForwardDouble(..)
, bundle
, unbundle
, apply
, bind
, bind'
, bindWith
, bindWith'
, transposeWith
) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative hiding ((<**>))
import Data.Foldable (Foldable, toList)
import Data.Traversable (Traversable, mapAccumL)
#else
import Data.Foldable (toList)
import Data.Traversable (mapAccumL)
#endif
import Control.Monad (join)
import Data.Function (on)
import Data.Number.Erf
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
data ForwardDouble = ForwardDouble { primal, tangent :: {-# UNPACK #-} !Double }
deriving (Read, Show)
unbundle :: ForwardDouble -> (Double, Double)
unbundle (ForwardDouble a da) = (a, da)
{-# INLINE unbundle #-}
bundle :: Double -> Double -> ForwardDouble
bundle = ForwardDouble
{-# INLINE bundle #-}
apply :: (ForwardDouble -> b) -> Double -> b
apply f a = f (bundle a 1)
{-# INLINE apply #-}
instance Mode ForwardDouble where
type Scalar ForwardDouble = Double
auto = flip ForwardDouble 0
zero = ForwardDouble 0 0
isKnownZero (ForwardDouble 0 0) = True
isKnownZero _ = False
isKnownConstant (ForwardDouble _ 0) = True
isKnownConstant _ = False
a *^ ForwardDouble b db = ForwardDouble (a * b) (a * db)
ForwardDouble a da ^* b = ForwardDouble (a * b) (da * b)
ForwardDouble a da ^/ b = ForwardDouble (a / b) (da / b)
(<+>) :: ForwardDouble -> ForwardDouble -> ForwardDouble
ForwardDouble a da <+> ForwardDouble b db = ForwardDouble (a + b) (da + db)
instance Jacobian ForwardDouble where
type D ForwardDouble = Id Double
unary f (Id dadb) (ForwardDouble b db) = ForwardDouble (f b) (dadb * db)
lift1 f df (ForwardDouble b db) = ForwardDouble (f b) (dadb * db) where
Id dadb = df (Id b)
lift1_ f df (ForwardDouble b db) = ForwardDouble a da where
a = f b
Id da = df (Id a) (Id b) ^* db
binary f (Id dadb) (Id dadc) (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble (f b c) $ dadb * db + dc * dadc
lift2 f df (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble a da where
a = f b c
(Id dadb, Id dadc) = df (Id b) (Id c)
da = dadb * db + dc * dadc
lift2_ f df (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble a da where
a = f b c
(Id dadb, Id dadc) = df (Id a) (Id b) (Id c)
da = dadb * db + dc * dadc
instance Eq ForwardDouble where
(==) = on (==) primal
instance Ord ForwardDouble where
compare = on compare primal
instance Num ForwardDouble where
fromInteger 0 = zero
fromInteger n = auto (fromInteger n)
(+) = (<+>)
(-) = binary (-) (auto 1) (auto (-1))
(*) = lift2 (*) (\x y -> (y, x))
negate = lift1 negate (const (auto (-1)))
abs = lift1 abs signum
signum a = lift1 signum (const zero) a
instance Fractional ForwardDouble where
fromRational 0 = zero
fromRational r = auto (fromRational r)
x / y = x * recip y
recip = lift1_ recip (const . negate . join (*))
instance Floating ForwardDouble where
pi = auto pi
exp = lift1_ exp const
log = lift1 log recip
logBase x y = log y / log x
sqrt = lift1_ sqrt (\z _ -> recip (auto 2 * z))
ForwardDouble 0 0 ** ForwardDouble a _ = ForwardDouble (0 ** a) 0
_ ** ForwardDouble 0 0 = ForwardDouble 1 0
x ** ForwardDouble y 0 = lift1 (**y) (\z -> y *^ z ** Id (y - 1)) x
x ** y = lift2_ (**) (\z xi yi -> (yi * z / xi, z * log xi)) x y
sin = lift1 sin cos
cos = lift1 cos $ negate . sin
tan = lift1 tan $ recip . join (*) . cos
asin = lift1 asin $ \x -> recip (sqrt (auto 1 - join (*) x))
acos = lift1 acos $ \x -> negate (recip (sqrt (1 - join (*) x)))
atan = lift1 atan $ \x -> recip (1 + join (*) x)
sinh = lift1 sinh cosh
cosh = lift1 cosh sinh
tanh = lift1 tanh $ recip . join (*) . cosh
asinh = lift1 asinh $ \x -> recip (sqrt (1 + join (*) x))
acosh = lift1 acosh $ \x -> recip (sqrt (join (*) x - 1))
atanh = lift1 atanh $ \x -> recip (1 - join (*) x)
instance Enum ForwardDouble where
succ = lift1 succ (const 1)
pred = lift1 pred (const 1)
toEnum = auto . toEnum
fromEnum = fromEnum . primal
enumFrom a = withPrimal a <$> enumFrom (primal a)
enumFromTo a b = withPrimal a <$> enumFromTo (primal a) (primal b)
enumFromThen a b = zipWith (fromBy a delta) [0..] $ enumFromThen (primal a) (primal b) where delta = b - a
enumFromThenTo a b c = zipWith (fromBy a delta) [0..] $ enumFromThenTo (primal a) (primal b) (primal c) where delta = b - a
instance Real ForwardDouble where
toRational = toRational . primal
instance RealFloat ForwardDouble where
floatRadix = floatRadix . primal
floatDigits = floatDigits . primal
floatRange = floatRange . primal
decodeFloat = decodeFloat . primal
encodeFloat m e = auto (encodeFloat m e)
isNaN = isNaN . primal
isInfinite = isInfinite . primal
isDenormalized = isDenormalized . primal
isNegativeZero = isNegativeZero . primal
isIEEE = isIEEE . primal
exponent = exponent
scaleFloat n = unary (scaleFloat n) (scaleFloat n 1)
significand x = unary significand (scaleFloat (- floatDigits x) 1) x
atan2 = lift2 atan2 $ \vx vy -> let r = recip (join (*) vx + join (*) vy) in (vy * r, negate vx * r)
instance RealFrac ForwardDouble where
properFraction a = (w, a `withPrimal` pb) where
pa = primal a
(w, pb) = properFraction pa
truncate = truncate . primal
round = round . primal
ceiling = ceiling . primal
floor = floor . primal
instance Erf ForwardDouble where
erf = lift1 erf $ \x -> (2 / sqrt pi) * exp (negate x * x)
erfc = lift1 erfc $ \x -> ((-2) / sqrt pi) * exp (negate x * x)
normcdf = lift1 normcdf $ \x -> (recip $ sqrt (2 * pi)) * exp (- x * x / 2)
instance InvErf ForwardDouble where
inverf = lift1_ inverf $ \x _ -> sqrt pi / 2 * exp (x * x)
inverfc = lift1_ inverfc $ \x _ -> negate (sqrt pi / 2) * exp (x * x)
invnormcdf = lift1_ invnormcdf $ \x _ -> sqrt (2 * pi) * exp (x * x / 2)
bind :: Traversable f => (f ForwardDouble -> b) -> f Double -> f b
bind f as = snd $ mapAccumL outer (0 :: Int) as where
outer !i _ = (i + 1, f $ snd $ mapAccumL (inner i) 0 as)
inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
bind' :: Traversable f => (f ForwardDouble -> b) -> f Double -> (b, f b)
bind' f as = dropIx $ mapAccumL outer (0 :: Int, b0) as where
outer (!i, _) _ = let b = f $ snd $ mapAccumL (inner i) (0 :: Int) as in ((i + 1, b), b)
inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
b0 = f (auto <$> as)
dropIx ((_,b),bs) = (b,bs)
bindWith :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> f c
bindWith g f as = snd $ mapAccumL outer (0 :: Int) as where
outer !i a = (i + 1, g a $ f $ snd $ mapAccumL (inner i) 0 as)
inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
bindWith' :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> (b, f c)
bindWith' g f as = dropIx $ mapAccumL outer (0 :: Int, b0) as where
outer (!i, _) a = let b = f $ snd $ mapAccumL (inner i) (0 :: Int) as in ((i + 1, b), g a b)
inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
b0 = f (auto <$> as)
dropIx ((_,b),bs) = (b,bs)
transposeWith :: (Functor f, Foldable f, Traversable g) => (b -> f a -> c) -> f (g a) -> g b -> g c
transposeWith f as = snd . mapAccumL go xss0 where
go xss b = (tail <$> xss, f b (head <$> xss))
xss0 = toList <$> as