module Multilinear.Generic (
Tensor(..), (!), mergeScalars,
isScalar, isSimple, isFiniteTensor, isInfiniteTensor,
dot, _elemByElem, contractionErr, tensorIndex, _standardize
) where
import Control.DeepSeq
import Data.Bits
import Data.Foldable
import Data.List
import Data.Maybe
import qualified Data.Vector as Boxed
import GHC.Generics
import Multilinear.Class as Multilinear
import qualified Multilinear.Index as Index
import qualified Multilinear.Index.Finite as Finite
import qualified Multilinear.Index.Infinite as Infinite
incompatibleTypes :: String
incompatibleTypes = "Incompatible tensor types!"
scalarIndices :: String
scalarIndices = "Scalar has no indices!"
infiniteIndex :: String
infiniteIndex = "Index is infinitely-dimensional!"
infiniteTensor :: String
infiniteTensor = "This tensor is infinitely-dimensional and cannot be printed!"
indexNotFound :: String
indexNotFound = "This tensor has not such index!"
data Tensor a =
Scalar {
scalarVal :: a
} |
SimpleFinite {
tensorFiniteIndex :: Finite.Index,
tensorScalars :: Boxed.Vector a
} |
FiniteTensor {
tensorFiniteIndex :: Finite.Index,
tensorsFinite :: Boxed.Vector (Tensor a)
} |
InfiniteTensor {
tensorInfiniteIndex :: Infinite.Index,
tensorsInfinite :: [Tensor a]
} |
Err {
errMessage :: String
} deriving (Eq, Generic)
{-# INLINE isScalar #-}
isScalar :: Tensor a -> Bool
isScalar x = case x of
Scalar _ -> True
_ -> False
{-# INLINE isSimple #-}
isSimple :: Tensor a -> Bool
isSimple x = case x of
SimpleFinite _ _ -> True
_ -> False
{-# INLINE isFiniteTensor #-}
isFiniteTensor :: Tensor a -> Bool
isFiniteTensor x = case x of
FiniteTensor _ _ -> True
_ -> False
{-# INLINE isInfiniteTensor #-}
isInfiniteTensor :: Tensor a -> Bool
isInfiniteTensor x = case x of
InfiniteTensor _ _ -> True
_ -> False
{-# INLINE isErrTensor #-}
isErrTensor :: Tensor a -> Bool
isErrTensor x = case x of
Err _ -> True
_ -> False
{-# INLINE tensorIndex #-}
tensorIndex :: Tensor a -> Index.TIndex
tensorIndex x = case x of
Scalar _ -> error scalarIndices
SimpleFinite i _ -> Index.toTIndex i
FiniteTensor i _ -> Index.toTIndex i
InfiniteTensor i _ -> Index.toTIndex i
Err msg -> error msg
{-# INLINE isEmptyTensor #-}
isEmptyTensor :: Tensor a -> Bool
isEmptyTensor x = case x of
Scalar _ -> False
SimpleFinite _ ts -> Boxed.null ts
FiniteTensor _ ts -> Boxed.null ts
InfiniteTensor _ ts -> null ts
Err _ -> False
{-# INLINE firstElem #-}
firstElem :: Tensor a -> a
firstElem x = case x of
Scalar val -> val
SimpleFinite _ ts -> Boxed.head ts
FiniteTensor _ ts -> firstElem $ Boxed.head ts
InfiniteTensor _ ts -> firstElem $ head ts
Err msg -> error msg
{-# INLINE firstTensor #-}
firstTensor :: Tensor a -> Tensor a
firstTensor x = case x of
FiniteTensor _ ts -> Boxed.head ts
InfiniteTensor _ ts -> Data.List.head ts
_ -> x
{-# INLINE (!) #-}
(!) :: Tensor a
-> Int
-> Tensor a
t ! i = case t of
Scalar _ -> Err scalarIndices
Err msg -> Err msg
SimpleFinite ind ts -> if i >= Finite.indexSize ind then error ("Index + " ++ show ind ++ " out of bonds!") else Scalar $ ts Boxed.! i
FiniteTensor ind ts -> if i >= Finite.indexSize ind then error ("Index + " ++ show ind ++ " out of bonds!") else ts Boxed.! i
InfiniteTensor _ ts -> ts !! i
instance NFData a => NFData (Tensor a)
_standardize :: Num a => Tensor a -> Tensor a
_standardize tens = foldr' (\i t -> if Index.isContravariant i then t <<<| Index.indexName i else t) tens $ indices tens
instance (
Show a, Num a
) => Show (Tensor a) where
show = show' . _standardize . _mergeErr
where
show' x = case x of
Scalar v -> show v
SimpleFinite index ts -> show index ++ "S: " ++ case index of
Finite.Contravariant _ _ -> _showVertical ts
_ -> _showHorizontal ts
FiniteTensor index ts -> show index ++ "T: " ++ case index of
Finite.Contravariant _ _ -> _showVertical ts
_ -> _showHorizontal ts
InfiniteTensor _ _ -> show infiniteTensor
Err msg -> show msg
_mergeErr x = case x of
Err msg -> Err msg
FiniteTensor _ ts ->
let err = Data.List.find isErrTensor (_mergeErr <$> ts)
in fromMaybe x err
_ -> x
_showVertical :: (Show a, Foldable c) => c a -> String
_showVertical container =
"\n" ++ tail (foldl' (\string e -> string ++ "\n |" ++ show e) "" container)
_showHorizontal :: (Show a, Foldable c) => c a -> String
_showHorizontal container =
"[" ++ tail (foldl' (\string e -> string ++ "," ++ show e) "" container) ++ "]"
instance Functor Tensor where
{-# INLINE fmap #-}
fmap f x = case x of
Scalar v -> Scalar $ f v
SimpleFinite index ts -> SimpleFinite index (f <$> ts)
FiniteTensor index ts -> FiniteTensor index $ fmap (fmap f) ts
InfiniteTensor index ts -> InfiniteTensor index $ fmap (fmap f) ts
Err msg -> Err msg
instance (
Ord a
) => Ord (Tensor a) where
{-# INLINE (<=) #-}
Err msg1 <= Err msg2 = msg1 <= msg2
Err _ <= _ = True
_ <= Err _ = False
Scalar x1 <= Scalar x2 = x1 <= x2
Scalar _ <= _ = True
_ <= Scalar _ = False
SimpleFinite _ ts1 <= SimpleFinite _ ts2 = ts1 <= ts2
FiniteTensor _ ts1 <= FiniteTensor _ ts2 = ts1 <= ts2
InfiniteTensor _ ts1 <= InfiniteTensor _ ts2 = ts1 <= ts2
FiniteTensor _ _ <= SimpleFinite _ _ = False
SimpleFinite _ _ <= FiniteTensor _ _ = True
InfiniteTensor _ _ <= FiniteTensor _ _ = False
FiniteTensor _ _ <= InfiniteTensor _ _ = True
InfiniteTensor _ _ <= SimpleFinite _ _ = False
SimpleFinite _ _ <= InfiniteTensor _ _ = True
{-# INLINE mergeScalars #-}
mergeScalars :: Tensor a -> Tensor a
mergeScalars x = case x of
(FiniteTensor index1 ts1) -> case ts1 Boxed.! 0 of
Scalar _ -> SimpleFinite index1 (scalarVal <$> ts1)
_ -> FiniteTensor index1 $ mergeScalars <$> ts1
_ -> x
{-# INLINE _elemByElem' #-}
_elemByElem' :: Num a
=> Tensor a
-> Tensor a
-> (a -> a -> a)
-> (Tensor a -> Tensor a -> Tensor a)
-> Tensor a
_elemByElem' (Scalar x1) (Scalar x2) f _ = Scalar $ f x1 x2
_elemByElem' (Scalar x) t f _ = (x `f`) <$> t
_elemByElem' t (Scalar x) f _ = (`f` x) <$> t
_elemByElem' t1@(FiniteTensor index1 v1) t2@(FiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| Index.indexName index1 `Data.List.elem` indicesNames t2 =
FiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
| otherwise = FiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' t1@(InfiniteTensor index1 v1) t2@(InfiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| Index.indexName index1 `Data.List.elem` indicesNames t2 =
InfiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
| otherwise = InfiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' t1@(SimpleFinite index1 v1) t2@(SimpleFinite index2 _) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = FiniteTensor index1 $ (\x -> f x <$> t2) <$> v1
_elemByElem' t1@(FiniteTensor index1 v1) t2@(InfiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| Index.indexName index1 `Data.List.elem` indicesNames t2 =
InfiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
| otherwise = FiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' t1@(InfiniteTensor index1 v1) t2@(FiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| Index.indexName index1 `Data.List.elem` indicesNames t2 =
FiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
| otherwise = InfiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' t1@(SimpleFinite index1 _) t2@(FiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = FiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
_elemByElem' t1@(FiniteTensor index1 v1) t2@(SimpleFinite index2 _) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = FiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' t1@(SimpleFinite index1 _) t2@(InfiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = InfiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
_elemByElem' t1@(InfiniteTensor index1 v1) t2@(SimpleFinite index2 _) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = InfiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' (Err msg) _ _ _ = Err msg
_elemByElem' _ (Err msg) _ _ = Err msg
{-# INLINE _elemByElem #-}
_elemByElem :: Num a
=> Tensor a
-> Tensor a
-> (a -> a -> a)
-> (Tensor a -> Tensor a -> Tensor a)
-> Tensor a
_elemByElem t1 t2 f op =
let commonIndices = filter (`Data.List.elem` indicesNames t2) $ indicesNames t1
t1' = foldl' (|>>>) t1 commonIndices
t2' = foldl' (|>>>) t2 commonIndices
in mergeScalars $ _elemByElem' t1' t2' f op
{-# INLINE zipT #-}
zipT :: Num a
=> (Tensor a -> Tensor a -> Tensor a)
-> (Tensor a -> a -> Tensor a)
-> (a -> Tensor a -> Tensor a)
-> (a -> a -> a)
-> Tensor a
-> Tensor a
-> Tensor a
zipT _ _ _ f (SimpleFinite index1 v1) (SimpleFinite index2 v2) =
if index1 == index2 then SimpleFinite index1 $ Boxed.zipWith f v1 v2 else Err incompatibleTypes
zipT f _ _ _ (FiniteTensor index1 v1) (FiniteTensor index2 v2) =
if index1 == index2 then FiniteTensor index1 $ Boxed.zipWith f v1 v2 else Err incompatibleTypes
zipT f _ _ _ (InfiniteTensor index1 v1) (InfiniteTensor index2 v2) =
if index1 == index2 then InfiniteTensor index1 $ Data.List.zipWith f v1 v2 else Err incompatibleTypes
zipT f _ _ _ (InfiniteTensor _ v1) (FiniteTensor index2 v2) =
FiniteTensor index2 $ Boxed.zipWith f (Boxed.fromList $ take (Boxed.length v2) v1) v2
zipT f _ _ _ (FiniteTensor index1 v1) (InfiniteTensor _ v2) =
FiniteTensor index1 $ Boxed.zipWith f v1 (Boxed.fromList $ take (Boxed.length v1) v2)
zipT _ f _ _ (FiniteTensor index1 v1) (SimpleFinite index2 v2) =
if index1 == index2 then FiniteTensor index1 $ Boxed.zipWith f v1 v2 else Err incompatibleTypes
zipT _ _ f _ (SimpleFinite index1 v1) (FiniteTensor index2 v2) =
if index1 == index2 then FiniteTensor index1 $ Boxed.zipWith f v1 v2 else Err incompatibleTypes
zipT _ f _ _ (InfiniteTensor _ v1) (SimpleFinite index2 v2) =
FiniteTensor index2 $ Boxed.zipWith f (Boxed.fromList $ take (Boxed.length v2) v1) v2
zipT _ _ f _ (SimpleFinite index1 v1) (InfiniteTensor _ v2) =
FiniteTensor index1 $ Boxed.zipWith f v1 (Boxed.fromList $ take (Boxed.length v1) v2)
zipT _ _ _ _ (Err msg) _ = Err msg
zipT _ _ _ _ _ (Err msg) = Err msg
zipT _ _ _ _ _ _ = Err scalarIndices
{-# INLINE dot #-}
dot :: Num a
=> Tensor a
-> Tensor a
-> Tensor a
dot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 =
Scalar $ Boxed.sum $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (SimpleFinite i1@(Finite.Contravariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 =
SimpleFinite i1 $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Covariant count2 _) ts2')
| count1 == count2 =
SimpleFinite i1 $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Boxed.sum $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Contravariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = FiniteTensor i1 $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Covariant count2 _) ts2')
| count1 == count2 = FiniteTensor i1 $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Boxed.sum $ Boxed.zipWith (*.) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (SimpleFinite i1@(Finite.Contravariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = FiniteTensor i1 $ Boxed.zipWith (*.) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Covariant count2 _) ts2')
| count1 == count2 = FiniteTensor i1 $ Boxed.zipWith (*.) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Boxed.sum $ Boxed.zipWith (.*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Contravariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = FiniteTensor i1 $ Boxed.zipWith (.*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Covariant count2 _) ts2')
| count1 == count2 = FiniteTensor i1 $ Boxed.zipWith (.*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (SimpleFinite (Finite.Covariant count1 _) ts1') (InfiniteTensor (Infinite.Contravariant _) ts2') =
Boxed.sum $ Boxed.zipWith (*.) ts1' (Boxed.fromList $ take count1 ts2')
dot (SimpleFinite (Finite.Contravariant count1 _) ts1') (InfiniteTensor i2@(Infinite.Contravariant _) ts2') =
InfiniteTensor i2 $ Boxed.toList $ Boxed.zipWith (*.) ts1' (Boxed.fromList $ take count1 ts2')
dot (SimpleFinite (Finite.Covariant count1 _) ts1') (InfiniteTensor i2@(Infinite.Covariant _) ts2') =
InfiniteTensor i2 $ Boxed.toList $ Boxed.zipWith (*.) ts1' (Boxed.fromList $ take count1 ts2')
dot (InfiniteTensor (Infinite.Covariant _) ts1') (SimpleFinite (Finite.Contravariant count2 _) ts2') =
Boxed.sum $ Boxed.zipWith (.*) (Boxed.fromList $ take count2 ts1') ts2'
dot (InfiniteTensor i1@(Infinite.Contravariant _) ts1') (SimpleFinite (Finite.Contravariant count2 _) ts2') =
InfiniteTensor i1 $ Boxed.toList $ Boxed.zipWith (.*) (Boxed.fromList $ take count2 ts1') ts2'
dot (InfiniteTensor i1@(Infinite.Covariant _) ts1') (SimpleFinite (Finite.Covariant count2 _) ts2') =
InfiniteTensor i1 $ Boxed.toList $ Boxed.zipWith (.*) (Boxed.fromList $ take count2 ts1') ts2'
dot (FiniteTensor (Finite.Covariant count1 _) ts1') (InfiniteTensor (Infinite.Contravariant _) ts2') =
Boxed.sum $ Boxed.zipWith (*) ts1' (Boxed.fromList $ take count1 ts2')
dot (FiniteTensor (Finite.Contravariant count1 _) ts1') (InfiniteTensor i2@(Infinite.Contravariant _) ts2') =
InfiniteTensor i2 $ Boxed.toList $ Boxed.zipWith (*) ts1' (Boxed.fromList $ take count1 ts2')
dot (FiniteTensor (Finite.Covariant count1 _) ts1') (InfiniteTensor i2@(Infinite.Covariant _) ts2') =
InfiniteTensor i2 $ Boxed.toList $ Boxed.zipWith (*) ts1' (Boxed.fromList $ take count1 ts2')
dot (InfiniteTensor (Infinite.Covariant _) ts1') (FiniteTensor (Finite.Contravariant count2 _) ts2') =
Boxed.sum $ Boxed.zipWith (*) (Boxed.fromList $ take count2 ts1') ts2'
dot (InfiniteTensor i1@(Infinite.Contravariant _) ts1') (FiniteTensor (Finite.Contravariant count2 _) ts2') =
InfiniteTensor i1 $ Boxed.toList $ Boxed.zipWith (*) (Boxed.fromList $ take count2 ts1') ts2'
dot (InfiniteTensor i1@(Infinite.Covariant _) ts1') (FiniteTensor (Finite.Covariant count2 _) ts2') =
InfiniteTensor i1 $ Boxed.toList $ Boxed.zipWith (*) (Boxed.fromList $ take count2 ts1') ts2'
dot t1' t2' = contractionErr (tensorIndex t1') (tensorIndex t2')
{-# INLINE bitDot #-}
bitDot :: (
Num a, Bits a
) => Tensor a
-> Tensor a
-> Tensor a
bitDot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Data.Foldable.foldl' (.|.) 0 $ Boxed.zipWith (.&.) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
bitDot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 =
let dotProduct v1 v2 = Data.Foldable.foldl' (.|.) 0 $ Boxed.zipWith (.&.) v1 v2
in Scalar $ dotProduct ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
bitDot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Data.Foldable.foldl' (.|.) 0 $ Boxed.zipWith (\e t -> (e .&.) <$> t) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
bitDot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Data.Foldable.foldl' (.|.) 0 $ Boxed.zipWith (\t e -> (.&. e) <$> t) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
bitDot (SimpleFinite (Finite.Covariant count1 _) ts1') (InfiniteTensor (Infinite.Contravariant _) ts2') =
Data.Foldable.foldl' (.|.) 0 $ Boxed.zipWith (\e t -> (e .&.) <$> t) ts1' (Boxed.fromList $ take count1 ts2')
bitDot (InfiniteTensor (Infinite.Covariant _) ts1') (SimpleFinite (Finite.Contravariant count2 _) ts2') =
Data.Foldable.foldl' (.|.) 0 $ Boxed.zipWith (\t e -> (.&. e) <$> t) (Boxed.fromList $ take count2 ts1') ts2'
bitDot (FiniteTensor (Finite.Covariant count1 _) ts1') (InfiniteTensor (Infinite.Contravariant _) ts2') =
Boxed.sum $ Boxed.zipWith (*) ts1' (Boxed.fromList $ take count1 ts2')
bitDot (InfiniteTensor (Infinite.Covariant _) ts1') (FiniteTensor (Finite.Contravariant count2 _) ts2') =
Boxed.sum $ Boxed.zipWith (*) (Boxed.fromList $ take count2 ts1') ts2'
bitDot t1' t2' = contractionErr (tensorIndex t1') (tensorIndex t2')
{-# INLINE contractionErr #-}
contractionErr :: Index.TIndex
-> Index.TIndex
-> Tensor a
contractionErr i1' i2' = Err $
"Tensor product: " ++ incompatibleTypes ++
" - index1 is " ++ show i1' ++
" and index2 is " ++ show i2'
instance Num a => Num (Tensor a) where
{-# INLINE (+) #-}
t1 + t2 = _elemByElem t1 t2 (+) $ zipT (+) (.+) (+.) (+)
{-# INLINE (-) #-}
t1 - t2 = _elemByElem t1 t2 (-) $ zipT (-) (.-) (-.) (-)
{-# INLINE (*) #-}
t1 * t2 = _elemByElem t1 t2 (*) dot
{-# INLINE abs #-}
abs t = abs <$> t
{-# INLINE signum #-}
signum t = signum <$> t
{-# INLINE fromInteger #-}
fromInteger x = Scalar $ fromInteger x
instance (
Num a, Bits a
) => Bits (Tensor a) where
{-# INLINE (.|.) #-}
t1 .|. t2 = _elemByElem t1 t2 (.|.) $ zipT (.|.) (\t e -> (.|. e) <$> t) (\e t -> (e .|.) <$> t) (.|.)
{-# INLINE (.&.) #-}
t1 .&. t2 = _elemByElem t1 t2 (.&.) bitDot
{-# INLINE xor #-}
t1 `xor` t2 = _elemByElem t1 t2 xor $ zipT xor (\t e -> (`xor` e) <$> t) (\e t -> (e `xor`) <$> t) xor
{-# INLINE complement #-}
complement = Multilinear.map complement
{-# INLINE shift #-}
shift t n = Multilinear.map (`shift` n) t
{-# INLINE rotate #-}
rotate t n = Multilinear.map (`rotate` n) t
{-# INLINE bitSize #-}
bitSize (Scalar x) = fromMaybe (-1) $ bitSizeMaybe x
bitSize (Err _) = -1
bitSize t =
if isEmptyTensor t
then (-1)
else fromMaybe (-1) $ bitSizeMaybe $ firstElem t
{-# INLINE bitSizeMaybe #-}
bitSizeMaybe (Scalar x) = bitSizeMaybe x
bitSizeMaybe (Err _) = Nothing
bitSizeMaybe t =
if isEmptyTensor t
then Nothing
else bitSizeMaybe $ firstElem t
{-# INLINE isSigned #-}
isSigned (Scalar x) = isSigned x
isSigned (Err _) = False
isSigned t =
not (isEmptyTensor t) &&
isSigned (firstElem t)
{-# INLINE bit #-}
bit i = Scalar (bit i)
{-# INLINE testBit #-}
testBit _ _ = False
{-# INLINE popCount #-}
popCount = popCountDefault
instance Fractional a => Fractional (Tensor a) where
{-# INLINE (/) #-}
Scalar x1 / Scalar x2 = Scalar $ x1 / x2
Scalar x1 / t2 = (x1 /) <$> t2
t1 / Scalar x2 = (/ x2) <$> t1
Err msg / _ = Err msg
_ / Err msg = Err msg
_ / _ = Err "TODO"
{-# INLINE fromRational #-}
fromRational x = Scalar $ fromRational x
instance Floating a => Floating (Tensor a) where
{-# INLINE pi #-}
pi = Scalar pi
{-# INLINE exp #-}
exp t = exp <$> t
{-# INLINE log #-}
log t = log <$> t
{-# INLINE sin #-}
sin t = sin <$> t
{-# INLINE cos #-}
cos t = cos <$> t
{-# INLINE asin #-}
asin t = asin <$> t
{-# INLINE acos #-}
acos t = acos <$> t
{-# INLINE atan #-}
atan t = atan <$> t
{-# INLINE sinh #-}
sinh t = sinh <$> t
{-# INLINE cosh #-}
cosh t = cosh <$> t
{-# INLINE asinh #-}
asinh t = acosh <$> t
{-# INLINE acosh #-}
acosh t = acosh <$> t
{-# INLINE atanh #-}
atanh t = atanh <$> t
instance Num a => Multilinear Tensor a where
{-# INLINE (.+) #-}
t .+ x = (+x) <$> t
{-# INLINE (.-) #-}
t .- x = (\p -> p - x) <$> t
{-# INLINE (.*) #-}
t .* x = (*x) <$> t
{-# INLINE (+.) #-}
x +. t = (x+) <$> t
{-# INLINE (-.) #-}
x -. t = (x-) <$> t
{-# INLINE (*.) #-}
x *. t = (x*) <$> t
{-# INLINE (.+.) #-}
t1 .+. t2 = _elemByElem t1 t2 (+) $ zipT (+) (.+) (+.) (+)
{-# INLINE (.-.) #-}
t1 .-. t2 = _elemByElem t1 t2 (-) $ zipT (+) (.+) (+.) (+)
{-# INLINE (.*.) #-}
t1 .*. t2 = _elemByElem t1 t2 (+) dot
{-# INLINE indices #-}
indices x = case x of
Scalar _ -> []
FiniteTensor i ts -> Index.toTIndex i : indices (head $ toList ts)
InfiniteTensor i ts -> Index.toTIndex i : indices (head ts)
SimpleFinite i _ -> [Index.toTIndex i]
Err _ -> []
{-# INLINE order #-}
order x = case x of
Scalar _ -> (0,0)
SimpleFinite index _ -> case index of
Finite.Contravariant _ _ -> (1,0)
Finite.Covariant _ _ -> (0,1)
Finite.Indifferent _ _ -> (0,0)
Err _ -> (-1,-1)
_ -> let (cnvr, covr) = order $ firstTensor x
in case tensorIndex x of
Index.Contravariant _ _ -> (cnvr+1,covr)
Index.Covariant _ _ -> (cnvr,covr+1)
Index.Indifferent _ _ -> (cnvr,covr)
{-# INLINE size #-}
size t iname = case t of
Scalar _ -> error scalarIndices
SimpleFinite index _ ->
if Index.indexName index == iname
then Finite.indexSize index
else error indexNotFound
FiniteTensor index _ ->
if Index.indexName index == iname
then Finite.indexSize index
else size (firstTensor t) iname
InfiniteTensor _ _ -> error infiniteIndex
Err msg -> error msg
{-# INLINE ($|) #-}
Scalar x $| _ = Scalar x
SimpleFinite (Finite.Contravariant isize _) ts $| (u:_, _) = SimpleFinite (Finite.Contravariant isize [u]) ts
SimpleFinite (Finite.Covariant isize _) ts $| (_, d:_) = SimpleFinite (Finite.Covariant isize [d]) ts
FiniteTensor (Finite.Contravariant isize _) ts $| (u:us, ds) = FiniteTensor (Finite.Contravariant isize [u]) $ ($| (us,ds)) <$> ts
FiniteTensor (Finite.Covariant isize _) ts $| (us, d:ds) = FiniteTensor (Finite.Covariant isize [d]) $ ($| (us,ds)) <$> ts
InfiniteTensor (Infinite.Contravariant _) ts $| (u:us, ds) = InfiniteTensor (Infinite.Contravariant [u]) $ ($| (us,ds)) <$> ts
InfiniteTensor (Infinite.Covariant _) ts $| (us, d:ds) = InfiniteTensor (Infinite.Covariant [d]) $ ($| (us,ds)) <$> ts
Err msg $| _ = Err msg
t $| _ = t
{-# INLINE (/\) #-}
Scalar x /\ _ = Scalar x
FiniteTensor index ts /\ n
| Index.indexName index == n =
FiniteTensor (Finite.Contravariant (Finite.indexSize index) n) $ (/\ n) <$> ts
| otherwise =
FiniteTensor index $ (/\ n) <$> ts
InfiniteTensor index ts /\ n
| Index.indexName index == n =
InfiniteTensor (Infinite.Contravariant n) $ (/\ n) <$> ts
| otherwise =
InfiniteTensor index $ (/\ n) <$> ts
t1@(SimpleFinite index ts) /\ n
| Index.indexName index == n =
SimpleFinite (Finite.Contravariant (Finite.indexSize index) n) ts
| otherwise = t1
Err msg /\ _ = Err msg
{-# INLINE (\/) #-}
Scalar x \/ _ = Scalar x
FiniteTensor index ts \/ n
| Index.indexName index == n =
FiniteTensor (Finite.Covariant (Finite.indexSize index) n) $ (\/ n) <$> ts
| otherwise =
FiniteTensor index $ (\/ n) <$> ts
InfiniteTensor index ts \/ n
| Index.indexName index == n =
InfiniteTensor (Infinite.Covariant n) $ (\/ n) <$> ts
| otherwise =
InfiniteTensor index $ (\/ n) <$> ts
t1@(SimpleFinite index ts) \/ n
| Index.indexName index == n =
SimpleFinite (Finite.Covariant (Finite.indexSize index) n) ts
| otherwise = t1
Err msg \/ _ = Err msg
{-# INLINE transpose #-}
transpose (Scalar x) = Scalar x
transpose (FiniteTensor (Finite.Covariant count name) ts) =
FiniteTensor (Finite.Contravariant count name) (Multilinear.transpose <$> ts)
transpose (FiniteTensor (Finite.Contravariant count name) ts) =
FiniteTensor (Finite.Covariant count name) (Multilinear.transpose <$> ts)
transpose (FiniteTensor (Finite.Indifferent count name) ts) =
FiniteTensor (Finite.Indifferent count name) (Multilinear.transpose <$> ts)
transpose (InfiniteTensor (Infinite.Covariant name) ts) =
InfiniteTensor (Infinite.Contravariant name) (Multilinear.transpose <$> ts)
transpose (InfiniteTensor (Infinite.Contravariant name) ts) =
InfiniteTensor (Infinite.Covariant name) (Multilinear.transpose <$> ts)
transpose (InfiniteTensor (Infinite.Indifferent name) ts) =
InfiniteTensor (Infinite.Indifferent name) (Multilinear.transpose <$> ts)
transpose (SimpleFinite (Finite.Covariant count name) ts) =
SimpleFinite (Finite.Contravariant count name) ts
transpose (SimpleFinite (Finite.Contravariant count name) ts) =
SimpleFinite (Finite.Covariant count name) ts
transpose (SimpleFinite (Finite.Indifferent count name) ts) =
SimpleFinite (Finite.Indifferent count name) ts
transpose (Err msg) = Err msg
{-# INLINE shiftRight #-}
Err msg `shiftRight` _ = Err msg
Scalar x `shiftRight` _ = Scalar x
t1@(SimpleFinite _ _) `shiftRight` _ = t1
t1@(FiniteTensor index1 ts1) `shiftRight` ind
| Data.List.length (indicesNames t1) > 1 && Index.indexName index1 /= ind =
FiniteTensor index1 $ (|>> ind) <$> ts1
| Data.List.length (indicesNames t1) > 1 && Index.indexName index1 == ind =
let index2 = tensorFiniteIndex (ts1 Boxed.! 0)
dane = if isSimple (ts1 Boxed.! 0)
then (Scalar <$>) <$> (tensorScalars <$> ts1)
else tensorsFinite <$> ts1
daneList = Boxed.toList <$> Boxed.toList dane
transposedList = Data.List.transpose daneList
transposed = Boxed.fromList <$> Boxed.fromList transposedList
in mergeScalars $ FiniteTensor index2 $ FiniteTensor index1 <$> transposed
| otherwise = t1
t1@(InfiniteTensor index1 ts1) `shiftRight` ind
| Data.List.length (indicesNames t1) > 1 && Index.indexName index1 /= ind =
InfiniteTensor index1 $ (|>> ind) <$> ts1
| Data.List.length (indicesNames t1) > 1 && Index.indexName index1 == ind =
let index2 = tensorInfiniteIndex (head ts1)
dane = if isSimple (head ts1)
then (Scalar <$>) <$> (Boxed.toList . tensorScalars <$> ts1)
else tensorsInfinite <$> ts1
transposed = Data.List.transpose dane
in mergeScalars $ InfiniteTensor index2 $ InfiniteTensor index1 <$> transposed
| otherwise = t1
instance Num a => Accessible Tensor a where
{-# INLINE el #-}
el (Scalar x) _ = Scalar x
el t1@(SimpleFinite index1 _) (inds,vals) =
let indvals = zip inds vals
val = Data.List.find (\(n,_) -> [n] == Index.indexName index1) indvals
in if isJust val
then t1 ! snd (fromJust val)
else t1
el t1@(FiniteTensor index1 v1) (inds,vals) =
let indvals = zip inds vals
val = Data.List.find (\(n,_) -> [n] == Index.indexName index1) indvals
indvals1 = Data.List.filter (\(n,_) -> [n] /= Index.indexName index1) indvals
inds1 = Data.List.map fst indvals1
vals1 = Data.List.map snd indvals1
in if isJust val
then el (t1 ! snd (fromJust val)) (inds1,vals1)
else FiniteTensor index1 $ (\t -> el t (inds,vals)) <$> v1
el t1@(InfiniteTensor index1 v1) (inds,vals) =
let indvals = zip inds vals
val = Data.List.find (\(n,_) -> [n] == Index.indexName index1) indvals
indvals1 = Data.List.filter (\(n,_) -> [n] /= Index.indexName index1) indvals
inds1 = Data.List.map fst indvals1
vals1 = Data.List.map snd indvals1
in if isJust val
then el (t1 ! snd (fromJust val)) (inds1,vals1)
else InfiniteTensor index1 $ (\t -> el t (inds,vals)) <$> v1
el (Err msg) _ = Err msg
{-# INLINE iMap #-}
iMap f t = iMap' t zeroList
where
zeroList = 0:zeroList
iMap' (Scalar x) inds =
Scalar $ f inds x
iMap' (SimpleFinite index ts) inds =
SimpleFinite index $ Boxed.imap (\i e -> f (inds ++ [i]) e) ts
iMap' (FiniteTensor index ts) inds =
FiniteTensor index $ Boxed.imap (\i e -> iMap' e (inds ++ [i])) ts
iMap' (InfiniteTensor index ts) inds =
InfiniteTensor index $ (\tind -> iMap' (fst tind) $ inds ++ [snd tind]) <$> zip ts [0..]
iMap' (Err msg) _ =
Err msg