{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Unsafe and often partial combinators intended for internal usage.
--
-- Handle with care.
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Forward
  ( Forward(..)
  , primal
  , tangent
  , bundle
  , unbundle
  , apply
  , bind
  , bind'
  , bindWith
  , bindWith'
  , transposeWith
  ) where


import Control.Monad (join)
import Data.Foldable (toList)
import Data.Traversable (mapAccumL)
import Data.Data
import Data.Number.Erf
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | 'Forward' mode AD
data Forward a
  = Forward !a a
  | Lift !a
  | Zero
  deriving (Int -> Forward a -> ShowS
[Forward a] -> ShowS
Forward a -> String
(Int -> Forward a -> ShowS)
-> (Forward a -> String)
-> ([Forward a] -> ShowS)
-> Show (Forward a)
forall a. Show a => Int -> Forward a -> ShowS
forall a. Show a => [Forward a] -> ShowS
forall a. Show a => Forward a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Forward a -> ShowS
showsPrec :: Int -> Forward a -> ShowS
$cshow :: forall a. Show a => Forward a -> String
show :: Forward a -> String
$cshowList :: forall a. Show a => [Forward a] -> ShowS
showList :: [Forward a] -> ShowS
Show, Typeable (Forward a)
Typeable (Forward a) =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> Forward a -> c (Forward a))
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c (Forward a))
-> (Forward a -> Constr)
-> (Forward a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c (Forward a)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e))
    -> Maybe (c (Forward a)))
-> ((forall b. Data b => b -> b) -> Forward a -> Forward a)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> Forward a -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> Forward a -> r)
-> (forall u. (forall d. Data d => d -> u) -> Forward a -> [u])
-> (forall u.
    Int -> (forall d. Data d => d -> u) -> Forward a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> Forward a -> m (Forward a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Forward a -> m (Forward a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Forward a -> m (Forward a))
-> Data (Forward a)
Forward a -> Constr
Forward a -> DataType
(forall b. Data b => b -> b) -> Forward a -> Forward a
forall a. Data a => Typeable (Forward a)
forall a. Data a => Forward a -> Constr
forall a. Data a => Forward a -> DataType
forall a.
Data a =>
(forall b. Data b => b -> b) -> Forward a -> Forward a
forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Forward a -> u
forall a u.
Data a =>
(forall d. Data d => d -> u) -> Forward a -> [u]
forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Forward a)
forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Forward a -> c (Forward a)
forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Forward a))
forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Forward a))
forall a.
Typeable a =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Forward a -> u
forall u. (forall d. Data d => d -> u) -> Forward a -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Forward a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Forward a -> c (Forward a)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Forward a))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Forward a))
$cgfoldl :: forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Forward a -> c (Forward a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Forward a -> c (Forward a)
$cgunfold :: forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Forward a)
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Forward a)
$ctoConstr :: forall a. Data a => Forward a -> Constr
toConstr :: Forward a -> Constr
$cdataTypeOf :: forall a. Data a => Forward a -> DataType
dataTypeOf :: Forward a -> DataType
$cdataCast1 :: forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Forward a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Forward a))
$cdataCast2 :: forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Forward a))
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Forward a))
$cgmapT :: forall a.
Data a =>
(forall b. Data b => b -> b) -> Forward a -> Forward a
gmapT :: (forall b. Data b => b -> b) -> Forward a -> Forward a
$cgmapQl :: forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
$cgmapQr :: forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Forward a -> r
$cgmapQ :: forall a u.
Data a =>
(forall d. Data d => d -> u) -> Forward a -> [u]
gmapQ :: forall u. (forall d. Data d => d -> u) -> Forward a -> [u]
$cgmapQi :: forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Forward a -> u
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Forward a -> u
$cgmapM :: forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
$cgmapMp :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
$cgmapMo :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Forward a -> m (Forward a)
Data, Typeable)

-- | Calculate the 'tangent' using forward mode AD.
tangent :: Num a => Forward a -> a
tangent :: forall a. Num a => Forward a -> a
tangent (Forward a
_ a
da) = a
da
tangent Forward a
_ = a
0
{-# INLINE tangent #-}

unbundle :: Num a => Forward a -> (a, a)
unbundle :: forall a. Num a => Forward a -> (a, a)
unbundle (Forward a
a a
da) = (a
a, a
da)
unbundle Forward a
Zero = (a
0,a
0)
unbundle (Lift a
a) = (a
a, a
0)
{-# INLINE unbundle #-}

bundle :: a -> a -> Forward a
bundle :: forall a. a -> a -> Forward a
bundle = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward
{-# INLINE bundle #-}

apply :: Num a => (Forward a -> b) -> a -> b
apply :: forall a b. Num a => (Forward a -> b) -> a -> b
apply Forward a -> b
f a
a = Forward a -> b
f (a -> a -> Forward a
forall a. a -> a -> Forward a
bundle a
a a
1)
{-# INLINE apply #-}

primal :: Num a => Forward a -> a
primal :: forall a. Num a => Forward a -> a
primal (Forward a
a a
_) = a
a
primal (Lift a
a) = a
a
primal Forward a
Zero = a
0

instance Num a => Mode (Forward a) where
  type Scalar (Forward a) = a

  auto :: Scalar (Forward a) -> Forward a
auto = a -> Forward a
Scalar (Forward a) -> Forward a
forall a. a -> Forward a
Lift
  zero :: Forward a
zero = Forward a
forall a. Forward a
Zero

  isKnownZero :: Forward a -> Bool
isKnownZero Forward a
Zero = Bool
True
  isKnownZero Forward a
_    = Bool
False

  asKnownConstant :: Forward a -> Maybe (Scalar (Forward a))
asKnownConstant Forward a
Zero = a -> Maybe a
forall a. a -> Maybe a
Just a
0
  asKnownConstant (Lift a
a) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
  asKnownConstant Forward a
_ = Maybe a
Maybe (Scalar (Forward a))
forall a. Maybe a
Nothing

  isKnownConstant :: Forward a -> Bool
isKnownConstant Forward{} = Bool
False
  isKnownConstant Forward a
_ = Bool
True

  Scalar (Forward a)
a *^ :: Scalar (Forward a) -> Forward a -> Forward a
*^ Forward a
b a
db = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (a
Scalar (Forward a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b) (a
Scalar (Forward a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
db)
  Scalar (Forward a)
a *^ Lift a
b = a -> Forward a
forall a. a -> Forward a
Lift (a
Scalar (Forward a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b)
  Scalar (Forward a)
_ *^ Forward a
Zero = Forward a
forall a. Forward a
Zero

  Forward a
a a
da ^* :: Forward a -> Scalar (Forward a) -> Forward a
^* Scalar (Forward a)
b = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Forward a)
b) (a
da a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Forward a)
b)
  Lift a
a ^* Scalar (Forward a)
b = a -> Forward a
forall a. a -> Forward a
Lift (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Forward a)
b)
  Forward a
Zero ^* Scalar (Forward a)
_ = Forward a
forall a. Forward a
Zero

  Forward a
a a
da ^/ :: Fractional (Scalar (Forward a)) =>
Forward a -> Scalar (Forward a) -> Forward a
^/ Scalar (Forward a)
b = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Forward a)
b) (a
da a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Forward a)
b)
  Lift a
a ^/ Scalar (Forward a)
b = a -> Forward a
forall a. a -> Forward a
Lift (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Forward a)
b)
  Forward a
Zero ^/ Scalar (Forward a)
_ = Forward a
forall a. Forward a
Zero

(<+>) :: Num a => Forward a -> Forward a -> Forward a
Forward a
Zero         <+> :: forall a. Num a => Forward a -> Forward a -> Forward a
<+> Forward a
a            = Forward a
a
Forward a
a            <+> Forward a
Zero         = Forward a
a
Forward a
a a
da <+> Forward a
b a
db = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) (a
da a -> a -> a
forall a. Num a => a -> a -> a
+ a
db)
Forward a
a a
da <+> Lift a
b       = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) a
da
Lift a
a       <+> Forward a
b a
db = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) a
db
Lift a
a       <+> Lift a
b       = a -> Forward a
forall a. a -> Forward a
Lift (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b)

instance Num a => Jacobian (Forward a) where
  type D (Forward a) = Id a

  unary :: (Scalar (Forward a) -> Scalar (Forward a))
-> D (Forward a) -> Forward a -> Forward a
unary Scalar (Forward a) -> Scalar (Forward a)
f (Id a
dadb) (Forward a
b a
db) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b) (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db)
  unary Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         (Lift a
b)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b)
  unary Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0)

  lift1 :: (Scalar (Forward a) -> Scalar (Forward a))
-> (D (Forward a) -> D (Forward a)) -> Forward a -> Forward a
lift1 Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a)
_ Forward a
Zero            = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0)
  lift1 Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a)
_  (Lift a
b)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b)
  lift1 Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a)
df (Forward a
b a
db) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b) (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db) where
    Id a
dadb = D (Forward a) -> D (Forward a)
df (a -> Id a
forall a. a -> Id a
Id a
b)

  lift1_ :: (Scalar (Forward a) -> Scalar (Forward a))
-> (D (Forward a) -> D (Forward a) -> D (Forward a))
-> Forward a
-> Forward a
lift1_ Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> D (Forward a)
_  Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0)
  lift1_ Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> D (Forward a)
_  (Lift a
b)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b)
  lift1_ Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> D (Forward a)
df (Forward a
b a
db) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a a
da where
    a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b
    Id a
da = D (Forward a) -> D (Forward a) -> D (Forward a)
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Forward a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) Id a -> Scalar (Id a) -> Id a
forall t. Mode t => t -> Scalar t -> t
^* a
Scalar (Id a)
db

  binary :: (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a))
-> D (Forward a)
-> D (Forward a)
-> Forward a
-> Forward a
-> Forward a
binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         D (Forward a)
_         Forward a
Zero           Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
0)
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         D (Forward a)
_         Forward a
Zero           (Lift a
c)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
c)
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         D (Forward a)
_         (Lift a
b)       Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
0)
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         D (Forward a)
_         (Lift a
b)       (Lift a
c)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c)
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         (Id a
dadc) Forward a
Zero           (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
_         (Id a
dadc) (Lift a
b)       (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f (Id a
dadb) D (Forward a)
_         (Forward a
b a
db) Forward a
Zero           = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
0) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f (Id a
dadb) D (Forward a)
_         (Forward a
b a
db) (Lift a
c)       = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db
  binary Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f (Id a
dadb) (Id a
dadc) (Forward a
b a
db) (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

  lift2 :: (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a))
-> (D (Forward a)
    -> D (Forward a) -> (D (Forward a), D (Forward a)))
-> Forward a
-> Forward a
-> Forward a
lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  Forward a
Zero           Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
0)
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  Forward a
Zero           (Lift a
c)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
c)
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  (Lift a
b)       Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
0)
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  (Lift a
b)       (Lift a
c)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c)
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df Forward a
Zero           (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
0) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Lift a
b)       (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Forward a
b a
db) Forward a
Zero           = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
0) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
0))) a -> a -> a
forall a. Num a => a -> a -> a
* a
db
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Forward a
b a
db) (Lift a
c)       = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c) (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c))) a -> a -> a
forall a. Num a => a -> a -> a
* a
db
  lift2 Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Forward a
b a
db) (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a a
da where
    a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c
    (Id a
dadb, Id a
dadc) = D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
    da :: a
da = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

  lift2_ :: (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a))
-> (D (Forward a)
    -> D (Forward a)
    -> D (Forward a)
    -> (D (Forward a), D (Forward a)))
-> Forward a
-> Forward a
-> Forward a
lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  Forward a
Zero           Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
0)
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  Forward a
Zero           (Lift a
c)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
c)
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  (Lift a
b)       Forward a
Zero           = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
0)
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
_  (Lift a
b)       (Lift a
c)       = a -> Forward a
forall a. a -> Forward a
Lift (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c)
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df Forward a
Zero           (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Forward a)
a) (a -> Id a
forall a. a -> Id a
Id a
0) (a -> Id a
forall a. a -> Id a
Id a
c))) where a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
0 a
Scalar (Forward a)
c
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Lift a
b)       (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Forward a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c))) where a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Forward a
b a
db) Forward a
Zero           = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Forward a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
0))) a -> a -> a
forall a. Num a => a -> a -> a
* a
db where a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
0
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Forward a
b a
db) (Lift a
c)       = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a (a -> Forward a) -> a -> Forward a
forall a b. (a -> b) -> a -> b
$ Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Forward a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c))) a -> a -> a
forall a. Num a => a -> a -> a
* a
db where a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c
  lift2_ Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (Forward a
b a
db) (Forward a
c a
dc) = a -> a -> Forward a
forall a. a -> a -> Forward a
Forward a
Scalar (Forward a)
a a
da where
    a :: Scalar (Forward a)
a = Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
f a
Scalar (Forward a)
b a
Scalar (Forward a)
c
    (Id a
dadb, Id a
dadc) = D (Forward a)
-> D (Forward a) -> D (Forward a) -> (D (Forward a), D (Forward a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Forward a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
    da :: a
da = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
db a -> a -> a
forall a. Num a => a -> a -> a
+ a
dc a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

#define HEAD (Forward a)
#include "instances.h"

bind :: (Traversable f, Num a) => (f (Forward a) -> b) -> f a -> f b
bind :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Forward a) -> b) -> f a -> f b
bind f (Forward a) -> b
f f a
as = (Int, f b) -> f b
forall a b. (a, b) -> b
snd ((Int, f b) -> f b) -> (Int, f b) -> f b
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, b)) -> Int -> f a -> (Int, f b)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Int -> a -> (Int, b)
outer (Int
0 :: Int) f a
as where
  outer :: Int -> a -> (Int, b)
outer !Int
i a
_ = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, f (Forward a) -> b
f (f (Forward a) -> b) -> f (Forward a) -> b
forall a b. (a -> b) -> a -> b
$ (Int, f (Forward a)) -> f (Forward a)
forall a b. (a, b) -> b
snd ((Int, f (Forward a)) -> f (Forward a))
-> (Int, f (Forward a)) -> f (Forward a)
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, Forward a))
-> Int -> f a -> (Int, f (Forward a))
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int -> Int -> a -> (Int, Forward a)
forall {a} {a}.
(Eq a, Num a, Num a) =>
a -> a -> a -> (a, Forward a)
inner Int
i) Int
0 f a
as)
  inner :: a -> a -> a -> (a, Forward a)
inner !a
i !a
j a
a = (a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j then a -> a -> Forward a
forall a. a -> a -> Forward a
bundle a
a a
1 else Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Forward a)
a)

bind' :: (Traversable f, Num a) => (f (Forward a) -> b) -> f a -> (b, f b)
bind' :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Forward a) -> b) -> f a -> (b, f b)
bind' f (Forward a) -> b
f f a
as = ((Int, b), f b) -> (b, f b)
forall {a} {a} {b}. ((a, a), b) -> (a, b)
dropIx (((Int, b), f b) -> (b, f b)) -> ((Int, b), f b) -> (b, f b)
forall a b. (a -> b) -> a -> b
$ ((Int, b) -> a -> ((Int, b), b))
-> (Int, b) -> f a -> ((Int, b), f b)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int, b) -> a -> ((Int, b), b)
outer (Int
0 :: Int, b
b0) f a
as where
  outer :: (Int, b) -> a -> ((Int, b), b)
outer (!Int
i, b
_) a
_ = let b :: b
b = f (Forward a) -> b
f (f (Forward a) -> b) -> f (Forward a) -> b
forall a b. (a -> b) -> a -> b
$ (Int, f (Forward a)) -> f (Forward a)
forall a b. (a, b) -> b
snd ((Int, f (Forward a)) -> f (Forward a))
-> (Int, f (Forward a)) -> f (Forward a)
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, Forward a))
-> Int -> f a -> (Int, f (Forward a))
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int -> Int -> a -> (Int, Forward a)
forall {a} {a}.
(Eq a, Num a, Num a) =>
a -> a -> a -> (a, Forward a)
inner Int
i) (Int
0 :: Int) f a
as in ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, b
b), b
b)
  inner :: a -> a -> a -> (a, Forward a)
inner !a
i !a
j a
a = (a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j then a -> a -> Forward a
forall a. a -> a -> Forward a
bundle a
a a
1 else Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Forward a)
a)
  b0 :: b
b0 = f (Forward a) -> b
f (a -> Forward a
Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto (a -> Forward a) -> f a -> f (Forward a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
as)
  dropIx :: ((a, a), b) -> (a, b)
dropIx ((a
_,a
b),b
bs) = (a
b,b
bs)

bindWith :: (Traversable f, Num a) => (a -> b -> c) -> (f (Forward a) -> b) -> f a -> f c
bindWith :: forall (f :: * -> *) a b c.
(Traversable f, Num a) =>
(a -> b -> c) -> (f (Forward a) -> b) -> f a -> f c
bindWith a -> b -> c
g f (Forward a) -> b
f f a
as = (Int, f c) -> f c
forall a b. (a, b) -> b
snd ((Int, f c) -> f c) -> (Int, f c) -> f c
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, c)) -> Int -> f a -> (Int, f c)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Int -> a -> (Int, c)
outer (Int
0 :: Int) f a
as where
  outer :: Int -> a -> (Int, c)
outer !Int
i a
a = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> b -> c
g a
a (b -> c) -> b -> c
forall a b. (a -> b) -> a -> b
$ f (Forward a) -> b
f (f (Forward a) -> b) -> f (Forward a) -> b
forall a b. (a -> b) -> a -> b
$ (Int, f (Forward a)) -> f (Forward a)
forall a b. (a, b) -> b
snd ((Int, f (Forward a)) -> f (Forward a))
-> (Int, f (Forward a)) -> f (Forward a)
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, Forward a))
-> Int -> f a -> (Int, f (Forward a))
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int -> Int -> a -> (Int, Forward a)
forall {a} {a}.
(Eq a, Num a, Num a) =>
a -> a -> a -> (a, Forward a)
inner Int
i) Int
0 f a
as)
  inner :: a -> a -> a -> (a, Forward a)
inner !a
i !a
j a
a = (a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j then a -> a -> Forward a
forall a. a -> a -> Forward a
bundle a
a a
1 else Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Forward a)
a)

bindWith' :: (Traversable f, Num a) => (a -> b -> c) -> (f (Forward a) -> b) -> f a -> (b, f c)
bindWith' :: forall (f :: * -> *) a b c.
(Traversable f, Num a) =>
(a -> b -> c) -> (f (Forward a) -> b) -> f a -> (b, f c)
bindWith' a -> b -> c
g f (Forward a) -> b
f f a
as = ((Int, b), f c) -> (b, f c)
forall {a} {a} {b}. ((a, a), b) -> (a, b)
dropIx (((Int, b), f c) -> (b, f c)) -> ((Int, b), f c) -> (b, f c)
forall a b. (a -> b) -> a -> b
$ ((Int, b) -> a -> ((Int, b), c))
-> (Int, b) -> f a -> ((Int, b), f c)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int, b) -> a -> ((Int, b), c)
outer (Int
0 :: Int, b
b0) f a
as where
  outer :: (Int, b) -> a -> ((Int, b), c)
outer (!Int
i, b
_) a
a = let b :: b
b = f (Forward a) -> b
f (f (Forward a) -> b) -> f (Forward a) -> b
forall a b. (a -> b) -> a -> b
$ (Int, f (Forward a)) -> f (Forward a)
forall a b. (a, b) -> b
snd ((Int, f (Forward a)) -> f (Forward a))
-> (Int, f (Forward a)) -> f (Forward a)
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, Forward a))
-> Int -> f a -> (Int, f (Forward a))
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Int -> Int -> a -> (Int, Forward a)
forall {a} {a}.
(Eq a, Num a, Num a) =>
a -> a -> a -> (a, Forward a)
inner Int
i) (Int
0 :: Int) f a
as in ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, b
b), a -> b -> c
g a
a b
b)
  inner :: a -> a -> a -> (a, Forward a)
inner !a
i !a
j a
a = (a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j then a -> a -> Forward a
forall a. a -> a -> Forward a
bundle a
a a
1 else Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Forward a)
a)
  b0 :: b
b0 = f (Forward a) -> b
f (a -> Forward a
Scalar (Forward a) -> Forward a
forall t. Mode t => Scalar t -> t
auto (a -> Forward a) -> f a -> f (Forward a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
as)
  dropIx :: ((a, a), b) -> (a, b)
dropIx ((a
_,a
b),b
bs) = (a
b,b
bs)

-- we can't transpose arbitrary traversables, since we can't construct one out of whole cloth, and the outer
-- traversable could be empty. So instead we use one as a 'skeleton'
transposeWith :: (Functor f, Foldable f, Traversable g) => (b -> f a -> c) -> f (g a) -> g b -> g c
transposeWith :: forall (f :: * -> *) (g :: * -> *) b a c.
(Functor f, Foldable f, Traversable g) =>
(b -> f a -> c) -> f (g a) -> g b -> g c
transposeWith b -> f a -> c
f f (g a)
as = (f [a], g c) -> g c
forall a b. (a, b) -> b
snd ((f [a], g c) -> g c) -> (g b -> (f [a], g c)) -> g b -> g c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f [a] -> b -> (f [a], c)) -> f [a] -> g b -> (f [a], g c)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL f [a] -> b -> (f [a], c)
go f [a]
xss0 where
  go :: f [a] -> b -> (f [a], c)
go f [a]
xss b
b = ([a] -> [a]
forall a. HasCallStack => [a] -> [a]
tail ([a] -> [a]) -> f [a] -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f [a]
xss, b -> f a -> c
f b
b ([a] -> a
forall a. HasCallStack => [a] -> a
head ([a] -> a) -> f [a] -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f [a]
xss))
  xss0 :: f [a]
xss0 = g a -> [a]
forall a. g a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (g a -> [a]) -> f (g a) -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (g a)
as

mul :: (Num a) => Forward a -> Forward a -> Forward a
mul :: forall a. Num a => Forward a -> Forward a -> Forward a
mul = (Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a))
-> (D (Forward a)
    -> D (Forward a) -> (D (Forward a), D (Forward a)))
-> Forward a
-> Forward a
-> Forward a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 a -> a -> a
Scalar (Forward a) -> Scalar (Forward a) -> Scalar (Forward a)
forall a. Num a => a -> a -> a
(*) (\D (Forward a)
x D (Forward a)
y -> (D (Forward a)
y, D (Forward a)
x))