{-# LANGUAGE TupleSections #-}
{-# LANGUAGE CPP, TypeOperators, FlexibleContexts, TypeFamilies
, GeneralizedNewtypeDeriving, StandaloneDeriving, UndecidableInstances #-}
{-# OPTIONS_GHC -Wall -fno-warn-orphans #-}
module Data.LinearMap
( (:-*) , linear, lapply, atBasis, idL, (*.*)
, inLMap, inLMap2, inLMap3
, liftMS, liftMS2, liftMS3
, liftL, liftL2, liftL3
, exlL, exrL, forkL, firstL, secondL
, inlL, inrL, joinL
)
where
#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative (Applicative)
#endif
import Control.Applicative (liftA2, liftA3)
import Control.Arrow (first,second)
import Data.MemoTrie (HasTrie(..),(:->:))
import Data.AdditiveGroup (Sum(..), AdditiveGroup(..))
import Data.VectorSpace (VectorSpace(..))
import Data.Basis (HasBasis(..), linearCombo)
type MSum a = Maybe (Sum a)
jsum :: a -> MSum a
jsum = Just . Sum
type LMap' u v = MSum (Basis u :->: v)
infixr 1 :-*
newtype u :-* v = LMap { unLMap :: LMap' u v }
deriving instance (HasTrie (Basis u), AdditiveGroup v) => AdditiveGroup (u :-* v)
instance (HasTrie (Basis u), VectorSpace v) =>
VectorSpace (u :-* v) where
type Scalar (u :-* v) = Scalar v
(*^) s = (inLMap.liftMS.fmap) (s *^)
exlL :: ( HasBasis a, HasTrie (Basis a), HasBasis b, HasTrie (Basis b)
, Scalar a ~ Scalar b )
=> (a,b) :-* a
exlL = linear fst
exrL :: ( HasBasis a, HasTrie (Basis a), HasBasis b, HasTrie (Basis b)
, Scalar a ~ Scalar b )
=> (a,b) :-* b
exrL = linear snd
forkL :: (HasTrie (Basis a), HasBasis c, HasBasis d)
=> (a :-* c) -> (a :-* d) -> (a :-* (c,d))
forkL = (inLMap2.liftL2) (,)
firstL :: ( HasBasis u, HasBasis u', HasBasis v
, HasTrie (Basis u), HasTrie (Basis v)
, Scalar u ~ Scalar v, Scalar u ~ Scalar u'
) =>
(u :-* u') -> ((u,v) :-* (u',v))
firstL = linear.first.lapply
secondL :: ( HasBasis u, HasBasis v, HasBasis v'
, HasTrie (Basis u), HasTrie (Basis v)
, Scalar u ~ Scalar v, Scalar u ~ Scalar v'
) =>
(v :-* v') -> ((u,v) :-* (u,v'))
secondL = linear.second.lapply
inlL :: (HasBasis a, HasTrie (Basis a), HasBasis b)
=> a :-* (a,b)
inlL = linear (,zeroV)
inrL :: (HasBasis a, HasBasis b, HasTrie (Basis b))
=> b :-* (a,b)
inrL = linear (zeroV,)
joinL :: ( HasBasis a, HasTrie (Basis a)
, HasBasis b, HasTrie (Basis b)
, Scalar a ~ Scalar b, Scalar a ~ Scalar c
, VectorSpace c )
=> (a :-* c) -> (b :-* c) -> ((a,b) :-* c)
f `joinL` g = linear (\ (a,b) -> lapply f a ^+^ lapply g b)
linear :: (HasBasis u, HasTrie (Basis u)) =>
(u -> v) -> (u :-* v)
linear f = LMap (jsum (trie (f . basisValue)))
atZ :: AdditiveGroup b => (a -> b) -> (MSum a -> b)
atZ f = maybe zeroV (f . getSum)
inLMap :: (LMap' r s -> LMap' t u) -> ((r :-* s) -> (t :-* u))
inLMap = unLMap ~> LMap
inLMap2 :: (LMap' r s -> LMap' t u -> LMap' v w)
-> ((r :-* s) -> (t :-* u) -> (v :-* w))
inLMap2 = unLMap ~> inLMap
inLMap3 :: (LMap' r s -> LMap' t u -> LMap' v w -> LMap' x y)
-> ((r :-* s) -> (t :-* u) -> (v :-* w) -> (x :-* y))
inLMap3 = unLMap ~> inLMap2
lapply :: ( VectorSpace v, Scalar u ~ Scalar v
, HasBasis u, HasTrie (Basis u) ) =>
(u :-* v) -> (u -> v)
lapply = atZ lapply' . unLMap
atBasis :: (AdditiveGroup v, HasTrie (Basis u)) =>
(u :-* v) -> Basis u -> v
LMap m `atBasis` b = atZ (`untrie` b) m
lapply' :: ( VectorSpace v, Scalar u ~ Scalar v
, HasBasis u, HasTrie (Basis u) ) =>
(Basis u :->: v) -> (u -> v)
lapply' tr = linearCombo . fmap (first (untrie tr)) . decompose
idL :: (HasBasis u, HasTrie (Basis u)) =>
u :-* u
idL = linear id
infixr 9 *.*
(*.*) :: ( HasTrie (Basis u)
, HasBasis v, HasTrie (Basis v)
, VectorSpace w
, Scalar v ~ Scalar w ) =>
(v :-* w) -> (u :-* v) -> (u :-* w)
(*.*) vw = (inLMap.fmap.fmap.fmap) (lapply vw)
liftMS :: (a -> b) -> (MSum a -> MSum b)
liftMS = fmap.fmap
liftMS2 :: (AdditiveGroup a, AdditiveGroup b) =>
(a -> b -> c) ->
(MSum a -> MSum b -> MSum c)
liftMS2 _ Nothing Nothing = Nothing
liftMS2 h ma mb = Just (Sum (h (fromMS ma) (fromMS mb)))
liftMS3 :: (AdditiveGroup a, AdditiveGroup b, AdditiveGroup c) =>
(a -> b -> c -> d) ->
(MSum a -> MSum b -> MSum c -> MSum d)
liftMS3 _ Nothing Nothing Nothing = Nothing
liftMS3 h ma mb mc = Just (Sum (h (fromMS ma) (fromMS mb) (fromMS mc)))
fromMS :: AdditiveGroup u => MSum u -> u
fromMS Nothing = zeroV
fromMS (Just (Sum u)) = u
liftL :: Functor f => (a -> b) -> MSum (f a) -> MSum (f b)
liftL = liftMS . fmap
liftL2 :: (Applicative f, AdditiveGroup (f a), AdditiveGroup (f b)) =>
(a -> b -> c)
-> (MSum (f a) -> MSum (f b) -> MSum (f c))
liftL2 = liftMS2 . liftA2
liftL3 :: ( Applicative f
, AdditiveGroup (f a), AdditiveGroup (f b), AdditiveGroup (f c)) =>
(a -> b -> c -> d)
-> (MSum (f a) -> MSum (f b) -> MSum (f c) -> MSum (f d))
liftL3 = liftMS3 . liftA3
(~>) :: (a' -> a) -> (b -> b') -> ((a -> b) -> (a' -> b'))
(f ~> h) g = h . g . f