{-# LANGUAGE CPP, BangPatterns, TypeFamilies, UnicodeSyntax, KindSignatures, DataKinds #-}
module Math.Algebra.Polynomial.Monomial.Tensor where
import Data.Typeable
import Data.Either
import Data.Proxy
import GHC.TypeLits
#if MIN_VERSION_base(4,11,0)
import Data.Semigroup
import Data.Monoid
#else
import Data.Monoid
#endif
import Math.Algebra.Polynomial.Class
import Math.Algebra.Polynomial.Pretty
data Tensor (symbol :: Symbol) (a :: *) (b :: *) = Tensor !a !b deriving (Eq,Ord,Show,Typeable)
instance (Semigroup a, Semigroup b) => Semigroup (Tensor sym a b) where
(<>) (Tensor x1 y1) (Tensor x2 y2) = Tensor (x1<>x2) (y1<>y2)
instance (Monoid a, Monoid b) => Monoid (Tensor sym a b) where
mempty = Tensor mempty mempty
mappend (Tensor x1 y1) (Tensor x2 y2) = Tensor (x1 `mappend` x2) (y1 `mappend` y2)
instance (KnownSymbol sym, Pretty a, Pretty b) => Pretty (Tensor sym a b) where
pretty t@(Tensor a b) = pretty a ++ tensorSymbol t ++ pretty b
tensorSymbol :: KnownSymbol sym => Tensor sym a b -> String
tensorSymbol = symbolVal . symProxy where
symProxy :: Tensor sym a b -> Proxy sym
symProxy _ = Proxy
flip :: Tensor sym a b -> Tensor sym b a
flip (Tensor x y) = Tensor y x
injLeft :: Monoid b => a -> Tensor sym a b
injLeft x = Tensor x mempty
injRight :: Monoid a => b -> Tensor sym a b
injRight x = Tensor mempty x
projLeft :: Tensor sym a b -> a
projLeft (Tensor x _) = x
projRight :: Tensor sym a b -> b
projRight (Tensor _ y) = y
diffTensor :: (Monomial a, Monomial b, Num c) => Either (VarM a) (VarM b) -> Int -> Tensor sym a b -> Maybe (Tensor sym a b, c)
diffTensor ei k (Tensor left right) = case ei of
Left v -> case diffM v k left of
Just (left' ,c) -> Just (Tensor left' right , c)
Nothing -> Nothing
Right v -> case diffM v k right of
Just (right',c) -> Just (Tensor left right', c)
Nothing -> Nothing
instance (KnownSymbol sym, Monomial a, Monomial b) => Monomial (Tensor sym a b) where
type VarM (Tensor sym a b) = Either (VarM a) (VarM b)
normalizeM (Tensor x y) = Tensor (normalizeM x) (normalizeM y)
isNormalM (Tensor x y) = isNormalM x && isNormalM y
fromListM list = Tensor (fromListM list1) (fromListM list2) where
(list1,list2) = partitionEithers $ map distEither list
toListM (Tensor x y) = map f (toListM x) ++ map g (toListM y) where
f (v,e) = (Left v, e)
g (v,e) = (Right v, e)
emptyM = Tensor emptyM emptyM
isEmptyM (Tensor x y) = isEmptyM x && isEmptyM y
variableM ei = case ei of
Left v -> Tensor (variableM v) emptyM
Right v -> Tensor emptyM (variableM v)
singletonM ei k = case ei of
Left v -> Tensor (singletonM v k) emptyM
Right v -> Tensor emptyM (singletonM v k)
mulM (Tensor x1 y1) (Tensor x2 y2) = Tensor (mulM x1 x2) (mulM y1 y2)
productM tensors = Tensor (productM $ map projLeft tensors) (productM $ map projRight tensors)
powM (Tensor x y) k = Tensor (powM x k) (powM y k)
divM (Tensor x1 y1) (Tensor x2 y2) = case (divM x1 x2, divM y1 y2) of
(Just z1 , Just z2) -> Just (Tensor z1 z2)
(_ , _ ) -> Nothing
diffM = diffTensor
maxDegM (Tensor x y) = max (maxDegM x) (maxDegM y)
totalDegM (Tensor x y) = totalDegM x + totalDegM y
evalM f (Tensor x y) = evalM (f . Left) x * evalM (f . Right) y
varSubsM f (Tensor x y) = Tensor x' y' where
x' = varSubsM (unsafeFromLeft . f . Left ) x
y' = varSubsM (unsafeFromRight . f . Right) y
termSubsM f (Tensor x y, c) = (Tensor x' y', c*d*e) where
(x',d) = termSubsM (f . Left ) (x,1)
(y',e) = termSubsM (f . Right) (y,1)
distEither :: (Either a b, c) -> Either (a,c) (b,c)
distEither (ei, z) = case ei of
Left x -> Left (x,z)
Right y -> Right (y,z)
unsafeFromLeft :: Either a b -> a
unsafeFromLeft ei = case ei of
Left x -> x
Right _ -> error "unsafeFromLeft: Right"
unsafeFromRight :: Either a b -> b
unsafeFromRight ei = case ei of
Left _ -> error "unsafeFromRight: Left"
Right y -> y