{-# LANGUAGE TypeOperators, MultiParamTypeClasses, UndecidableInstances
, TypeSynonymInstances, FlexibleInstances
, FlexibleContexts, TypeFamilies
, ScopedTypeVariables, CPP #-}
{-# OPTIONS_GHC -Wall #-}
module Data.Maclaurin
(
(:>)(D), powVal, derivative, derivAtBasis
, (:~>), pureD
, fmapD, (<$>>), liftD2, liftD3
, idD, fstD, sndD
, linearD, distrib
, (>-<)
, pairD, unpairD, tripleD, untripleD
)
where
import Data.Function (on)
import Data.VectorSpace
import Data.NumInstances ()
import Data.MemoTrie
import Data.Basis
import Data.LinearMap
import Data.Boolean
#if MIN_VERSION_base(4,8,0)
import Prelude hiding ((<*))
#endif
infixr 9 `D`
data a :> b = D { powVal :: b, derivative :: a :-* (a :> b) }
type a :~> b = a -> (a:>b)
noOv :: String -> a
noOv op = error (op ++ ": not defined on a :> b")
pureD :: (AdditiveGroup b, HasBasis a, HasTrie (Basis a)) => b -> a:>b
pureD b = b `D` zeroV
infixl 4 <$>>
fmapD, (<$>>) :: HasTrie (Basis a) => (b -> c) -> (a :> b) -> (a :> c)
fmapD f = lf
where
lf (D b0 b') = D (f b0) ((inLMap.liftL) lf b')
(<$>>) = fmapD
liftD2 :: (HasBasis a, HasTrie (Basis a), AdditiveGroup b, AdditiveGroup c) =>
(b -> c -> d) -> (a :> b) -> (a :> c) -> (a :> d)
liftD2 f = lf
where
lf (D b0 b') (D c0 c') = D (f b0 c0) ((inLMap2.liftL2) lf b' c')
liftD3 :: (HasBasis a, HasTrie (Basis a)
, AdditiveGroup b, AdditiveGroup c, AdditiveGroup d) =>
(b -> c -> d -> e)
-> (a :> b) -> (a :> c) -> (a :> d) -> (a :> e)
liftD3 f = lf
where
lf (D b0 b') (D c0 c') (D d0 d') =
D (f b0 c0 d0) ((inLMap3.liftL3) lf b' c' d')
idD :: (VectorSpace u , HasBasis u, HasTrie (Basis u)) =>
u :~> u
idD = linearD id
linearD :: (HasBasis u, HasTrie (Basis u), AdditiveGroup v) =>
(u -> v) -> (u :~> v)
linearD f = \ u -> f u `D` d
where
d = linear (pureD . f)
fstD :: ( HasBasis a, HasTrie (Basis a)
, HasBasis b, HasTrie (Basis b)
, Scalar a ~ Scalar b
) => (a,b) :~> a
fstD = linearD fst
sndD :: ( HasBasis a, HasTrie (Basis a)
, HasBasis b, HasTrie (Basis b)
, Scalar a ~ Scalar b
) => (a,b) :~> b
sndD = linearD snd
distrib :: forall a b c u. (HasBasis a, HasTrie (Basis a) , AdditiveGroup u) =>
(b -> c -> u) -> (a :> b) -> (a :> c) -> (a :> u)
distrib op = (#)
where
u@(D u0 u') # v@(D v0 v') =
D (u0 `op` v0) ( (inLMap.liftMS) (inTrie ((# v) .)) u' ^+^
(inLMap.liftMS) (inTrie ((u #) .)) v' )
instance Show b => Show (a :> b) where
show (D b0 _) = "D " ++ show b0 ++ " ..."
instance Eq (a :> b) where (==) = noOv "(==)"
type instance BooleanOf (a :> b) = BooleanOf b
instance (AdditiveGroup v, HasBasis u, HasTrie (Basis u), IfB v) =>
IfB (u :> v) where
ifB = liftD2 . ifB
instance OrdB v => OrdB (u :> v) where
(<*) = (<*) `on` powVal
instance ( AdditiveGroup b, HasBasis a, HasTrie (Basis a)
, OrdB b, IfB b, Ord b) => Ord (a :> b) where
compare = compare `on` powVal
min = minB
max = maxB
instance (HasBasis a, HasTrie (Basis a), AdditiveGroup u) => AdditiveGroup (a :> u) where
zeroV = pureD zeroV
negateV = fmapD negateV
D a0 a' ^+^ D b0 b' = D (a0 ^+^ b0) (a' ^+^ b')
instance (HasBasis a, HasTrie (Basis a), VectorSpace u)
=> VectorSpace (a :> u) where
type Scalar (a :> u) = (a :> Scalar u)
(*^) = distrib (*^)
instance ( InnerSpace u, s ~ Scalar u, AdditiveGroup s
, HasBasis a, HasTrie (Basis a) ) =>
InnerSpace (a :> u) where
(<.>) = distrib (<.>)
infix 0 >-<
(>-<) :: (HasBasis a, HasTrie (Basis a), VectorSpace u) =>
(u -> u) -> ((a :> u) -> (a :> Scalar u))
-> (a :> u) -> (a :> u)
f >-< f' = \ u@(D u0 u') -> D (f u0) ((inLMap.liftMS) (f' u *^) u')
instance ( HasBasis a, s ~ Scalar a, HasTrie (Basis a)
, Num s, VectorSpace s, Scalar s ~ s
)
=> Num (a:>s) where
fromInteger = pureD . fromInteger
(+) = (^+^)
(*) = distrib (*)
negate = negate >-< -1
abs = abs >-< signum
signum = signum >-< 0
instance ( HasBasis a, s ~ Scalar a, HasTrie (Basis a)
, Fractional s, VectorSpace s, Scalar s ~ s)
=> Fractional (a:>s) where
fromRational = pureD . fromRational
recip = recip >-< - recip sqr
sqr :: Num a => a -> a
sqr x = x*x
instance ( HasBasis a, s ~ Scalar a, HasTrie (Basis a)
, Floating s, VectorSpace s, Scalar s ~ s)
=> Floating (a:>s) where
pi = pureD pi
exp = exp >-< exp
log = log >-< recip
sqrt = sqrt >-< recip (2 * sqrt)
sin = sin >-< cos
cos = cos >-< - sin
sinh = sinh >-< cosh
cosh = cosh >-< sinh
asin = asin >-< recip (sqrt (1-sqr))
acos = acos >-< recip (- sqrt (1-sqr))
atan = atan >-< recip (1+sqr)
asinh = asinh >-< recip (sqrt (1+sqr))
acosh = acosh >-< recip (- sqrt (sqr-1))
atanh = atanh >-< recip (1-sqr)
derivAtBasis :: (HasTrie (Basis a), HasBasis a, AdditiveGroup b) =>
(a :> b) -> (Basis a -> (a :> b))
derivAtBasis f = atBasis (derivative f)
pairD :: (HasBasis a, HasTrie (Basis a), VectorSpace b, VectorSpace c)
=> (a:>b,a:>c) -> a:>(b,c)
pairD (u,v) = liftD2 (,) u v
unpairD :: HasTrie (Basis a) => (a :> (b,c)) -> (a:>b, a:>c)
unpairD d = (fst <$>> d, snd <$>> d)
tripleD :: ( HasBasis a, HasTrie (Basis a)
, VectorSpace b, VectorSpace c, VectorSpace d
) => (a:>b,a:>c,a:>d) -> a:>(b,c,d)
tripleD (u,v,w) = liftD3 (,,) u v w
untripleD :: HasTrie (Basis a) => (a :> (b,c,d)) -> (a:>b, a:>c, a:>d)
untripleD d =
((\ (a,_,_) -> a) <$>> d, (\ (_,b,_) -> b) <$>> d, (\ (_,_,c) -> c) <$>> d)