{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Unsafe and often partial combinators intended for internal usage.
--
-- Handle with care.
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Sparse
  ( Monomial(..)
  , emptyMonomial
  , addToMonomial
  , indices
  , Sparse(..)
  , apply
  , vars
  , d, d', ds
  , skeleton
  , spartial
  , partial
  , vgrad
  , vgrad'
  , vgrads
  , Grad(..)
  , Grads(..)
  , terms
  , primal
  ) where

import Prelude hiding (lookup)
import Control.Comonad.Cofree
import Control.Monad (join, guard)
import Data.Data
import Data.IntMap (IntMap, unionWith, findWithDefault, singleton, lookup)
import qualified Data.IntMap as IntMap
import Data.Number.Erf
import Data.Traversable
import Data.Typeable ()
import Numeric
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Sparse.Common
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | We only store partials in sorted order, so the map contained in a partial
-- will only contain partials with equal or greater keys to that of the map in
-- which it was found. This should be key for efficiently computing sparse hessians.
-- there are only @n + k - 1@ choose @k@ distinct nth partial derivatives of a
-- function with k inputs.
data Sparse a
  = Sparse !a (IntMap (Sparse a))
  | Zero
  deriving (Int -> Sparse a -> ShowS
[Sparse a] -> ShowS
Sparse a -> String
(Int -> Sparse a -> ShowS)
-> (Sparse a -> String) -> ([Sparse a] -> ShowS) -> Show (Sparse a)
forall a. Show a => Int -> Sparse a -> ShowS
forall a. Show a => [Sparse a] -> ShowS
forall a. Show a => Sparse a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Sparse a -> ShowS
showsPrec :: Int -> Sparse a -> ShowS
$cshow :: forall a. Show a => Sparse a -> String
show :: Sparse a -> String
$cshowList :: forall a. Show a => [Sparse a] -> ShowS
showList :: [Sparse a] -> ShowS
Show, Typeable (Sparse a)
Typeable (Sparse a) =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> Sparse a -> c (Sparse a))
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c (Sparse a))
-> (Sparse a -> Constr)
-> (Sparse a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c (Sparse a)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e))
    -> Maybe (c (Sparse a)))
-> ((forall b. Data b => b -> b) -> Sparse a -> Sparse a)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> Sparse a -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> Sparse a -> r)
-> (forall u. (forall d. Data d => d -> u) -> Sparse a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> Sparse a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a))
-> Data (Sparse a)
Sparse a -> Constr
Sparse a -> DataType
(forall b. Data b => b -> b) -> Sparse a -> Sparse a
forall a. Data a => Typeable (Sparse a)
forall a. Data a => Sparse a -> Constr
forall a. Data a => Sparse a -> DataType
forall a.
Data a =>
(forall b. Data b => b -> b) -> Sparse a -> Sparse a
forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Sparse a -> u
forall a u.
Data a =>
(forall d. Data d => d -> u) -> Sparse a -> [u]
forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
forall a.
Typeable a =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Sparse a -> u
forall u. (forall d. Data d => d -> u) -> Sparse a -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
$cgfoldl :: forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Sparse a -> c (Sparse a)
$cgunfold :: forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Sparse a)
$ctoConstr :: forall a. Data a => Sparse a -> Constr
toConstr :: Sparse a -> Constr
$cdataTypeOf :: forall a. Data a => Sparse a -> DataType
dataTypeOf :: Sparse a -> DataType
$cdataCast1 :: forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Sparse a))
$cdataCast2 :: forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Sparse a))
$cgmapT :: forall a.
Data a =>
(forall b. Data b => b -> b) -> Sparse a -> Sparse a
gmapT :: (forall b. Data b => b -> b) -> Sparse a -> Sparse a
$cgmapQl :: forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
$cgmapQr :: forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Sparse a -> r
$cgmapQ :: forall a u.
Data a =>
(forall d. Data d => d -> u) -> Sparse a -> [u]
gmapQ :: forall u. (forall d. Data d => d -> u) -> Sparse a -> [u]
$cgmapQi :: forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Sparse a -> u
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Sparse a -> u
$cgmapM :: forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
$cgmapMp :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
$cgmapMo :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Sparse a -> m (Sparse a)
Data, Typeable)

vars :: (Traversable f, Num a) => f a -> f (Sparse a)
vars :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Sparse a)
vars = (Int, f (Sparse a)) -> f (Sparse a)
forall a b. (a, b) -> b
snd ((Int, f (Sparse a)) -> f (Sparse a))
-> (f a -> (Int, f (Sparse a))) -> f a -> f (Sparse a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> a -> (Int, Sparse a)) -> Int -> f a -> (Int, f (Sparse a))
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Int -> a -> (Int, Sparse a)
forall {a}. Num a => Int -> a -> (Int, Sparse a)
var Int
0 where
  var :: Int -> a -> (Int, Sparse a)
var !Int
n a
a = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse a
a (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ Int -> Sparse a -> IntMap (Sparse a)
forall a. Int -> a -> IntMap a
singleton Int
n (Sparse a -> IntMap (Sparse a)) -> Sparse a -> IntMap (Sparse a)
forall a b. (a -> b) -> a -> b
$ Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Sparse a)
1)
{-# INLINE vars #-}

apply :: (Traversable f, Num a) => (f (Sparse a) -> b) -> f a -> b
apply :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> b
f = f (Sparse a) -> b
f (f (Sparse a) -> b) -> (f a -> f (Sparse a)) -> f a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> f (Sparse a)
forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Sparse a)
vars
{-# INLINE apply #-}

d :: (Traversable f, Num a) => f b -> Sparse a -> f a
d :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> f a
d f b
fs Sparse a
Zero = a
0 a -> f b -> f a
forall a b. a -> f b -> f a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs
d f b
fs (Sparse a
_ IntMap (Sparse a)
da) = (Int, f a) -> f a
forall a b. (a, b) -> b
snd ((Int, f a) -> f a) -> (Int, f a) -> f a
forall a b. (a -> b) -> a -> b
$ (Int -> b -> (Int, a)) -> Int -> f b -> (Int, f a)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ !Int
n b
_ -> (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> (Sparse a -> a) -> Maybe (Sparse a) -> a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe a
0 Sparse a -> a
forall a. Num a => Sparse a -> a
primal (Maybe (Sparse a) -> a) -> Maybe (Sparse a) -> a
forall a b. (a -> b) -> a -> b
$ Int -> IntMap (Sparse a) -> Maybe (Sparse a)
forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap (Sparse a)
da)) Int
0 f b
fs
{-# INLINE d #-}

d' :: (Traversable f, Num a) => f a -> Sparse a -> (a, f a)
d' :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> Sparse a -> (a, f a)
d' f a
fs Sparse a
Zero = (a
0, a
0 a -> f a -> f a
forall a b. a -> f b -> f a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f a
fs)
d' f a
fs (Sparse a
a IntMap (Sparse a)
da) = (a
a, (Int, f a) -> f a
forall a b. (a, b) -> b
snd ((Int, f a) -> f a) -> (Int, f a) -> f a
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, a)) -> Int -> f a -> (Int, f a)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ !Int
n a
_ -> (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> (Sparse a -> a) -> Maybe (Sparse a) -> a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe a
0 Sparse a -> a
forall a. Num a => Sparse a -> a
primal (Maybe (Sparse a) -> a) -> Maybe (Sparse a) -> a
forall a b. (a -> b) -> a -> b
$ Int -> IntMap (Sparse a) -> Maybe (Sparse a)
forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap (Sparse a)
da)) Int
0 f a
fs)
{-# INLINE d' #-}

ds :: (Traversable f, Num a) => f b -> Sparse a -> Cofree f a
ds :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> Cofree f a
ds f b
fs Sparse a
Zero = Cofree f a
r where r :: Cofree f a
r = a
0 a -> f (Cofree f a) -> Cofree f a
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Cofree f a
r Cofree f a -> f b -> f (Cofree f a)
forall a b. a -> f b -> f a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f b
fs)
ds f b
fs as :: Sparse a
as@(Sparse a
a IntMap (Sparse a)
_) = a
a a -> f (Cofree f a) -> Cofree f a
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Int -> Cofree f a
go Monomial
emptyMonomial (Int -> Cofree f a) -> f Int -> f (Cofree f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Int
fns) where
  fns :: f Int
fns = f b -> f Int
forall (f :: * -> *) a. Traversable f => f a -> f Int
skeleton f b
fs
  -- go :: Monomial -> Int -> Cofree f a
  go :: Monomial -> Int -> Cofree f a
go Monomial
ix Int
i = [Int] -> Sparse a -> a
forall a. Num a => [Int] -> Sparse a -> a
partial (Monomial -> [Int]
indices Monomial
ix') Sparse a
as a -> f (Cofree f a) -> Cofree f a
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (Monomial -> Int -> Cofree f a
go Monomial
ix' (Int -> Cofree f a) -> f Int -> f (Cofree f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Int
fns) where
    ix' :: Monomial
ix' = Int -> Monomial -> Monomial
addToMonomial Int
i Monomial
ix
{-# INLINE ds #-}

partialS :: Num a => [Int] -> Sparse a -> Sparse a
partialS :: forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS []     Sparse a
a             = Sparse a
a
partialS (Int
n:[Int]
ns) (Sparse a
_ IntMap (Sparse a)
da) = [Int] -> Sparse a -> Sparse a
forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS [Int]
ns (Sparse a -> Sparse a) -> Sparse a -> Sparse a
forall a b. (a -> b) -> a -> b
$ Sparse a -> Int -> IntMap (Sparse a) -> Sparse a
forall a. a -> Int -> IntMap a -> a
findWithDefault Sparse a
forall a. Sparse a
Zero Int
n IntMap (Sparse a)
da
partialS [Int]
_      Sparse a
Zero          = Sparse a
forall a. Sparse a
Zero
{-# INLINE partialS #-}

partial :: Num a => [Int] -> Sparse a -> a
partial :: forall a. Num a => [Int] -> Sparse a -> a
partial []     (Sparse a
a IntMap (Sparse a)
_)  = a
a
partial (Int
n:[Int]
ns) (Sparse a
_ IntMap (Sparse a)
da) = [Int] -> Sparse a -> a
forall a. Num a => [Int] -> Sparse a -> a
partial [Int]
ns (Sparse a -> a) -> Sparse a -> a
forall a b. (a -> b) -> a -> b
$ Sparse a -> Int -> IntMap (Sparse a) -> Sparse a
forall a. a -> Int -> IntMap a -> a
findWithDefault (Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Sparse a)
0) Int
n IntMap (Sparse a)
da
partial [Int]
_      Sparse a
Zero          = a
0
{-# INLINE partial #-}

spartial :: Num a => [Int] -> Sparse a -> Maybe a
spartial :: forall a. Num a => [Int] -> Sparse a -> Maybe a
spartial [] (Sparse a
a IntMap (Sparse a)
_) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
spartial (Int
n:[Int]
ns) (Sparse a
_ IntMap (Sparse a)
da) = do
  Sparse a
a' <- Int -> IntMap (Sparse a) -> Maybe (Sparse a)
forall a. Int -> IntMap a -> Maybe a
lookup Int
n IntMap (Sparse a)
da
  [Int] -> Sparse a -> Maybe a
forall a. Num a => [Int] -> Sparse a -> Maybe a
spartial [Int]
ns Sparse a
a'
spartial [Int]
_  Sparse a
Zero         = Maybe a
forall a. Maybe a
Nothing
{-# INLINE spartial #-}

primal :: Num a => Sparse a -> a
primal :: forall a. Num a => Sparse a -> a
primal (Sparse a
a IntMap (Sparse a)
_) = a
a
primal Sparse a
Zero = a
0

instance Num a => Mode (Sparse a) where
  type Scalar (Sparse a) = a
  auto :: Scalar (Sparse a) -> Sparse a
auto Scalar (Sparse a)
a = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse a
Scalar (Sparse a)
a IntMap (Sparse a)
forall a. IntMap a
IntMap.empty
  zero :: Sparse a
zero = Sparse a
forall a. Sparse a
Zero
  isKnownZero :: Sparse a -> Bool
isKnownZero Sparse a
Zero = Bool
True
  isKnownZero Sparse a
_ = Bool
False
  isKnownConstant :: Sparse a -> Bool
isKnownConstant Sparse a
Zero = Bool
True
  isKnownConstant (Sparse a
_ IntMap (Sparse a)
m) = IntMap (Sparse a) -> Bool
forall a. IntMap a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap (Sparse a)
m
  asKnownConstant :: Sparse a -> Maybe (Scalar (Sparse a))
asKnownConstant Sparse a
Zero = a -> Maybe a
forall a. a -> Maybe a
Just a
0
  asKnownConstant (Sparse a
a IntMap (Sparse a)
m) = a
a a -> Maybe () -> Maybe a
forall a b. a -> Maybe b -> Maybe a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (IntMap (Sparse a) -> Bool
forall a. IntMap a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null IntMap (Sparse a)
m)
  Sparse a
Zero        ^* :: Sparse a -> Scalar (Sparse a) -> Sparse a
^* Scalar (Sparse a)
_ = Sparse a
forall a. Sparse a
Zero
  Sparse a
a IntMap (Sparse a)
as ^* Scalar (Sparse a)
b = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Sparse a)
b) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Sparse a -> Scalar (Sparse a) -> Sparse a
forall t. Mode t => t -> Scalar t -> t
^* Scalar (Sparse a)
b) IntMap (Sparse a)
as
  Scalar (Sparse a)
_ *^ :: Scalar (Sparse a) -> Sparse a -> Sparse a
*^ Sparse a
Zero        = Sparse a
forall a. Sparse a
Zero
  Scalar (Sparse a)
a *^ Sparse a
b IntMap (Sparse a)
bs = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
Scalar (Sparse a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Sparse a)
a Scalar (Sparse a) -> Sparse a -> Sparse a
forall t. Mode t => Scalar t -> t -> t
*^) IntMap (Sparse a)
bs
  Sparse a
Zero        ^/ :: Fractional (Scalar (Sparse a)) =>
Sparse a -> Scalar (Sparse a) -> Sparse a
^/ Scalar (Sparse a)
_ = Sparse a
forall a. Sparse a
Zero
  Sparse a
a IntMap (Sparse a)
as ^/ Scalar (Sparse a)
b = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Sparse a)
b) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Sparse a -> Scalar (Sparse a) -> Sparse a
forall t. (Mode t, Fractional (Scalar t)) => t -> Scalar t -> t
^/ Scalar (Sparse a)
b) IntMap (Sparse a)
as

infixr 6 <+>

(<+>) :: Num a => Sparse a -> Sparse a -> Sparse a
Sparse a
Zero <+> :: forall a. Num a => Sparse a -> Sparse a -> Sparse a
<+> Sparse a
a = Sparse a
a
Sparse a
a <+> Sparse a
Zero = Sparse a
a
Sparse a
a IntMap (Sparse a)
as <+> Sparse a
b IntMap (Sparse a)
bs = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a -> Sparse a)
-> IntMap (Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith Sparse a -> Sparse a -> Sparse a
forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>) IntMap (Sparse a)
as IntMap (Sparse a)
bs

-- The instances for Jacobian for Sparse and Tower are almost identical;
-- could easily be made exactly equal by small changes.
instance Num a => Jacobian (Sparse a) where
  type D (Sparse a) = Sparse a
  unary :: (Scalar (Sparse a) -> Scalar (Sparse a))
-> D (Sparse a) -> Sparse a -> Sparse a
unary Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
_ Sparse a
Zero = Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0)
  unary Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
dadb (Sparse a
pb IntMap (Sparse a)
bs) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a) -> D (Sparse a) -> D (Sparse a)
forall a. Num a => a -> a -> a
* D (Sparse a)
dadb) IntMap (Sparse a)
bs

  lift1 :: (Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a) -> D (Sparse a)) -> Sparse a -> Sparse a
lift1 Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a)
_ Sparse a
Zero = Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0)
  lift1 Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a)
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
bs) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a) -> D (Sparse a) -> D (Sparse a)
forall a. Num a => a -> a -> a
* D (Sparse a) -> D (Sparse a)
df D (Sparse a)
Sparse a
b) IntMap (Sparse a)
bs

  lift1_ :: (Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a) -> D (Sparse a) -> D (Sparse a))
-> Sparse a
-> Sparse a
lift1_ Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> D (Sparse a)
_  Sparse a
Zero = Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0)
  lift1_ Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> D (Sparse a)
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
bs) = Sparse a
a where
    a :: Sparse a
a = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a) -> D (Sparse a) -> D (Sparse a)
df D (Sparse a)
Sparse a
a D (Sparse a)
Sparse a
b Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
bs

  binary :: (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a))
-> D (Sparse a) -> D (Sparse a) -> Sparse a -> Sparse a -> Sparse a
binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
_    D (Sparse a)
_    Sparse a
Zero           Sparse a
Zero           = Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0 a
Scalar (Sparse a)
0)
  binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
_    D (Sparse a)
dadc Sparse a
Zero           (Sparse a
pc IntMap (Sparse a)
dc) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0  a
Scalar (Sparse a)
pc) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
Sparse a
dadc Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc
  binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
dadb D (Sparse a)
_    (Sparse a
pb IntMap (Sparse a)
db) Sparse a
Zero           = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb a
Scalar (Sparse a)
0 ) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
Sparse a
dadb Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db
  binary Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
dadb D (Sparse a)
dadc (Sparse a
pb IntMap (Sparse a)
db) (Sparse a
pc IntMap (Sparse a)
dc) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb a
Scalar (Sparse a)
pc) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$
    (Sparse a -> Sparse a -> Sparse a)
-> IntMap (Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith Sparse a -> Sparse a -> Sparse a
forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>)  ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
Sparse a
dadb Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
Sparse a
dadc Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc)

  lift2 :: (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a)))
-> Sparse a
-> Sparse a
-> Sparse a
lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
_  Sparse a
Zero             Sparse a
Zero = Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0 a
Scalar (Sparse a)
0)
  lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
Zero c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0 a
Scalar (Sparse a)
pc) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a
dadc Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc where dadc :: Sparse a
dadc = (Sparse a, Sparse a) -> Sparse a
forall a b. (a, b) -> b
snd (D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df D (Sparse a)
Sparse a
forall t. Mode t => t
zero D (Sparse a)
Sparse a
c)
  lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) Sparse a
Zero = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb a
Scalar (Sparse a)
0) (IntMap (Sparse a) -> Sparse a) -> IntMap (Sparse a) -> Sparse a
forall a b. (a -> b) -> a -> b
$ (Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
* Sparse a
dadb) IntMap (Sparse a)
db where dadb :: Sparse a
dadb = (Sparse a, Sparse a) -> Sparse a
forall a b. (a, b) -> a
fst (D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df D (Sparse a)
Sparse a
b D (Sparse a)
Sparse a
forall t. Mode t => t
zero)
  lift2 Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb a
Scalar (Sparse a)
pc) IntMap (Sparse a)
da where
    (D (Sparse a)
dadb, D (Sparse a)
dadc) = D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df D (Sparse a)
Sparse a
b D (Sparse a)
Sparse a
c
    da :: IntMap (Sparse a)
da = (Sparse a -> Sparse a -> Sparse a)
-> IntMap (Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith Sparse a -> Sparse a -> Sparse a
forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
Sparse a
dadb Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (D (Sparse a)
Sparse a
dadc Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc)

  lift2_ :: (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a))
-> (D (Sparse a)
    -> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a)))
-> Sparse a
-> Sparse a
-> Sparse a
lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
_  Sparse a
Zero             Sparse a
Zero = Scalar (Sparse a) -> Sparse a
forall t. Mode t => Scalar t -> t
auto (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0 a
Scalar (Sparse a)
0)
  lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) Sparse a
Zero = Sparse a
a where a :: Sparse a
a = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb a
Scalar (Sparse a)
0) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map ((Sparse a, Sparse a) -> Sparse a
forall a b. (a, b) -> a
fst (D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df D (Sparse a)
Sparse a
a D (Sparse a)
Sparse a
b D (Sparse a)
Sparse a
forall t. Mode t => t
zero) Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db)
  lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df Sparse a
Zero c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = Sparse a
a where a :: Sparse a
a = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
0 a
Scalar (Sparse a)
pc) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
* (Sparse a, Sparse a) -> Sparse a
forall a b. (a, b) -> b
snd (D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df D (Sparse a)
Sparse a
a D (Sparse a)
Sparse a
forall t. Mode t => t
zero D (Sparse a)
Sparse a
c)) IntMap (Sparse a)
dc)
  lift2_ Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df b :: Sparse a
b@(Sparse a
pb IntMap (Sparse a)
db) c :: Sparse a
c@(Sparse a
pc IntMap (Sparse a)
dc) = Sparse a
a where
    (Sparse a
dadb, Sparse a
dadc) = D (Sparse a)
-> D (Sparse a) -> D (Sparse a) -> (D (Sparse a), D (Sparse a))
df D (Sparse a)
Sparse a
a D (Sparse a)
Sparse a
b D (Sparse a)
Sparse a
c
    a :: Sparse a
a = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Scalar (Sparse a) -> Scalar (Sparse a) -> Scalar (Sparse a)
f a
Scalar (Sparse a)
pb a
Scalar (Sparse a)
pc) IntMap (Sparse a)
da
    da :: IntMap (Sparse a)
da = (Sparse a -> Sparse a -> Sparse a)
-> IntMap (Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
unionWith Sparse a -> Sparse a -> Sparse a
forall a. Num a => Sparse a -> Sparse a -> Sparse a
(<+>) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a
dadb Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
db) ((Sparse a -> Sparse a) -> IntMap (Sparse a) -> IntMap (Sparse a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Sparse a
dadc Sparse a -> Sparse a -> Sparse a
forall a. Num a => a -> a -> a
*) IntMap (Sparse a)
dc)

#define HEAD (Sparse a)
#include "instances.h"

class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
  pack :: i -> [Sparse a] -> Sparse a
  unpack :: ([a] -> [a]) -> o
  unpack' :: ([a] -> (a, [a])) -> o'

instance Num a => Grad (Sparse a) [a] (a, [a]) a where
  pack :: Sparse a -> [Sparse a] -> Sparse a
pack Sparse a
i [Sparse a]
_ = Sparse a
i
  unpack :: ([a] -> [a]) -> [a]
unpack [a] -> [a]
f = [a] -> [a]
f []
  unpack' :: ([a] -> (a, [a])) -> (a, [a])
unpack' [a] -> (a, [a])
f = [a] -> (a, [a])
f []

instance Grad i o o' a => Grad (Sparse a -> i) (a -> o) (a -> o') a where
  pack :: (Sparse a -> i) -> [Sparse a] -> Sparse a
pack Sparse a -> i
f (Sparse a
a:[Sparse a]
as) = i -> [Sparse a] -> Sparse a
forall i o o' a. Grad i o o' a => i -> [Sparse a] -> Sparse a
pack (Sparse a -> i
f Sparse a
a) [Sparse a]
as
  pack Sparse a -> i
_ [] = String -> Sparse a
forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpack :: ([a] -> [a]) -> a -> o
unpack [a] -> [a]
f a
a = ([a] -> [a]) -> o
forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack ([a] -> [a]
f ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))
  unpack' :: ([a] -> (a, [a])) -> a -> o'
unpack' [a] -> (a, [a])
f a
a = ([a] -> (a, [a])) -> o'
forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' ([a] -> (a, [a])
f ([a] -> (a, [a])) -> ([a] -> [a]) -> [a] -> (a, [a])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))

vgrad :: Grad i o o' a => i -> o
vgrad :: forall i o o' a. Grad i o o' a => i -> o
vgrad i
i = ([a] -> [a]) -> o
forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack (([Sparse a] -> Sparse a) -> [a] -> [a]
forall {f :: * -> *} {a} {a}.
(Traversable f, Num a, Num a) =>
(f (Sparse a) -> Sparse a) -> f a -> f a
unsafeGrad (i -> [Sparse a] -> Sparse a
forall i o o' a. Grad i o o' a => i -> [Sparse a] -> Sparse a
pack i
i)) where
  unsafeGrad :: (f (Sparse a) -> Sparse a) -> f a -> f a
unsafeGrad f (Sparse a) -> Sparse a
f f a
as = f a -> Sparse a -> f a
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> f a
d f a
as (Sparse a -> f a) -> Sparse a -> f a
forall a b. (a -> b) -> a -> b
$ (f (Sparse a) -> Sparse a) -> f a -> Sparse a
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> Sparse a
f f a
as
{-# INLINE vgrad #-}

vgrad' :: Grad i o o' a => i -> o'
vgrad' :: forall i o o' a. Grad i o o' a => i -> o'
vgrad' i
i = ([a] -> (a, [a])) -> o'
forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' (([Sparse a] -> Sparse a) -> [a] -> (a, [a])
forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Sparse a) -> Sparse a) -> f a -> (a, f a)
unsafeGrad' (i -> [Sparse a] -> Sparse a
forall i o o' a. Grad i o o' a => i -> [Sparse a] -> Sparse a
pack i
i)) where
  unsafeGrad' :: (f (Sparse a) -> Sparse a) -> f a -> (a, f a)
unsafeGrad' f (Sparse a) -> Sparse a
f f a
as = f a -> Sparse a -> (a, f a)
forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> Sparse a -> (a, f a)
d' f a
as (Sparse a -> (a, f a)) -> Sparse a -> (a, f a)
forall a b. (a -> b) -> a -> b
$ (f (Sparse a) -> Sparse a) -> f a -> Sparse a
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> Sparse a
f f a
as
{-# INLINE vgrad' #-}

class Num a => Grads i o a | i -> a o, o -> a i where
  packs :: i -> [Sparse a] -> Sparse a
  unpacks :: ([a] -> Cofree [] a) -> o

instance Num a => Grads (Sparse a) (Cofree [] a) a where
  packs :: Sparse a -> [Sparse a] -> Sparse a
packs Sparse a
i [Sparse a]
_ = Sparse a
i
  unpacks :: ([a] -> Cofree [] a) -> Cofree [] a
unpacks [a] -> Cofree [] a
f = [a] -> Cofree [] a
f []

instance Grads i o a => Grads (Sparse a -> i) (a -> o) a where
  packs :: (Sparse a -> i) -> [Sparse a] -> Sparse a
packs Sparse a -> i
f (Sparse a
a:[Sparse a]
as) = i -> [Sparse a] -> Sparse a
forall i o a. Grads i o a => i -> [Sparse a] -> Sparse a
packs (Sparse a -> i
f Sparse a
a) [Sparse a]
as
  packs Sparse a -> i
_ [] = String -> Sparse a
forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
  unpacks :: ([a] -> Cofree [] a) -> a -> o
unpacks [a] -> Cofree [] a
f a
a = ([a] -> Cofree [] a) -> o
forall i o a. Grads i o a => ([a] -> Cofree [] a) -> o
unpacks ([a] -> Cofree [] a
f ([a] -> Cofree [] a) -> ([a] -> [a]) -> [a] -> Cofree [] a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:))

vgrads :: Grads i o a => i -> o
vgrads :: forall i o a. Grads i o a => i -> o
vgrads i
i = ([a] -> Cofree [] a) -> o
forall i o a. Grads i o a => ([a] -> Cofree [] a) -> o
unpacks (([Sparse a] -> Sparse a) -> [a] -> Cofree [] a
forall {f :: * -> *} {a} {a}.
(Traversable f, Num a, Num a) =>
(f (Sparse a) -> Sparse a) -> f a -> Cofree f a
unsafeGrads (i -> [Sparse a] -> Sparse a
forall i o a. Grads i o a => i -> [Sparse a] -> Sparse a
packs i
i)) where
  unsafeGrads :: (f (Sparse a) -> Sparse a) -> f a -> Cofree f a
unsafeGrads f (Sparse a) -> Sparse a
f f a
as = f a -> Sparse a -> Cofree f a
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
f b -> Sparse a -> Cofree f a
ds f a
as (Sparse a -> Cofree f a) -> Sparse a -> Cofree f a
forall a b. (a -> b) -> a -> b
$ (f (Sparse a) -> Sparse a) -> f a -> Sparse a
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Sparse a) -> b) -> f a -> b
apply f (Sparse a) -> Sparse a
f f a
as
{-# INLINE vgrads #-}

isZero :: Sparse a -> Bool
isZero :: forall a. Sparse a -> Bool
isZero Sparse a
Zero = Bool
True
isZero Sparse a
_ = Bool
False

mul :: Num a => Sparse a -> Sparse a -> Sparse a
mul :: forall a. Num a => Sparse a -> Sparse a -> Sparse a
mul Sparse a
Zero Sparse a
_ = Sparse a
forall a. Sparse a
Zero
mul Sparse a
_ Sparse a
Zero = Sparse a
forall a. Sparse a
Zero
mul f :: Sparse a
f@(Sparse a
_ IntMap (Sparse a)
am) g :: Sparse a
g@(Sparse a
_ IntMap (Sparse a)
bm) = a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse (Sparse a -> a
forall a. Num a => Sparse a -> a
primal Sparse a
f a -> a -> a
forall a. Num a => a -> a -> a
* Sparse a -> a
forall a. Num a => Sparse a -> a
primal Sparse a
g) (Int -> Monomial -> IntMap (Sparse a)
derivs Int
0 Monomial
emptyMonomial) where
  derivs :: Int -> Monomial -> IntMap (Sparse a)
derivs Int
v Monomial
mi = [IntMap (Sparse a)] -> IntMap (Sparse a)
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IntMap.unions ((Int -> IntMap (Sparse a)) -> [Int] -> [IntMap (Sparse a)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> IntMap (Sparse a)
fn [Int
v..Int
kMax]) where
    fn :: Int -> IntMap (Sparse a)
fn Int
w
      | [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
zs = IntMap (Sparse a)
forall a. IntMap a
IntMap.empty
      | Bool
otherwise = Int -> Sparse a -> IntMap (Sparse a)
forall a. Int -> a -> IntMap a
IntMap.singleton Int
w (a -> IntMap (Sparse a) -> Sparse a
forall a. a -> IntMap (Sparse a) -> Sparse a
Sparse ([a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [a]
ds) (Int -> Monomial -> IntMap (Sparse a)
derivs Int
w Monomial
mi'))
      where
        mi' :: Monomial
mi' = Int -> Monomial -> Monomial
addToMonomial Int
w Monomial
mi
        ([Bool]
zs,[a]
ds) = [(Bool, a)] -> ([Bool], [a])
forall a b. [(a, b)] -> ([a], [b])
unzip (((Integer, Monomial, Monomial) -> (Bool, a))
-> [(Integer, Monomial, Monomial)] -> [(Bool, a)]
forall a b. (a -> b) -> [a] -> [b]
map (Integer, Monomial, Monomial) -> (Bool, a)
derVal (Monomial -> [(Integer, Monomial, Monomial)]
terms Monomial
mi'))
        derVal :: (Integer, Monomial, Monomial) -> (Bool, a)
derVal (Integer
bin,Monomial
mif,Monomial
mig) = (Sparse a -> Bool
forall a. Sparse a -> Bool
isZero Sparse a
fder Bool -> Bool -> Bool
|| Sparse a -> Bool
forall a. Sparse a -> Bool
isZero Sparse a
gder, Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
bin a -> a -> a
forall a. Num a => a -> a -> a
* Sparse a -> a
forall a. Num a => Sparse a -> a
primal Sparse a
fder a -> a -> a
forall a. Num a => a -> a -> a
* Sparse a -> a
forall a. Num a => Sparse a -> a
primal Sparse a
gder) where
          fder :: Sparse a
fder = [Int] -> Sparse a -> Sparse a
forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS (Monomial -> [Int]
indices Monomial
mif) Sparse a
f
          gder :: Sparse a
gder = [Int] -> Sparse a -> Sparse a
forall a. Num a => [Int] -> Sparse a -> Sparse a
partialS (Monomial -> [Int]
indices Monomial
mig) Sparse a
g
  kMax :: Int
kMax = Int
-> (((Int, Sparse a), IntMap (Sparse a)) -> Int)
-> Maybe ((Int, Sparse a), IntMap (Sparse a))
-> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) ((Int, Sparse a) -> Int
forall a b. (a, b) -> a
fst((Int, Sparse a) -> Int)
-> (((Int, Sparse a), IntMap (Sparse a)) -> (Int, Sparse a))
-> ((Int, Sparse a), IntMap (Sparse a))
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Int, Sparse a), IntMap (Sparse a)) -> (Int, Sparse a)
forall a b. (a, b) -> a
fst) (IntMap (Sparse a) -> Maybe ((Int, Sparse a), IntMap (Sparse a))
forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IntMap.maxViewWithKey IntMap (Sparse a)
am) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
-> (((Int, Sparse a), IntMap (Sparse a)) -> Int)
-> Maybe ((Int, Sparse a), IntMap (Sparse a))
-> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) ((Int, Sparse a) -> Int
forall a b. (a, b) -> a
fst((Int, Sparse a) -> Int)
-> (((Int, Sparse a), IntMap (Sparse a)) -> (Int, Sparse a))
-> ((Int, Sparse a), IntMap (Sparse a))
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Int, Sparse a), IntMap (Sparse a)) -> (Int, Sparse a)
forall a b. (a, b) -> a
fst) (IntMap (Sparse a) -> Maybe ((Int, Sparse a), IntMap (Sparse a))
forall a. IntMap a -> Maybe ((Int, a), IntMap a)
IntMap.maxViewWithKey IntMap (Sparse a)
bm)