{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   : (c) Edward Kmett 2010-2021
-- License     : BSD3
-- Maintainer  : ekmett@gmail.com
-- Stability   : experimental
-- Portability : GHC only
--
-- A dense forward AD based on representable functors. This allows for much larger
-- forward mode data types than 'Numeric.AD.Internal.Dense, as we only need
-- the ability to compare the representation of a functor for equality, rather
-- than put the representation on in a straight line like you have to with
-- 'Traversable'.
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Dense.Representable
  ( Repr(..)
  , ds
  , ds'
  , vars
  , apply
  ) where

import Control.Monad (join)
import Data.Functor.Rep
import Data.Typeable ()
import Data.Data ()
import Data.Number.Erf
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode

data Repr f a
  = Lift !a
  | Repr !a (f a)
  | Zero

instance Show a => Show (Repr f a) where
  showsPrec :: Int -> Repr f a -> ShowS
showsPrec Int
d (Lift a
a)    = Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d a
a
  showsPrec Int
d (Repr a
a f a
_) = Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d a
a
  showsPrec Int
_ Repr f a
Zero        = String -> ShowS
showString String
"0"

ds :: f a -> Repr f a -> f a
ds :: forall (f :: * -> *) a. f a -> Repr f a -> f a
ds f a
_ (Repr a
_ f a
da) = f a
da
ds f a
z Repr f a
_ = f a
z
{-# INLINE ds #-}

ds' :: Num a => f a -> Repr f a -> (a, f a)
ds' :: forall a (f :: * -> *). Num a => f a -> Repr f a -> (a, f a)
ds' f a
_ (Repr a
a f a
da) = (a
a, f a
da)
ds' f a
z (Lift a
a) = (a
a, f a
z)
ds' f a
z Repr f a
Zero = (a
0, f a
z)
{-# INLINE ds' #-}

-- Bind variables and count inputs
vars :: (Representable f, Eq (Rep f), Num a) => f a -> f (Repr f a)
vars :: forall (f :: * -> *) a.
(Representable f, Eq (Rep f), Num a) =>
f a -> f (Repr f a)
vars = (Rep f -> a -> Repr f a) -> f a -> f (Repr f a)
forall (r :: * -> *) a a'.
Representable r =>
(Rep r -> a -> a') -> r a -> r a'
imapRep ((Rep f -> a -> Repr f a) -> f a -> f (Repr f a))
-> (Rep f -> a -> Repr f a) -> f a -> f (Repr f a)
forall a b. (a -> b) -> a -> b
$ \Rep f
i a
a -> a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
a (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (Rep f -> a) -> f a
forall a. (Rep f -> a) -> f a
forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate ((Rep f -> a) -> f a) -> (Rep f -> a) -> f a
forall a b. (a -> b) -> a -> b
$ \Rep f
j -> if Rep f
i Rep f -> Rep f -> Bool
forall a. Eq a => a -> a -> Bool
== Rep f
j then a
1 else a
0
{-# INLINE vars #-}

apply :: (Representable f, Eq (Rep f), Num a) => (f (Repr f a) -> b) -> f a -> b
apply :: forall (f :: * -> *) a b.
(Representable f, Eq (Rep f), Num a) =>
(f (Repr f a) -> b) -> f a -> b
apply f (Repr f a) -> b
f f a
as = f (Repr f a) -> b
f (f a -> f (Repr f a)
forall (f :: * -> *) a.
(Representable f, Eq (Rep f), Num a) =>
f a -> f (Repr f a)
vars f a
as)
{-# INLINE apply #-}

primal :: Num a => Repr f a -> a
primal :: forall a (f :: * -> *). Num a => Repr f a -> a
primal Repr f a
Zero = a
0
primal (Lift a
a) = a
a
primal (Repr a
a f a
_) = a
a

instance (Representable f, Num a) => Mode (Repr f a) where
  type Scalar (Repr f a) = a
  asKnownConstant :: Repr f a -> Maybe (Scalar (Repr f a))
asKnownConstant (Lift a
a) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
  asKnownConstant Repr f a
Zero = a -> Maybe a
forall a. a -> Maybe a
Just a
0
  asKnownConstant Repr f a
_ = Maybe a
Maybe (Scalar (Repr f a))
forall a. Maybe a
Nothing
  isKnownConstant :: Repr f a -> Bool
isKnownConstant Repr{} = Bool
False
  isKnownConstant Repr f a
_ = Bool
True
  isKnownZero :: Repr f a -> Bool
isKnownZero Repr f a
Zero = Bool
True
  isKnownZero Repr f a
_ = Bool
False
  auto :: Scalar (Repr f a) -> Repr f a
auto = a -> Repr f a
Scalar (Repr f a) -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift
  zero :: Repr f a
zero = Repr f a
forall (f :: * -> *) a. Repr f a
Zero
  Scalar (Repr f a)
_ *^ :: Scalar (Repr f a) -> Repr f a -> Repr f a
*^ Repr f a
Zero      = Repr f a
forall (f :: * -> *) a. Repr f a
Zero
  Scalar (Repr f a)
a *^ Lift a
b    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (a
Scalar (Repr f a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b)
  Scalar (Repr f a)
a *^ Repr a
b f a
db = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (a
Scalar (Repr f a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
Scalar (Repr f a)
aa -> a -> a
forall a. Num a => a -> a -> a
*) f a
db
  Repr f a
Zero      ^* :: Repr f a -> Scalar (Repr f a) -> Repr f a
^* Scalar (Repr f a)
_ = Repr f a
forall (f :: * -> *) a. Repr f a
Zero
  Lift a
a    ^* Scalar (Repr f a)
b = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Repr f a)
b)
  Repr a
a f a
da ^* Scalar (Repr f a)
b = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Repr f a)
b) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
forall a. Num a => a -> a -> a
*Scalar (Repr f a)
b) f a
da
  Repr f a
Zero      ^/ :: Fractional (Scalar (Repr f a)) =>
Repr f a -> Scalar (Repr f a) -> Repr f a
^/ Scalar (Repr f a)
_ = Repr f a
forall (f :: * -> *) a. Repr f a
Zero
  Lift a
a    ^/ Scalar (Repr f a)
b = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Repr f a)
b)
  Repr a
a f a
da ^/ Scalar (Repr f a)
b = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Repr f a)
b) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
forall a. Fractional a => a -> a -> a
/Scalar (Repr f a)
b) f a
da

(<+>) :: (Representable f, Num a) => Repr f a -> Repr f a -> Repr f a
Repr f a
Zero      <+> :: forall (f :: * -> *) a.
(Representable f, Num a) =>
Repr f a -> Repr f a -> Repr f a
<+> Repr f a
a         = Repr f a
a
Repr f a
a         <+> Repr f a
Zero      = Repr f a
a
Lift a
a    <+> Lift a
b    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b)
Lift a
a    <+> Repr a
b f a
db = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) f a
db
Repr a
a f a
da <+> Lift a
b    = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) f a
da
Repr a
a f a
da <+> Repr a
b f a
db = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. Num a => a -> a -> a
(+) f a
da f a
db

instance (Representable f, Num a) => Jacobian (Repr f a) where
  type D (Repr f a) = Id a
  unary :: (Scalar (Repr f a) -> Scalar (Repr f a))
-> D (Repr f a) -> Repr f a -> Repr f a
unary Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_         Repr f a
Zero        = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0)
  unary Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_         (Lift a
b)    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b)
  unary Scalar (Repr f a) -> Scalar (Repr f a)
f (Id a
dadb) (Repr a
b f a
db) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b) ((a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db)

  lift1 :: (Scalar (Repr f a) -> Scalar (Repr f a))
-> (D (Repr f a) -> D (Repr f a)) -> Repr f a -> Repr f a
lift1 Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a)
_  Repr f a
Zero        = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0)
  lift1 Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a)
_  (Lift a
b)    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b)
  lift1 Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a)
df (Repr a
b f a
db) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b) ((a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db) where
    Id a
dadb = D (Repr f a) -> D (Repr f a)
df (a -> Id a
forall a. a -> Id a
Id a
b)

  lift1_ :: (Scalar (Repr f a) -> Scalar (Repr f a))
-> (D (Repr f a) -> D (Repr f a) -> D (Repr f a))
-> Repr f a
-> Repr f a
lift1_ Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> D (Repr f a)
_  Repr f a
Zero         = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0)
  lift1_ Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> D (Repr f a)
_  (Lift a
b)     = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b)
  lift1_ Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> D (Repr f a)
df (Repr a
b f a
db) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
Scalar (Repr f a)
a ((a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db) where
    a :: Scalar (Repr f a)
a = Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b
    Id a
dadb = D (Repr f a) -> D (Repr f a) -> D (Repr f a)
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Repr f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b)

  binary :: (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a))
-> D (Repr f a) -> D (Repr f a) -> Repr f a -> Repr f a -> Repr f a
binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_          D (Repr f a)
_        Repr f a
Zero        Repr f a
Zero        = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
0)
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_          D (Repr f a)
_        Repr f a
Zero        (Lift a
c)    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
c)
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_          D (Repr f a)
_        (Lift a
b)    Repr f a
Zero        = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
0)
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_          D (Repr f a)
_        (Lift a
b)    (Lift a
c)    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c)
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_         (Id a
dadc) Repr f a
Zero        (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc) f a
dc
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
_         (Id a
dadc) (Lift a
b)    (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc) f a
dc
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f (Id a
dadb) D (Repr f a)
_         (Repr a
b f a
db) Repr f a
Zero        = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
0) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f (Id a
dadb) D (Repr f a)
_         (Repr a
b f a
db) (Lift a
c)    = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db
  binary Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f (Id a
dadb) (Id a
dadc) (Repr a
b f a
db) (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
productRule f a
db f a
dc where
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
dbi a -> a -> a
forall a. Num a => a -> a -> a
+ a
dci a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

  lift2 :: (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a))
-> (D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a)))
-> Repr f a
-> Repr f a
-> Repr f a
lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  Repr f a
Zero        Repr f a
Zero        = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
0)
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  Repr f a
Zero        (Lift a
c)    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
c)
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  (Lift a
b)    Repr f a
Zero        = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
0)
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  (Lift a
b)    (Lift a
c)    = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c)
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df Repr f a
Zero        (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where dadc :: a
dadc = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
0) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Lift a
b)    (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where dadc :: a
dadc = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Repr a
b f a
db) Repr f a
Zero        = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
0) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where dadb :: a
dadb = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
0)))
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Repr a
b f a
db) (Lift a
c)    = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c) (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where dadb :: a
dadb = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Repr a
b f a
db) (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c) f a
da where
    (Id a
dadb, Id a
dadc) = D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
    da :: f a
da = (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
productRule f a
db f a
dc
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
dbi a -> a -> a
forall a. Num a => a -> a -> a
+ a
dci a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

  lift2_ :: (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a))
-> (D (Repr f a)
    -> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a)))
-> Repr f a
-> Repr f a
-> Repr f a
lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  Repr f a
Zero     Repr f a
Zero           = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
0)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  Repr f a
Zero     (Lift a
c)       = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
c)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  (Lift a
b) Repr f a
Zero           = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
0)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
_  (Lift a
b) (Lift a
c)       = a -> Repr f a
forall (f :: * -> *) a. a -> Repr f a
Lift (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df Repr f a
Zero     (Repr a
c f a
dc)    = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
Scalar (Repr f a)
a (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where
    a :: Scalar (Repr f a)
a = Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
0 a
Scalar (Repr f a)
c
    (D (Repr f a)
_, Id a
dadc) = D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Repr f a)
a) (a -> Id a
forall a. a -> Id a
Id a
0) (a -> Id a
forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Lift a
b) (Repr a
c f a
dc)    = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
Scalar (Repr f a)
a (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where
    a :: Scalar (Repr f a)
a = Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c
    (D (Repr f a)
_, Id a
dadc) = D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Repr f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Repr a
b f a
db) Repr f a
Zero        = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
Scalar (Repr f a)
a (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where
    a :: Scalar (Repr f a)
a = Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
0
    (Id a
dadb, D (Repr f a)
_) = D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Repr f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
0)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Repr a
b f a
db) (Lift a
c)    = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
Scalar (Repr f a)
a (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where
    a :: Scalar (Repr f a)
a = Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c
    (Id a
dadb, D (Repr f a)
_) = D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Repr f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (Repr a
b f a
db) (Repr a
c f a
dc) = a -> f a -> Repr f a
forall (f :: * -> *) a. a -> f a -> Repr f a
Repr a
Scalar (Repr f a)
a (f a -> Repr f a) -> f a -> Repr f a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
productRule f a
db f a
dc where
    a :: Scalar (Repr f a)
a = Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
f a
Scalar (Repr f a)
b a
Scalar (Repr f a)
c
    (Id a
dadb, Id a
dadc) = D (Repr f a)
-> D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Repr f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
dbi a -> a -> a
forall a. Num a => a -> a -> a
+ a
dci a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

mul :: (Representable f, Num a) => Repr f a -> Repr f a -> Repr f a
mul :: forall (f :: * -> *) a.
(Representable f, Num a) =>
Repr f a -> Repr f a -> Repr f a
mul = (Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a))
-> (D (Repr f a) -> D (Repr f a) -> (D (Repr f a), D (Repr f a)))
-> Repr f a
-> Repr f a
-> Repr f a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 a -> a -> a
Scalar (Repr f a) -> Scalar (Repr f a) -> Scalar (Repr f a)
forall a. Num a => a -> a -> a
(*) (\D (Repr f a)
x D (Repr f a)
y -> (D (Repr f a)
y, D (Repr f a)
x))

#define BODY1(x)   (Representable f, x) =>
#define BODY2(x,y) (Representable f, x, y) =>
#define HEAD (Repr f a)
#include "instances.h"