{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module NumHask.Array.Dynamic
(
Array (..),
fromFlatList,
toFlatList,
index,
tabulate,
takes,
reshape,
transpose,
indices,
ident,
sequent,
diag,
undiag,
singleton,
selects,
selectsExcept,
folds,
extracts,
extractsExcept,
joins,
maps,
concatenate,
insert,
append,
reorder,
expand,
expandr,
apply,
contract,
dot,
mult,
slice,
squeeze,
fromScalar,
toScalar,
col,
row,
mmult,
)
where
import Data.List (intercalate)
import Data.Vector qualified as V
import GHC.Show (Show (..))
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (product)
data Array a = Array {forall a. Array a -> [Int]
shape :: [Int], forall a. Array a -> Vector a
unArray :: V.Vector a}
deriving (Array a -> Array a -> Bool
forall a. Eq a => Array a -> Array a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Array a -> Array a -> Bool
$c/= :: forall a. Eq a => Array a -> Array a -> Bool
== :: Array a -> Array a -> Bool
$c== :: forall a. Eq a => Array a -> Array a -> Bool
Eq, Array a -> Array a -> Bool
Array a -> Array a -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {a}. Ord a => Eq (Array a)
forall a. Ord a => Array a -> Array a -> Bool
forall a. Ord a => Array a -> Array a -> Ordering
forall a. Ord a => Array a -> Array a -> Array a
min :: Array a -> Array a -> Array a
$cmin :: forall a. Ord a => Array a -> Array a -> Array a
max :: Array a -> Array a -> Array a
$cmax :: forall a. Ord a => Array a -> Array a -> Array a
>= :: Array a -> Array a -> Bool
$c>= :: forall a. Ord a => Array a -> Array a -> Bool
> :: Array a -> Array a -> Bool
$c> :: forall a. Ord a => Array a -> Array a -> Bool
<= :: Array a -> Array a -> Bool
$c<= :: forall a. Ord a => Array a -> Array a -> Bool
< :: Array a -> Array a -> Bool
$c< :: forall a. Ord a => Array a -> Array a -> Bool
compare :: Array a -> Array a -> Ordering
$ccompare :: forall a. Ord a => Array a -> Array a -> Ordering
Ord, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Array a) x -> Array a
forall a x. Array a -> Rep (Array a) x
$cto :: forall a x. Rep (Array a) x -> Array a
$cfrom :: forall a x. Array a -> Rep (Array a) x
Generic)
instance Functor Array where
fmap :: forall a b. (a -> b) -> Array a -> Array b
fmap a -> b
f (Array [Int]
s Vector a
a) = forall a. [Int] -> Vector a -> Array a
Array [Int]
s (forall a b. (a -> b) -> Vector a -> Vector b
V.map a -> b
f Vector a
a)
instance Foldable Array where
foldr :: forall a b. (a -> b -> b) -> b -> Array a -> b
foldr a -> b -> b
x b
a (Array [Int]
_ Vector a
v) = forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr a -> b -> b
x b
a Vector a
v
instance Traversable Array where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array a -> f (Array b)
traverse a -> f b
f (Array [Int]
s Vector a
v) =
forall a. [Int] -> [a] -> Array a
fromFlatList [Int]
s forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Vector a
v)
instance (Show a) => Show (Array a) where
show :: Array a -> String
show a :: Array a
a@(Array [Int]
l Vector a
_) = forall {a}. Show a => Int -> Array a -> String
go (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
l) Array a
a
where
go :: Int -> Array a -> String
go Int
n a' :: Array a
a'@(Array [Int]
l' Vector a
m) =
case forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
l' of
Int
0 -> forall a. Show a => a -> String
GHC.Show.show (forall a. Vector a -> a
V.head Vector a
m)
Int
1 -> String
"[" forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
intercalate String
", " (forall a. Show a => a -> String
GHC.Show.show forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Vector a -> [a]
V.toList Vector a
m) forall a. [a] -> [a] -> [a]
++ String
"]"
Int
x ->
String
"["
forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
intercalate
(String
",\n" forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
n forall a. Subtractive a => a -> a -> a
- Int
x forall a. Additive a => a -> a -> a
+ Int
1) Char
' ')
(Int -> Array a -> String
go Int
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Array a -> [a]
toFlatList (forall a. [Int] -> Array a -> Array (Array a)
extracts [Int
0] Array a
a'))
forall a. [a] -> [a] -> [a]
++ String
"]"
fromFlatList :: [Int] -> [a] -> Array a
fromFlatList :: forall a. [Int] -> [a] -> Array a
fromFlatList [Int]
ds [a]
l = forall a. [Int] -> Vector a -> Array a
Array [Int]
ds forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Vector a
V.fromList forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take ([Int] -> Int
size [Int]
ds) [a]
l
toFlatList :: Array a -> [a]
toFlatList :: forall a. Array a -> [a]
toFlatList (Array [Int]
_ Vector a
v) = forall a. Vector a -> [a]
V.toList Vector a
v
index :: () => Array a -> [Int] -> a
index :: forall a. Array a -> [Int] -> a
index (Array [Int]
s Vector a
v) [Int]
i = forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s [Int]
i)
tabulate :: () => [Int] -> ([Int] -> a) -> Array a
tabulate :: forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds [Int] -> a
f = forall a. [Int] -> Vector a -> Array a
Array [Int]
ds forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Int -> (Int -> a) -> Vector a
V.generate ([Int] -> Int
size [Int]
ds) forall a b. (a -> b) -> a -> b
$ ([Int] -> a
f forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> Int -> [Int]
shapen [Int]
ds)
takes ::
[Int] ->
Array a ->
Array a
takes :: forall a. [Int] -> Array a -> Array a
takes [Int]
ds Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds forall a b. (a -> b) -> a -> b
$ \[Int]
s -> forall a. Array a -> [Int] -> a
index Array a
a [Int]
s
reshape ::
[Int] ->
Array a ->
Array a
reshape :: forall a. [Int] -> Array a -> Array a
reshape [Int]
s Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s (forall a. Array a -> [Int] -> a
index Array a
a forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> Int -> [Int]
shapen (forall a. Array a -> [Int]
shape Array a
a) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> [Int] -> Int
flatten [Int]
s)
transpose :: Array a -> Array a
transpose :: forall a. Array a -> Array a
transpose Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. Array a -> [Int]
shape Array a
a) (forall a. Array a -> [Int] -> a
index Array a
a forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse)
indices :: [Int] -> Array [Int]
indices :: [Int] -> Array [Int]
indices [Int]
ds = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
ident :: (Additive a, Multiplicative a) => [Int] -> Array a
ident :: forall a. (Additive a, Multiplicative a) => [Int] -> Array a
ident [Int]
ds = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (forall a. a -> a -> Bool -> a
bool forall a. Additive a => a
zero forall a. Multiplicative a => a
one forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {a}. Eq a => [a] -> Bool
isDiag)
where
isDiag :: [a] -> Bool
isDiag [] = Bool
True
isDiag [a
_] = Bool
True
isDiag [a
x, a
y] = a
x forall a. Eq a => a -> a -> Bool
== a
y
isDiag (a
x : a
y : [a]
xs) = a
x forall a. Eq a => a -> a -> Bool
== a
y Bool -> Bool -> Bool
&& [a] -> Bool
isDiag (a
y forall a. a -> [a] -> [a]
: [a]
xs)
sequent :: [Int] -> Array Int
sequent :: [Int] -> Array Int
sequent [Int]
ds = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds forall {a}. (Additive a, Eq a) => [a] -> a
go
where
go :: [a] -> a
go [] = forall a. Additive a => a
zero
go [a
i] = a
i
go (a
i : [a]
js) = forall a. a -> a -> Bool -> a
bool forall a. Additive a => a
zero a
i (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a
i ==) [a]
js)
diag ::
Array a ->
Array a
diag :: forall a. Array a -> Array a
diag Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [[Int] -> Int
NumHask.Array.Shape.minimum (forall a. Array a -> [Int]
shape Array a
a)] [Int] -> a
go
where
go :: [Int] -> a
go [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
go (Int
s' : [Int]
_) = forall a. Array a -> [Int] -> a
index Array a
a (forall a. Int -> a -> [a]
replicate (forall a. [a] -> Int
rank (forall a. Array a -> [Int]
shape Array a
a)) Int
s')
undiag ::
(Additive a) =>
Int ->
Array a ->
Array a
undiag :: forall a. Additive a => Int -> Array a -> Array a
undiag Int
r Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (forall a. Int -> a -> [a]
replicate Int
r (forall a. [a] -> a
head (forall a. Array a -> [Int]
shape Array a
a))) [Int] -> a
go
where
go :: [Int] -> a
go [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
go xs :: [Int]
xs@(Int
x : [Int]
xs') = forall a. a -> a -> Bool -> a
bool forall a. Additive a => a
zero (forall a. Array a -> [Int] -> a
index Array a
a [Int]
xs) (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int
x ==) [Int]
xs')
singleton :: [Int] -> a -> Array a
singleton :: forall a. [Int] -> a -> Array a
singleton [Int]
ds a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (forall a b. a -> b -> a
const a
a)
selects ::
[Int] ->
[Int] ->
Array a ->
Array a
selects :: forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
i Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
dropIndexes (forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
s [Int]
ds [Int]
i)
selectsExcept ::
[Int] ->
[Int] ->
Array a ->
Array a
selectsExcept :: forall a. [Int] -> [Int] -> Array a -> Array a
selectsExcept [Int]
ds [Int]
i Array a
a = forall a. [Int] -> [Int] -> Array a -> Array a
selects (Int -> [Int] -> [Int]
exclude (forall a. [a] -> Int
rank (forall a. Array a -> [Int]
shape Array a
a)) [Int]
ds) [Int]
i Array a
a
folds ::
(Array a -> b) ->
[Int] ->
Array a ->
Array b
folds :: forall a b. (Array a -> b) -> [Int] -> Array a -> Array b
folds Array a -> b
f [Int]
ds Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
takeIndexes (forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> b
go
where
go :: [Int] -> b
go [Int]
s = Array a -> b
f (forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
s Array a
a)
extracts ::
[Int] ->
Array a ->
Array (Array a)
[Int]
ds Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
takeIndexes (forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> Array a
go
where
go :: [Int] -> Array a
go [Int]
s = forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
s Array a
a
extractsExcept ::
[Int] ->
Array a ->
Array (Array a)
[Int]
ds Array a
a = forall a. [Int] -> Array a -> Array (Array a)
extracts (Int -> [Int] -> [Int]
exclude (forall a. [a] -> Int
rank (forall a. Array a -> [Int]
shape Array a
a)) [Int]
ds) Array a
a
joins ::
[Int] ->
Array (Array a) ->
Array a
joins :: forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds Array (Array a)
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
si [Int]
ds [Int]
so) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = forall a. Array a -> [Int] -> a
index (forall a. Array a -> [Int] -> a
index Array (Array a)
a ([Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
ds)) ([Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
ds)
so :: [Int]
so = forall a. Array a -> [Int]
shape Array (Array a)
a
si :: [Int]
si = forall a. Array a -> [Int]
shape (forall a. Array a -> [Int] -> a
index Array (Array a)
a (forall a. Int -> a -> [a]
replicate (forall a. [a] -> Int
rank [Int]
so) Int
0))
maps ::
(Array a -> Array b) ->
[Int] ->
Array a ->
Array b
maps :: forall a b. (Array a -> Array b) -> [Int] -> Array a -> Array b
maps Array a -> Array b
f [Int]
ds Array a
a = forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> Array b
f (forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a))
concatenate ::
Int ->
Array a ->
Array a ->
Array a
concatenate :: forall a. Int -> Array a -> Array a -> Array a
concatenate Int
d Array a
a0 Array a
a1 = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int] -> [Int]
concatenate' Int
d (forall a. Array a -> [Int]
shape Array a
a0) (forall a. Array a -> [Int]
shape Array a
a1)) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s =
forall a. a -> a -> Bool -> a
bool
(forall a. Array a -> [Int] -> a
index Array a
a0 [Int]
s)
( forall a. Array a -> [Int] -> a
index
Array a
a1
( [Int] -> Int -> Int -> [Int]
addIndex
([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
Int
d
(([Int]
s forall a. [a] -> Int -> a
!! Int
d) forall a. Subtractive a => a -> a -> a
- ([Int]
ds0 forall a. [a] -> Int -> a
!! Int
d))
)
)
(([Int]
s forall a. [a] -> Int -> a
!! Int
d) forall a. Ord a => a -> a -> Bool
>= ([Int]
ds0 forall a. [a] -> Int -> a
!! Int
d))
ds0 :: [Int]
ds0 = forall a. Array a -> [Int]
shape Array a
a0
insert ::
Int ->
Int ->
Array a ->
Array a ->
Array a
insert :: forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d Int
i Array a
a Array a
b = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int]
incAt Int
d (forall a. Array a -> [Int]
shape Array a
a)) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s
| [Int]
s forall a. [a] -> Int -> a
!! Int
d forall a. Eq a => a -> a -> Bool
== Int
i = forall a. Array a -> [Int] -> a
index Array a
b ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
| [Int]
s forall a. [a] -> Int -> a
!! Int
d forall a. Ord a => a -> a -> Bool
< Int
i = forall a. Array a -> [Int] -> a
index Array a
a [Int]
s
| Bool
otherwise = forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
decAt Int
d [Int]
s)
append ::
Int ->
Array a ->
Array a ->
Array a
append :: forall a. Int -> Array a -> Array a -> Array a
append Int
d Array a
a Array a
b = forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d ([Int] -> Int -> Int
dimension (forall a. Array a -> [Int]
shape Array a
a) Int
d) Array a
a Array a
b
reorder ::
[Int] ->
Array a ->
Array a
reorder :: forall a. [Int] -> Array a -> Array a
reorder [Int]
ds Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
reorder' (forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [] [Int]
ds [Int]
s)
expand ::
(a -> b -> c) ->
Array a ->
Array b ->
Array c
expand :: forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
f Array a
a Array b
b = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (forall a. [a] -> [a] -> [a]
(++) (forall a. Array a -> [Int]
shape Array a
a) (forall a. Array a -> [Int]
shape Array b
b)) (\[Int]
i -> a -> b -> c
f (forall a. Array a -> [Int] -> a
index Array a
a (forall a. Int -> [a] -> [a]
take Int
r [Int]
i)) (forall a. Array a -> [Int] -> a
index Array b
b (forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)))
where
r :: Int
r = forall a. [a] -> Int
rank (forall a. Array a -> [Int]
shape Array a
a)
expandr ::
(a -> b -> c) ->
Array a ->
Array b ->
Array c
expandr :: forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expandr a -> b -> c
f Array a
a Array b
b = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (forall a. [a] -> [a] -> [a]
(++) (forall a. Array a -> [Int]
shape Array a
a) (forall a. Array a -> [Int]
shape Array b
b)) (\[Int]
i -> a -> b -> c
f (forall a. Array a -> [Int] -> a
index Array a
a (forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)) (forall a. Array a -> [Int] -> a
index Array b
b (forall a. Int -> [a] -> [a]
take Int
r [Int]
i)))
where
r :: Int
r = forall a. [a] -> Int
rank (forall a. Array a -> [Int]
shape Array a
a)
apply ::
Array (a -> b) ->
Array a ->
Array b
apply :: forall a b. Array (a -> b) -> Array a -> Array b
apply Array (a -> b)
f Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (forall a. [a] -> [a] -> [a]
(++) (forall a. Array a -> [Int]
shape Array (a -> b)
f) (forall a. Array a -> [Int]
shape Array a
a)) (\[Int]
i -> forall a. Array a -> [Int] -> a
index Array (a -> b)
f (forall a. Int -> [a] -> [a]
take Int
r [Int]
i) (forall a. Array a -> [Int] -> a
index Array a
a (forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)))
where
r :: Int
r = forall a. [a] -> Int
rank (forall a. Array a -> [Int]
shape Array (a -> b)
f)
contract ::
(Array a -> b) ->
[Int] ->
Array a ->
Array b
contract :: forall a b. (Array a -> b) -> [Int] -> Array a -> Array b
contract Array a -> b
f [Int]
xs Array a
a = Array a -> b
f forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Array a -> Array a
diag forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. [Int] -> Array a -> Array (Array a)
extractsExcept [Int]
xs Array a
a
dot ::
(Array c -> d) ->
(a -> b -> c) ->
Array a ->
Array b ->
Array d
dot :: forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array c -> d
f a -> b -> c
g Array a
a Array b
b = forall a b. (Array a -> b) -> [Int] -> Array a -> Array b
contract Array c -> d
f [forall a. [a] -> Int
rank [Int]
sa forall a. Subtractive a => a -> a -> a
- Int
1, forall a. [a] -> Int
rank [Int]
sa] (forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
g Array a
a Array b
b)
where
sa :: [Int]
sa = forall a. Array a -> [Int]
shape Array a
a
mult ::
( Additive a,
Multiplicative a
) =>
Array a ->
Array a ->
Array a
mult :: forall a.
(Additive a, Multiplicative a) =>
Array a -> Array a -> Array a
mult = forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum forall a. Multiplicative a => a -> a -> a
(*)
slice ::
[[Int]] ->
Array a ->
Array a
slice :: forall a. [[Int]] -> Array a -> Array a
slice [[Int]]
pss Array a
a = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (forall a. [[a]] -> [Int]
ranks [[Int]]
pss) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = forall a. Array a -> [Int] -> a
index Array a
a (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. [a] -> Int -> a
(!!) [[Int]]
pss [Int]
s)
squeeze ::
Array a ->
Array a
squeeze :: forall a. Array a -> Array a
squeeze (Array [Int]
s Vector a
x) = forall a. [Int] -> Vector a -> Array a
Array (forall a. (Eq a, Multiplicative a) => [a] -> [a]
squeeze' [Int]
s) Vector a
x
fromScalar :: Array a -> a
fromScalar :: forall a. Array a -> a
fromScalar Array a
a = forall a. Array a -> [Int] -> a
index Array a
a ([] :: [Int])
toScalar :: a -> Array a
toScalar :: forall a. a -> Array a
toScalar a
a = forall a. [Int] -> [a] -> Array a
fromFlatList [] [a
a]
row :: Int -> Array a -> Array a
row :: forall a. Int -> Array a -> Array a
row Int
i (Array [Int]
s Vector a
a) = forall a. [Int] -> Vector a -> Array a
Array [Int
n] forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
i forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
where
(Int
_ : Int
n : [Int]
_) = [Int]
s
col :: Int -> Array a -> Array a
col :: forall a. Int -> Array a -> Array a
col Int
i (Array [Int]
s Vector a
a) = forall a. [Int] -> Vector a -> Array a
Array [Int
m] forall a b. (a -> b) -> a -> b
$ forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
i forall a. Additive a => a -> a -> a
+ Int
x forall a. Multiplicative a => a -> a -> a
* Int
n))
where
(Int
m : Int
n : [Int]
_) = [Int]
s
mmult ::
(Ring a) =>
Array a ->
Array a ->
Array a
mmult :: forall a. Ring a => Array a -> Array a -> Array a
mmult (Array [Int]
sx Vector a
x) (Array [Int]
sy Vector a
y) = forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int
m, Int
n] [Int] -> a
go
where
go :: [Int] -> a
go [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
go [Int
_] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
go (Int
i : Int
j : [Int]
_) = forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith forall a. Multiplicative a => a -> a -> a
(*) (forall a. Int -> Int -> Vector a -> Vector a
V.slice (forall a b. FromIntegral a b => b -> a
fromIntegral Int
i forall a. Multiplicative a => a -> a -> a
* Int
k) Int
k Vector a
x) (forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k (\Int
x' -> Vector a
y forall a. Vector a -> Int -> a
V.! (forall a b. FromIntegral a b => b -> a
fromIntegral Int
j forall a. Additive a => a -> a -> a
+ Int
x' forall a. Multiplicative a => a -> a -> a
* Int
n)))
(Int
m : Int
k : [Int]
_) = [Int]
sx
(Int
_ : Int
n : [Int]
_) = [Int]
sy
{-# INLINE mmult #-}