{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ViewPatterns #-}
module Language.Egison.Tensor
( TensorComponent (..)
, tref
, enumTensorIndices
, tTranspose
, tTranspose'
, tFlipIndices
, appendDF
, removeDF
, tMap
, tMap2
, tProduct
, tContract
, tContract'
, tConcat'
) where
import Prelude hiding (foldr, mappend, mconcat)
import Control.Monad.Except (mzero, throwError, zipWithM)
import Data.List (delete, intersect, partition, (\\))
import qualified Data.Vector as V
import Control.Egison
import qualified Control.Egison as M
import Language.Egison.Data
import Language.Egison.Data.Utils
import Language.Egison.IExpr (Index (..), extractSupOrSubIndex)
import Language.Egison.Math
import Language.Egison.RState
data IndexM m = IndexM m
instance M.Matcher m a => M.Matcher (IndexM m) (Index a)
sub :: M.Matcher m a => M.Pattern (PP a) (IndexM m) (Index a) a
sub :: Pattern (PP a) (IndexM m) (Index a) a
sub PP a
_ IndexM m
_ (Sub a
a) = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
sub PP a
_ IndexM m
_ Index a
_ = [a]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
subM :: M.Matcher m a => IndexM m -> Index a -> m
subM :: IndexM m -> Index a -> m
subM (IndexM m
m) Index a
_ = m
m
sup :: M.Matcher m a => M.Pattern (PP a) (IndexM m) (Index a) a
sup :: Pattern (PP a) (IndexM m) (Index a) a
sup PP a
_ IndexM m
_ (Sup a
a) = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
sup PP a
_ IndexM m
_ Index a
_ = [a]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
supM :: M.Matcher m a => IndexM m -> Index a -> m
supM :: IndexM m -> Index a -> m
supM (IndexM m
m) Index a
_ = m
m
supsub :: M.Matcher m a => M.Pattern (PP a) (IndexM m) (Index a) a
supsub :: Pattern (PP a) (IndexM m) (Index a) a
supsub PP a
_ IndexM m
_ (SupSub a
a) = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
supsub PP a
_ IndexM m
_ Index a
_ = [a]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
supsubM :: M.Matcher m a => IndexM m -> Index a -> m
supsubM :: IndexM m -> Index a -> m
supsubM (IndexM m
m) Index a
_ = m
m
class TensorComponent a b | a -> b where
fromTensor :: Tensor b -> EvalM a
toTensor :: a -> EvalM (Tensor b)
instance TensorComponent EgisonValue EgisonValue where
fromTensor :: Tensor EgisonValue -> EvalM EgisonValue
fromTensor t :: Tensor EgisonValue
t@Tensor{} = EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ Tensor EgisonValue -> EgisonValue
TensorData Tensor EgisonValue
t
fromTensor (Scalar EgisonValue
x) = EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return EgisonValue
x
toTensor :: EgisonValue -> EvalM (Tensor EgisonValue)
toTensor (TensorData Tensor EgisonValue
t) = Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor EgisonValue
t
toTensor EgisonValue
x = Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor EgisonValue -> EvalM (Tensor EgisonValue))
-> Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall a b. (a -> b) -> a -> b
$ EgisonValue -> Tensor EgisonValue
forall a. a -> Tensor a
Scalar EgisonValue
x
instance TensorComponent WHNFData ObjectRef where
fromTensor :: Tensor ObjectRef -> EvalM WHNFData
fromTensor t :: Tensor ObjectRef
t@Tensor{} = WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return (WHNFData -> EvalM WHNFData) -> WHNFData -> EvalM WHNFData
forall a b. (a -> b) -> a -> b
$ Tensor ObjectRef -> WHNFData
ITensor Tensor ObjectRef
t
fromTensor (Scalar ObjectRef
x) = ObjectRef -> EvalM WHNFData
evalRef ObjectRef
x
toTensor :: WHNFData -> EvalM (Tensor ObjectRef)
toTensor (ITensor Tensor ObjectRef
t) = Tensor ObjectRef -> EvalM (Tensor ObjectRef)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor ObjectRef
t
toTensor WHNFData
x = ObjectRef -> Tensor ObjectRef
forall a. a -> Tensor a
Scalar (ObjectRef -> Tensor ObjectRef)
-> StateT EvalState (ExceptT EgisonError RuntimeM) ObjectRef
-> EvalM (Tensor ObjectRef)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WHNFData
-> StateT EvalState (ExceptT EgisonError RuntimeM) ObjectRef
newEvaluatedObjectRef WHNFData
x
tShape :: Tensor a -> Shape
tShape :: Tensor a -> Shape
tShape (Tensor Shape
ns Vector a
_ [Index EgisonValue]
_) = Shape
ns
tShape (Scalar a
_) = []
tToVector :: Tensor a -> V.Vector a
tToVector :: Tensor a -> Vector a
tToVector (Tensor Shape
_ Vector a
xs [Index EgisonValue]
_) = Vector a
xs
tToVector (Scalar a
x) = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a
x]
tIndex :: Tensor a -> [Index EgisonValue]
tIndex :: Tensor a -> [Index EgisonValue]
tIndex (Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) = [Index EgisonValue]
js
tIndex (Scalar a
_) = []
tIntRef' :: Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' :: Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
i (Tensor [Integer
n] Vector a
xs [Index EgisonValue]
_) =
if Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
i Bool -> Bool -> Bool
&& Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n
then Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a))
-> (a -> Tensor a) -> a -> EvalM (Tensor a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Tensor a
forall a. a -> Tensor a
Scalar (a -> EvalM (Tensor a)) -> a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
else (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (Integer -> Integer -> CallStack -> EgisonError
TensorIndexOutOfBounds Integer
i Integer
n)
tIntRef' Integer
i (Tensor (Integer
n:Shape
ns) Vector a
xs [Index EgisonValue]
js) =
if Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
i Bool -> Bool -> Bool
&& Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n
then let w :: Int
w = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Shape -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
ns)
ys :: Vector a
ys = Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.take Int
w (Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.drop (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)) Vector a
xs)
in Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector a
ys ([Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a]
cdr [Index EgisonValue]
js)
else (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (Integer -> Integer -> CallStack -> EgisonError
TensorIndexOutOfBounds Integer
i Integer
n)
tIntRef' Integer
_ Tensor a
_ = EgisonError -> EvalM (Tensor a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM (Tensor a))
-> EgisonError -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"More indices than the order of the tensor"
tIntRef :: [Integer] -> Tensor a -> EvalM (Tensor a)
tIntRef :: Shape -> Tensor a -> EvalM (Tensor a)
tIntRef [] (Tensor [] Vector a
xs [Index EgisonValue]
_)
| Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ a -> Tensor a
forall a. a -> Tensor a
Scalar (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
0)
| Bool
otherwise = (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> CallStack -> EgisonError
EgisonBug String
"sevaral elements in scalar tensor")
tIntRef [] Tensor a
t = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tIntRef (Integer
m:Shape
ms) Tensor a
t = Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
m Tensor a
t EvalM (Tensor a)
-> (Tensor a -> EvalM (Tensor a)) -> EvalM (Tensor a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef Shape
ms
tIntRef1 :: [Integer] -> Tensor a -> EvalM a
tIntRef1 :: Shape -> Tensor a -> EvalM a
tIntRef1 [] (Scalar a
x) = a -> EvalM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
tIntRef1 [] (Tensor [] Vector a
xs [Index EgisonValue]
_) | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a -> EvalM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
0)
tIntRef1 [] Tensor a
_ = (CallStack -> EgisonError) -> EvalM a
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> CallStack -> EgisonError
EgisonBug String
"sevaral elements in scalar tensor")
tIntRef1 (Integer
m:Shape
ms) Tensor a
t = Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
m Tensor a
t EvalM (Tensor a) -> (Tensor a -> EvalM a) -> EvalM a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Shape -> Tensor a -> EvalM a
forall a. Shape -> Tensor a -> EvalM a
tIntRef1 Shape
ms
pattern SupOrSubIndex :: a -> Index a
pattern $mSupOrSubIndex :: forall r a. Index a -> (a -> r) -> (Void# -> r) -> r
SupOrSubIndex i <- (extractSupOrSubIndex -> Just i)
tref :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [] (Tensor [] Vector a
xs [Index EgisonValue]
_)
| Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ a -> Tensor a
forall a. a -> Tensor a
Scalar (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
0)
| Bool
otherwise = (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> CallStack -> EgisonError
EgisonBug String
"sevaral elements in scalar tensor")
tref [] Tensor a
t = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tref (s :: Index EgisonValue
s@(SupOrSubIndex (ScalarData (SingleSymbol SymbolExpr
_))):[Index EgisonValue]
ms) (Tensor (Integer
_:Shape
ns) Vector a
xs [Index EgisonValue]
js) = do
let yss :: [Vector a]
yss = Integer -> Vector a -> [Vector a]
forall a. Integer -> Vector a -> [Vector a]
split (Shape -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
ns) Vector a
xs
[Tensor a]
ts <- (Vector a -> EvalM (Tensor a))
-> [Vector a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Vector a
ys -> [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [Index EgisonValue]
ms (Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector a
ys ([Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a]
cdr [Index EgisonValue]
js))) [Vector a]
yss
Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat Index EgisonValue
s [Tensor a]
ts
tref (SupOrSubIndex (ScalarData (SingleTerm Integer
m [])):[Index EgisonValue]
ms) Tensor a
t = Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
m Tensor a
t EvalM (Tensor a)
-> (Tensor a -> EvalM (Tensor a)) -> EvalM (Tensor a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [Index EgisonValue]
ms
tref (SupOrSubIndex (ScalarData ScalarData
ZeroExpr):[Index EgisonValue]
_) Tensor a
_ = EgisonError -> EvalM (Tensor a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM (Tensor a))
-> EgisonError -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"tensor index out of bounds: 0"
tref (s :: Index EgisonValue
s@(SupOrSubIndex (Tuple [EgisonValue
mVal, EgisonValue
nVal])):[Index EgisonValue]
ms) t :: Tensor a
t@(Tensor Shape
is Vector a
_ [Index EgisonValue]
_) = do
Integer
m <- EgisonValue -> EvalM Integer
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
mVal
Integer
n <- EgisonValue -> EvalM Integer
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
nVal
if Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
n
then
Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Int -> Integer -> Shape
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
is) Integer
0) Vector a
forall a. Vector a
V.empty [])
else do
[Tensor a]
ts <- (Integer -> EvalM (Tensor a))
-> Shape
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Integer
i -> Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
i Tensor a
t EvalM (Tensor a)
-> (Tensor a -> EvalM (Tensor a)) -> EvalM (Tensor a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [Index EgisonValue]
ms) [Integer
m..Integer
n]
String
symId <- StateT EvalState (ExceptT EgisonError RuntimeM) String
forall (m :: * -> *). MonadRuntime m => m String
fresh
let index :: EgisonValue
index = String -> String -> EgisonValue
symbolScalarData String
"" (String
":::" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
symId)
case Index EgisonValue
s of
Sub{} -> Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat (EgisonValue -> Index EgisonValue
forall a. a -> Index a
Sub EgisonValue
index) [Tensor a]
ts
Sup{} -> Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat (EgisonValue -> Index EgisonValue
forall a. a -> Index a
Sup EgisonValue
index) [Tensor a]
ts
SupSub{} -> Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat (EgisonValue -> Index EgisonValue
forall a. a -> Index a
SupSub EgisonValue
index) [Tensor a]
ts
tref (Index EgisonValue
_:[Index EgisonValue]
_) Tensor a
_ = EgisonError -> EvalM (Tensor a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM (Tensor a))
-> EgisonError -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"Tensor index must be an integer or a single symbol."
enumTensorIndices :: Shape -> [[Integer]]
enumTensorIndices :: Shape -> [Shape]
enumTensorIndices [] = [[]]
enumTensorIndices (Integer
n:Shape
ns) = (Integer -> [Shape]) -> Shape -> [Shape]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\Integer
i -> (Shape -> Shape) -> [Shape] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map (Integer
iInteger -> Shape -> Shape
forall a. a -> [a] -> [a]
:) (Shape -> [Shape]
enumTensorIndices Shape
ns)) [Integer
1..Integer
n]
transIndex :: [Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex :: [Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex [Index EgisonValue]
is [Index EgisonValue]
js Shape
ns = do
(Index EgisonValue -> EvalM Integer)
-> [Index EgisonValue] -> EvalM Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Index EgisonValue
j -> case Index EgisonValue
-> [(Index EgisonValue, Integer)] -> Maybe Integer
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Index EgisonValue
j ([Index EgisonValue] -> Shape -> [(Index EgisonValue, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Index EgisonValue]
is Shape
ns) of
Just Integer
n -> Integer -> EvalM Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
n
Maybe Integer
Nothing -> EgisonError -> EvalM Integer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM Integer) -> EgisonError -> EvalM Integer
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"cannot transpose becuase of the inconsitent symbolic tensor indices")
[Index EgisonValue]
js
tTranspose :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose [Index EgisonValue]
is t :: Tensor a
t@(Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) | [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
js =
Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tTranspose [Index EgisonValue]
is t :: Tensor a
t@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
js) = do
let js' :: [Index EgisonValue]
js' = Int -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Int -> [a] -> [a]
take ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is) [Index EgisonValue]
js
let ds :: [Index EgisonValue]
ds = Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns [Index EgisonValue]
is
Shape
ns' <- [Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex ([Index EgisonValue]
js' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) Shape
ns
Vector a
xs' <- (Shape -> EvalM Shape)
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Shape]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) ([Index EgisonValue]
js' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds)) (Shape -> [Shape]
enumTensorIndices Shape
ns') StateT EvalState (ExceptT EgisonError RuntimeM) [Shape]
-> ([Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a))
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Shape -> StateT EvalState (ExceptT EgisonError RuntimeM) a)
-> Vector Shape
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape
-> Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall a. Shape -> Tensor a -> EvalM a
`tIntRef1` Tensor a
t) (Vector Shape
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a))
-> ([Shape] -> Vector Shape)
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Shape] -> Vector Shape
forall a. [a] -> Vector a
V.fromList
Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns' Vector a
xs' [Index EgisonValue]
is
tTranspose' :: [EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose' :: [EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose' [EgisonValue]
is t :: Tensor a
t@(Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) =
case (EgisonValue -> Maybe (Index EgisonValue))
-> [EgisonValue] -> Maybe [Index EgisonValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\EgisonValue
i -> EgisonValue -> [Index EgisonValue] -> Maybe (Index EgisonValue)
f EgisonValue
i [Index EgisonValue]
js) [EgisonValue]
is of
Maybe [Index EgisonValue]
Nothing -> Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
Just [Index EgisonValue]
is' -> [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose [Index EgisonValue]
is' Tensor a
t
where
f :: EgisonValue -> [Index EgisonValue] -> Maybe (Index EgisonValue)
f :: EgisonValue -> [Index EgisonValue] -> Maybe (Index EgisonValue)
f EgisonValue
i [Index EgisonValue]
js =
((List (IndexM Eql), [Index EgisonValue])
-> DFS (List (IndexM Eql), [Index EgisonValue]))
-> [Index EgisonValue]
-> List (IndexM Eql)
-> [(List (IndexM Eql), [Index EgisonValue])
-> DFS (Maybe (Index EgisonValue))]
-> Maybe (Index EgisonValue)
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (List (IndexM Eql), [Index EgisonValue])
-> DFS (List (IndexM Eql), [Index EgisonValue])
forall a. a -> DFS a
dfs [Index EgisonValue]
js (IndexM Eql -> List (IndexM Eql)
forall m. m -> List m
List (Eql -> IndexM Eql
forall m. m -> IndexM m
IndexM Eql
Eql))
[ [mc| _ ++ ($j & (sub #i | sup #i | supsub #i)) : _ -> Just j |]
, [mc| _ -> Nothing |]
]
tFlipIndices :: Tensor a -> EvalM (Tensor a)
tFlipIndices :: Tensor a -> EvalM (Tensor a)
tFlipIndices (Tensor Shape
ns Vector a
xs [Index EgisonValue]
js) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector a
xs ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
js)
appendDF :: Integer -> WHNFData -> WHNFData
appendDF :: Integer -> WHNFData -> WHNFData
appendDF Integer
id (ITensor (Tensor Shape
s Vector ObjectRef
xs [Index EgisonValue]
is)) =
let k :: Integer
k = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is)
in Tensor ObjectRef -> WHNFData
ITensor (Shape
-> Vector ObjectRef -> [Index EgisonValue] -> Tensor ObjectRef
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector ObjectRef
xs ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Integer -> Index EgisonValue) -> Shape -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Index EgisonValue
forall a. Integer -> Integer -> Index a
DF Integer
id) [Integer
1..Integer
k]))
appendDF Integer
id (Value (TensorData (Tensor Shape
s Vector EgisonValue
xs [Index EgisonValue]
is))) =
let k :: Integer
k = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is)
in EgisonValue -> WHNFData
Value (Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector EgisonValue
xs ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Integer -> Index EgisonValue) -> Shape -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Index EgisonValue
forall a. Integer -> Integer -> Index a
DF Integer
id) [Integer
1..Integer
k])))
appendDF Integer
_ WHNFData
whnf = WHNFData
whnf
removeDF :: WHNFData -> EvalM WHNFData
removeDF :: WHNFData -> EvalM WHNFData
removeDF (ITensor (Tensor Shape
s Vector ObjectRef
xs [Index EgisonValue]
is)) = do
let ([Index EgisonValue]
ds, [Index EgisonValue]
js) = (Index EgisonValue -> Bool)
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Index EgisonValue -> Bool
forall a. Index a -> Bool
isDF [Index EgisonValue]
is
Tensor Shape
s Vector ObjectRef
ys [Index EgisonValue]
_ <- [Index EgisonValue] -> Tensor ObjectRef -> EvalM (Tensor ObjectRef)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
js [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) (Shape
-> Vector ObjectRef -> [Index EgisonValue] -> Tensor ObjectRef
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector ObjectRef
xs [Index EgisonValue]
is)
WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor ObjectRef -> WHNFData
ITensor (Shape
-> Vector ObjectRef -> [Index EgisonValue] -> Tensor ObjectRef
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector ObjectRef
ys [Index EgisonValue]
js))
where
isDF :: Index a -> Bool
isDF (DF Integer
_ Integer
_) = Bool
True
isDF Index a
_ = Bool
False
removeDF (Value (TensorData (Tensor Shape
s Vector EgisonValue
xs [Index EgisonValue]
is))) = do
let ([Index EgisonValue]
ds, [Index EgisonValue]
js) = (Index EgisonValue -> Bool)
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Index EgisonValue -> Bool
forall a. Index a -> Bool
isDF [Index EgisonValue]
is
Tensor Shape
s Vector EgisonValue
ys [Index EgisonValue]
_ <- [Index EgisonValue]
-> Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
js [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector EgisonValue
xs [Index EgisonValue]
is)
WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> WHNFData
Value (Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector EgisonValue
ys [Index EgisonValue]
js)))
where
isDF :: Index a -> Bool
isDF (DF Integer
_ Integer
_) = Bool
True
isDF Index a
_ = Bool
False
removeDF WHNFData
whnf = WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return WHNFData
whnf
tMap :: (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap :: (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap a -> EvalM b
f (Tensor Shape
ns Vector a
xs [Index EgisonValue]
js') = do
let js :: [Index EgisonValue]
js = [Index EgisonValue]
js' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns [Index EgisonValue]
js'
Vector b
xs' <- (a -> EvalM b)
-> Vector a
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM a -> EvalM b
f Vector a
xs
Tensor b -> EvalM (Tensor b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor b -> EvalM (Tensor b)) -> Tensor b -> EvalM (Tensor b)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector b -> [Index EgisonValue] -> Tensor b
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector b
xs' [Index EgisonValue]
js
tMap a -> EvalM b
f (Scalar a
x) = b -> Tensor b
forall a. a -> Tensor a
Scalar (b -> Tensor b) -> EvalM b -> EvalM (Tensor b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> EvalM b
f a
x
tMap2 :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tMap2 :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tMap2 a -> b -> EvalM c
f (Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1') (Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2') = do
let js1 :: [Index EgisonValue]
js1 = [Index EgisonValue]
js1' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns1 [Index EgisonValue]
js1'
let js2 :: [Index EgisonValue]
js2 = [Index EgisonValue]
js2' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns2 [Index EgisonValue]
js2'
let cjs :: [Index EgisonValue]
cjs = [Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Index EgisonValue]
js2
Tensor a
t1' <- [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ ([Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Index EgisonValue]
cjs)) (Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1)
Tensor b
t2' <- [Index EgisonValue] -> Tensor b -> EvalM (Tensor b)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ ([Index EgisonValue]
js2 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Index EgisonValue]
cjs)) (Shape -> Vector b -> [Index EgisonValue] -> Tensor b
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2)
let cns :: Shape
cns = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
cjs) (Tensor a -> Shape
forall a. Tensor a -> Shape
tShape Tensor a
t1')
[Tensor a]
rts1 <- (Shape -> EvalM (Tensor a))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
`tIntRef` Tensor a
t1') (Shape -> [Shape]
enumTensorIndices Shape
cns)
[Tensor b]
rts2 <- (Shape -> EvalM (Tensor b))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape -> Tensor b -> EvalM (Tensor b)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
`tIntRef` Tensor b
t2') (Shape -> [Shape]
enumTensorIndices Shape
cns)
[Tensor c]
rts' <- (Tensor a -> Tensor b -> EvalM (Tensor c))
-> [Tensor a]
-> [Tensor b]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor c]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ((a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
forall a b c.
(a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct a -> b -> EvalM c
f) [Tensor a]
rts1 [Tensor b]
rts2
let ret :: Tensor c
ret = Shape -> Vector c -> [Index EgisonValue] -> Tensor c
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
cns Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Tensor c -> Shape
forall a. Tensor a -> Shape
tShape ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts')) ([Vector c] -> Vector c
forall a. [Vector a] -> Vector a
V.concat ((Tensor c -> Vector c) -> [Tensor c] -> [Vector c]
forall a b. (a -> b) -> [a] -> [b]
map Tensor c -> Vector c
forall a. Tensor a -> Vector a
tToVector [Tensor c]
rts')) ([Index EgisonValue]
cjs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Tensor c -> [Index EgisonValue]
forall a. Tensor a -> [Index EgisonValue]
tIndex ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts'))
[Index EgisonValue] -> Tensor c -> EvalM (Tensor c)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue] -> [Index EgisonValue]
uniq ([Index EgisonValue] -> [Index EgisonValue]
tDiagIndex ([Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
js2))) Tensor c
ret
where
uniq :: [Index EgisonValue] -> [Index EgisonValue]
uniq :: [Index EgisonValue] -> [Index EgisonValue]
uniq [] = []
uniq (Index EgisonValue
x:[Index EgisonValue]
xs) = Index EgisonValue
xIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue] -> [Index EgisonValue]
uniq (Index EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => a -> [a] -> [a]
delete Index EgisonValue
x [Index EgisonValue]
xs)
tMap2 a -> b -> EvalM c
f t :: Tensor a
t@Tensor{} (Scalar b
x) = (a -> EvalM c) -> Tensor a -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
`f` b
x) Tensor a
t
tMap2 a -> b -> EvalM c
f (Scalar a
x) t :: Tensor b
t@Tensor{} = (b -> EvalM c) -> Tensor b -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
f a
x) Tensor b
t
tMap2 a -> b -> EvalM c
f (Scalar a
x1) (Scalar b
x2) = c -> Tensor c
forall a. a -> Tensor a
Scalar (c -> Tensor c) -> EvalM c -> EvalM (Tensor c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> b -> EvalM c
f a
x1 b
x2
tDiag :: Tensor a -> EvalM (Tensor a)
tDiag :: Tensor a -> EvalM (Tensor a)
tDiag t :: Tensor a
t@(Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) =
case (Index EgisonValue -> Bool)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Index EgisonValue
j -> (Index EgisonValue -> Bool) -> [Index EgisonValue] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Index EgisonValue -> Index EgisonValue -> Bool
p Index EgisonValue
j) [Index EgisonValue]
js) [Index EgisonValue]
js of
[] -> Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
[Index EgisonValue]
xs -> do
let ys :: [Index EgisonValue]
ys = [Index EgisonValue]
js [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
xs)
Tensor a
t2 <- [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ys) Tensor a
t
let (Shape
ns1, Shape
tmp) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
xs) (Tensor a -> Shape
forall a. Tensor a -> Shape
tShape Tensor a
t2)
let ns2 :: Shape
ns2 = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
drop ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
xs) Shape
tmp
[Tensor a]
ts <- (Shape -> EvalM (Tensor a))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Shape
is -> Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef (Shape
is Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
is) Tensor a
t2) (Shape -> [Shape]
enumTensorIndices Shape
ns1)
Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
ns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
ns2) ([Vector a] -> Vector a
forall a. [Vector a] -> Vector a
V.concat ((Tensor a -> Vector a) -> [Tensor a] -> [Vector a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor a -> Vector a
forall a. Tensor a -> Vector a
tToVector [Tensor a]
ts)) ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
toSupSub [Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ys)
where
p :: Index EgisonValue -> Index EgisonValue -> Bool
p :: Index EgisonValue -> Index EgisonValue -> Bool
p (Sup EgisonValue
i) (Sub EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
p Index EgisonValue
_ Index EgisonValue
_ = Bool
False
tDiag Tensor a
t = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tDiagIndex :: [Index EgisonValue] -> [Index EgisonValue]
tDiagIndex :: [Index EgisonValue] -> [Index EgisonValue]
tDiagIndex [Index EgisonValue]
js =
((List (IndexM Eql), [Index EgisonValue])
-> DFS (List (IndexM Eql), [Index EgisonValue]))
-> [Index EgisonValue]
-> List (IndexM Eql)
-> [(List (IndexM Eql), [Index EgisonValue])
-> DFS [Index EgisonValue]]
-> [Index EgisonValue]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (List (IndexM Eql), [Index EgisonValue])
-> DFS (List (IndexM Eql), [Index EgisonValue])
forall a. a -> DFS a
dfs [Index EgisonValue]
js (IndexM Eql -> List (IndexM Eql)
forall m. m -> List m
List (Eql -> IndexM Eql
forall m. m -> IndexM m
IndexM Eql
Eql))
[ [mc| $hjs ++ sup $i : $mjs ++ sub #i : $tjs ->
tDiagIndex (SupSub i : hjs ++ mjs ++ tjs) |]
, [mc| $hjs ++ sub $i : $mjs ++ sup #i : $tjs ->
tDiagIndex (SupSub i : hjs ++ mjs ++ tjs) |]
, [mc| _ -> js |]
]
tProduct :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct a -> b -> EvalM c
f (Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1') (Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2') = do
let js1 :: [Index EgisonValue]
js1 = [Index EgisonValue]
js1' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns1 [Index EgisonValue]
js1'
let js2 :: [Index EgisonValue]
js2 = [Index EgisonValue]
js2' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns2 [Index EgisonValue]
js2'
let ([Index EgisonValue]
cjs1, [Index EgisonValue]
cjs2, [Index EgisonValue]
tjs1, [Index EgisonValue]
tjs2) = [Index EgisonValue]
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue], [Index EgisonValue],
[Index EgisonValue])
h [Index EgisonValue]
js1 [Index EgisonValue]
js2
let t1 :: Tensor a
t1 = Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1
let t2 :: Tensor b
t2 = Shape -> Vector b -> [Index EgisonValue] -> Tensor b
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2
case [Index EgisonValue]
cjs1 of
[] -> do
[c]
xs' <- (Shape -> EvalM c)
-> [Shape] -> StateT EvalState (ExceptT EgisonError RuntimeM) [c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Shape
is -> do let is1 :: Shape
is1 = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns1) Shape
is
let is2 :: Shape
is2 = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns2) (Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
drop (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns1) Shape
is)
a
x1 <- Shape -> Tensor a -> EvalM a
forall a. Shape -> Tensor a -> EvalM a
tIntRef1 Shape
is1 Tensor a
t1
b
x2 <- Shape -> Tensor b -> EvalM b
forall a. Shape -> Tensor a -> EvalM a
tIntRef1 Shape
is2 Tensor b
t2
a -> b -> EvalM c
f a
x1 b
x2)
(Shape -> [Shape]
enumTensorIndices (Shape
ns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
ns2))
Tensor c -> EvalM (Tensor c)
forall a. Tensor a -> EvalM (Tensor a)
tContract' (Shape -> Vector c -> [Index EgisonValue] -> Tensor c
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
ns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
ns2) ([c] -> Vector c
forall a. [a] -> Vector a
V.fromList [c]
xs') ([Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
js2))
[Index EgisonValue]
_ -> do
Tensor a
t1' <- [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs1) Tensor a
t1
Tensor b
t2' <- [Index EgisonValue] -> Tensor b -> EvalM (Tensor b)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs2 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs2) Tensor b
t2
let (Shape
cns1, Shape
_) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
cjs1) (Tensor a -> Shape
forall a. Tensor a -> Shape
tShape Tensor a
t1')
[Tensor c]
rts' <- (Shape -> EvalM (Tensor c))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Shape
is -> do Tensor a
rt1 <- Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef Shape
is Tensor a
t1'
Tensor b
rt2 <- Shape -> Tensor b -> EvalM (Tensor b)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef Shape
is Tensor b
t2'
(a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
forall a b c.
(a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct a -> b -> EvalM c
f Tensor a
rt1 Tensor b
rt2)
(Shape -> [Shape]
enumTensorIndices Shape
cns1)
let ret :: Tensor c
ret = Shape -> Vector c -> [Index EgisonValue] -> Tensor c
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
cns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Tensor c -> Shape
forall a. Tensor a -> Shape
tShape ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts')) ([Vector c] -> Vector c
forall a. [Vector a] -> Vector a
V.concat ((Tensor c -> Vector c) -> [Tensor c] -> [Vector c]
forall a b. (a -> b) -> [a] -> [b]
map Tensor c -> Vector c
forall a. Tensor a -> Vector a
tToVector [Tensor c]
rts')) ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
toSupSub [Index EgisonValue]
cjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Tensor c -> [Index EgisonValue]
forall a. Tensor a -> [Index EgisonValue]
tIndex ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts'))
[Index EgisonValue] -> Tensor c -> EvalM (Tensor c)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue] -> [Index EgisonValue]
uniq ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
toSupSub [Index EgisonValue]
cjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs2)) Tensor c
ret
where
h :: [Index EgisonValue] -> [Index EgisonValue] -> ([Index EgisonValue], [Index EgisonValue], [Index EgisonValue], [Index EgisonValue])
h :: [Index EgisonValue]
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue], [Index EgisonValue],
[Index EgisonValue])
h [Index EgisonValue]
js1 [Index EgisonValue]
js2 = let cjs :: [Index EgisonValue]
cjs = (Index EgisonValue -> Bool)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Index EgisonValue
j -> (Index EgisonValue -> Bool) -> [Index EgisonValue] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Index EgisonValue -> Index EgisonValue -> Bool
p Index EgisonValue
j) [Index EgisonValue]
js2) [Index EgisonValue]
js1 in
([Index EgisonValue]
cjs, (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
cjs, [Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Index EgisonValue]
cjs, [Index EgisonValue]
js2 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
cjs)
p :: Index EgisonValue -> Index EgisonValue -> Bool
p :: Index EgisonValue -> Index EgisonValue -> Bool
p (Sup EgisonValue
i) (Sub EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
p (Sub EgisonValue
i) (Sup EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
p Index EgisonValue
_ Index EgisonValue
_ = Bool
False
uniq :: [Index EgisonValue] -> [Index EgisonValue]
uniq :: [Index EgisonValue] -> [Index EgisonValue]
uniq [] = []
uniq (Index EgisonValue
x:[Index EgisonValue]
xs) = Index EgisonValue
xIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue] -> [Index EgisonValue]
uniq (Index EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => a -> [a] -> [a]
delete Index EgisonValue
x [Index EgisonValue]
xs)
tProduct a -> b -> EvalM c
f (Scalar a
x) t :: Tensor b
t@Tensor{} = (b -> EvalM c) -> Tensor b -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
f a
x) Tensor b
t
tProduct a -> b -> EvalM c
f t :: Tensor a
t@Tensor{} (Scalar b
x) = (a -> EvalM c) -> Tensor a -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
`f` b
x) Tensor a
t
tProduct a -> b -> EvalM c
f (Scalar a
x1) (Scalar b
x2) = c -> Tensor c
forall a. a -> Tensor a
Scalar (c -> Tensor c) -> EvalM c -> EvalM (Tensor c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> b -> EvalM c
f a
x1 b
x2
tContract :: Tensor a -> EvalM [Tensor a]
tContract :: Tensor a -> EvalM [Tensor a]
tContract Tensor a
t = do
Tensor a
t' <- Tensor a -> EvalM (Tensor a)
forall a. Tensor a -> EvalM (Tensor a)
tDiag Tensor a
t
case Tensor a
t' of
Tensor (Integer
n:Shape
_) Vector a
_ (SupSub EgisonValue
_ : [Index EgisonValue]
_) -> do
[Tensor a]
ts <- (Integer -> EvalM (Tensor a)) -> Shape -> EvalM [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
`tIntRef'` Tensor a
t') [Integer
1..Integer
n]
[[Tensor a]]
tss <- (Tensor a -> EvalM [Tensor a])
-> [Tensor a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [[Tensor a]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor a -> EvalM [Tensor a]
forall a. Tensor a -> EvalM [Tensor a]
tContract [Tensor a]
ts
[Tensor a] -> EvalM [Tensor a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Tensor a] -> EvalM [Tensor a]) -> [Tensor a] -> EvalM [Tensor a]
forall a b. (a -> b) -> a -> b
$ [[Tensor a]] -> [Tensor a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Tensor a]]
tss
Tensor a
_ -> [Tensor a] -> EvalM [Tensor a]
forall (m :: * -> *) a. Monad m => a -> m a
return [Tensor a
t']
tContract' :: Tensor a -> EvalM (Tensor a)
tContract' :: Tensor a -> EvalM (Tensor a)
tContract' t :: Tensor a
t@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
js) =
((List Something, [Index EgisonValue])
-> DFS (List Something, [Index EgisonValue]))
-> [Index EgisonValue]
-> List Something
-> [(List Something, [Index EgisonValue])
-> DFS (EvalM (Tensor a))]
-> EvalM (Tensor a)
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (List Something, [Index EgisonValue])
-> DFS (List Something, [Index EgisonValue])
forall a. a -> DFS a
dfs [Index EgisonValue]
js (Something -> List Something
forall m. m -> List m
List Something
M.Something)
[ [mc| $hjs ++ $a : $mjs ++ ?(p a) : $tjs -> do
let m = fromIntegral (length hjs)
xs' <- mapM (\i -> tref (hjs ++ (Sub (ScalarData (SingleTerm i [])) : mjs)
++ (Sub (ScalarData (SingleTerm i [])) : tjs)) t)
[1..(ns !! m)]
tConcat a xs' >>= tTranspose (hjs ++ a : mjs ++ tjs) >>= tContract' |]
, [mc| _ -> return t |]
]
where
p :: Index EgisonValue -> Index EgisonValue -> Bool
p :: Index EgisonValue -> Index EgisonValue -> Bool
p (Sup EgisonValue
i) (Sup EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
p (Sub EgisonValue
i) (Sub EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
p (DF Integer
i1 Integer
j1) (DF Integer
i2 Integer
j2) = (Integer
i1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
i2) Bool -> Bool -> Bool
&& (Integer
j1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j2)
p Index EgisonValue
_ Index EgisonValue
_ = Bool
False
tContract' Tensor a
val = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
val
tConcat :: Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat :: Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat Index EgisonValue
s (Tensor ns :: Shape
ns@(Integer
0:Shape
_) Vector a
_ [Index EgisonValue]
js:[Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Integer
0Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) Vector a
forall a. Vector a
V.empty (Index EgisonValue
sIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue]
js)
tConcat Index EgisonValue
s ts :: [Tensor a]
ts@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
js:[Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) ([Vector a] -> Vector a
forall a. [Vector a] -> Vector a
V.concat ((Tensor a -> Vector a) -> [Tensor a] -> [Vector a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor a -> Vector a
forall a. Tensor a -> Vector a
tToVector [Tensor a]
ts)) (Index EgisonValue
sIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue]
js)
tConcat Index EgisonValue
s [Tensor a]
ts = do
[a]
ts' <- (Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a)
-> [Tensor a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall a. Tensor a -> EvalM a
getScalar [Tensor a]
ts
Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor [Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)] ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
ts') [Index EgisonValue
s]
tConcat' :: [Tensor a] -> EvalM (Tensor a)
tConcat' :: [Tensor a] -> EvalM (Tensor a)
tConcat' (Tensor ns :: Shape
ns@(Integer
0:Shape
_) Vector a
_ [Index EgisonValue]
_ : [Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Integer
0Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) Vector a
forall a. Vector a
V.empty []
tConcat' ts :: [Tensor a]
ts@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
_ : [Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) ([Vector a] -> Vector a
forall a. [Vector a] -> Vector a
V.concat ((Tensor a -> Vector a) -> [Tensor a] -> [Vector a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor a -> Vector a
forall a. Tensor a -> Vector a
tToVector [Tensor a]
ts)) []
tConcat' [Tensor a]
ts = do
[a]
ts' <- (Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a)
-> [Tensor a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall a. Tensor a -> EvalM a
getScalar [Tensor a]
ts
Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor [Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)] ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
ts') []
cdr :: [a] -> [a]
cdr :: [a] -> [a]
cdr [] = []
cdr (a
_:[a]
ts) = [a]
ts
split :: Integer -> V.Vector a -> [V.Vector a]
split :: Integer -> Vector a -> [Vector a]
split Integer
w Vector a
xs
| Vector a -> Bool
forall a. Vector a -> Bool
V.null Vector a
xs = []
| Bool
otherwise = let (Vector a
hs, Vector a
ts) = Int -> Vector a -> (Vector a, Vector a)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
w) Vector a
xs in
Vector a
hsVector a -> [Vector a] -> [Vector a]
forall a. a -> [a] -> [a]
:Integer -> Vector a -> [Vector a]
forall a. Integer -> Vector a -> [Vector a]
split Integer
w Vector a
ts
getScalar :: Tensor a -> EvalM a
getScalar :: Tensor a -> EvalM a
getScalar (Scalar a
x) = a -> EvalM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
getScalar Tensor a
_ = EgisonError -> EvalM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM a) -> EgisonError -> EvalM a
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"Inconsitent Tensor order"
reverseIndex :: Index a -> Index a
reverseIndex :: Index a -> Index a
reverseIndex (Sup a
i) = a -> Index a
forall a. a -> Index a
Sub a
i
reverseIndex (Sub a
i) = a -> Index a
forall a. a -> Index a
Sup a
i
reverseIndex Index a
x = Index a
x
toSupSub :: Index a -> Index a
toSupSub :: Index a -> Index a
toSupSub (Sup a
i) = a -> Index a
forall a. a -> Index a
SupSub a
i
toSupSub (Sub a
i) = a -> Index a
forall a. a -> Index a
SupSub a
i
complementWithDF :: Shape -> [Index a] -> [Index a]
complementWithDF :: Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns [Index a]
js' = (Integer -> Index a) -> Shape -> [Index a]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Index a
forall a. Integer -> Integer -> Index a
DF Integer
0) [Integer
1..Integer
k]
where k :: Integer
k = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Index a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index a]
js'