{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Array.Internal.RankedG(
Array(..), Vector, VecElem,
size, shapeL, rank,
toList, fromList, toVector, fromVector,
normalize,
scalar, unScalar, constant,
reshape, stretch, stretchOuter, transpose,
index, pad,
mapA, zipWithA, zipWith3A,
append, concatOuter,
ravel, unravel,
window, stride, rotate,
slice, rerank, rerank2, rev,
reduce, foldrA, traverseA,
allSameA,
sumA, productA, maximumA, minimumA,
anyA, allA,
broadcast,
generate, iterateN, iota,
) where
import Control.Monad(replicateM)
import Control.DeepSeq
import Data.Data(Data)
import Data.List(sort)
import GHC.Generics(Generic)
import GHC.Stack
import GHC.TypeLits(Nat, type (+), KnownNat, type (<=))
import Test.QuickCheck hiding (generate)
import Text.PrettyPrint.HughesPJClass hiding ((<>))
import Data.Array.Internal
data Array (n :: Nat) v a = A ShapeL (T v a)
deriving ((forall x. Array n v a -> Rep (Array n v a) x)
-> (forall x. Rep (Array n v a) x -> Array n v a)
-> Generic (Array n v a)
forall x. Rep (Array n v a) x -> Array n v a
forall x. Array n v a -> Rep (Array n v a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (n :: Nat) (v :: * -> *) a x.
Rep (Array n v a) x -> Array n v a
forall (n :: Nat) (v :: * -> *) a x.
Array n v a -> Rep (Array n v a) x
$cto :: forall (n :: Nat) (v :: * -> *) a x.
Rep (Array n v a) x -> Array n v a
$cfrom :: forall (n :: Nat) (v :: * -> *) a x.
Array n v a -> Rep (Array n v a) x
Generic, Typeable (Array n v a)
DataType
Constr
Typeable (Array n v a)
-> (forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a))
-> (forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a))
-> (Array n v a -> Constr)
-> (Array n v a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a)))
-> ((forall b. Data b => b -> b) -> Array n v a -> Array n v a)
-> (forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r)
-> (forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r)
-> (forall u. (forall d. Data d => d -> u) -> Array n v a -> [u])
-> (forall u.
Int -> (forall d. Data d => d -> u) -> Array n v a -> u)
-> (forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a))
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a))
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a))
-> Data (Array n v a)
Array n v a -> DataType
Array n v a -> Constr
(forall b. Data b => b -> b) -> Array n v a -> Array n v a
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
forall a.
Typeable a
-> (forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Array n v a -> u
forall u. (forall d. Data d => d -> u) -> Array n v a -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Typeable (Array n v a)
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> DataType
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> Constr
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> Array n v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> Array n v a -> u
forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> Array n v a -> [u]
forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
forall (n :: Nat) (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
forall (n :: Nat) (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
$cA :: Constr
$tArray :: DataType
gmapMo :: (forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
$cgmapMo :: forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
gmapMp :: (forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
$cgmapMp :: forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
gmapM :: (forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
$cgmapM :: forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
gmapQi :: Int -> (forall d. Data d => d -> u) -> Array n v a -> u
$cgmapQi :: forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> Array n v a -> u
gmapQ :: (forall d. Data d => d -> u) -> Array n v a -> [u]
$cgmapQ :: forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> Array n v a -> [u]
gmapQr :: (r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
$cgmapQr :: forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
gmapQl :: (r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
$cgmapQl :: forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
gmapT :: (forall b. Data b => b -> b) -> Array n v a -> Array n v a
$cgmapT :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> Array n v a -> Array n v a
dataCast2 :: (forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
$cdataCast2 :: forall (n :: Nat) (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
dataCast1 :: (forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
$cdataCast1 :: forall (n :: Nat) (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
dataTypeOf :: Array n v a -> DataType
$cdataTypeOf :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> DataType
toConstr :: Array n v a -> Constr
$ctoConstr :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> Constr
gunfold :: (forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
$cgunfold :: forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
gfoldl :: (forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
$cgfoldl :: forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
$cp1Data :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Typeable (Array n v a)
Data)
instance (Vector v, Show a, VecElem v a) => Show (Array n v a) where
showsPrec :: Int -> Array n v a -> ShowS
showsPrec Int
p a :: Array n v a
a@(A ShapeL
s T v a
_) = Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
String -> ShowS
showString String
"fromList " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShapeL -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 ShapeL
s ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (Array n v a -> [a]
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> [a]
toList Array n v a
a)
instance (KnownNat n, Vector v, Read a, VecElem v a) => Read (Array n v a) where
readsPrec :: Int -> ReadS (Array n v a)
readsPrec Int
p = Bool -> ReadS (Array n v a) -> ReadS (Array n v a)
forall a. Bool -> ReadS a -> ReadS a
readParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ReadS (Array n v a) -> ReadS (Array n v a))
-> ReadS (Array n v a) -> ReadS (Array n v a)
forall a b. (a -> b) -> a -> b
$ \ String
r1 ->
[(ShapeL -> [a] -> Array n v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> [a] -> Array n v a
fromList ShapeL
s [a]
xs, String
r4)
| (String
"fromList", String
r2) <- ReadS String
lex String
r1, (ShapeL
s, String
r3) <- Int -> ReadS ShapeL
forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
r2
, ([a]
xs, String
r4) <- Int -> ReadS [a]
forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
r3, ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n, ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs]
instance (Vector v, Eq a, VecElem v a, Eq (v a)) => Eq (Array n v a) where
(A ShapeL
s T v a
v) == :: Array n v a -> Array n v a -> Bool
== (A ShapeL
s' T v a
v') = ShapeL
s ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeL
s' Bool -> Bool -> Bool
&& ShapeL -> T v a -> T v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a, Eq (v a)) =>
ShapeL -> T v a -> T v a -> Bool
equalT ShapeL
s T v a
v T v a
v'
{-# INLINE (==) #-}
instance (Vector v, Ord a, Ord (v a), VecElem v a) => Ord (Array n v a) where
(A ShapeL
s T v a
v) compare :: Array n v a -> Array n v a -> Ordering
`compare` (A ShapeL
s' T v a
v') = ShapeL -> ShapeL -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ShapeL
s ShapeL
s' Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> ShapeL -> T v a -> T v a -> Ordering
forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a, Ord (v a)) =>
ShapeL -> T v a -> T v a -> Ordering
compareT ShapeL
s T v a
v T v a
v'
{-# INLINE compare #-}
instance (Vector v, Pretty a, VecElem v a) => Pretty (Array n v a) where
pPrintPrec :: PrettyLevel -> Rational -> Array n v a -> Doc
pPrintPrec PrettyLevel
l Rational
p (A ShapeL
sh T v a
t) = PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
forall (v :: * -> *) a.
(Vector v, VecElem v a, Pretty a) =>
PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
ppT PrettyLevel
l Rational
p ShapeL
sh T v a
t
instance (NFData (v a)) => NFData (Array n v a) where
rnf :: Array n v a -> ()
rnf (A ShapeL
sh T v a
v) = ShapeL -> ()
forall a. NFData a => a -> ()
rnf ShapeL
sh () -> () -> ()
`seq` T v a -> ()
forall a. NFData a => a -> ()
rnf T v a
v
{-# INLINE size #-}
size :: Array n v a -> Int
size :: Array n v a -> Int
size = ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (ShapeL -> Int) -> (Array n v a -> ShapeL) -> Array n v a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array n v a -> ShapeL
forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL
{-# INLINE shapeL #-}
shapeL :: Array n v a -> ShapeL
shapeL :: Array n v a -> ShapeL
shapeL (A ShapeL
s T v a
_) = ShapeL
s
{-# INLINE rank #-}
rank :: forall n v a . (KnownNat n) => Array n v a -> Int
rank :: Array n v a -> Int
rank (A ShapeL
_ T v a
_) = forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n
{-# INLINE index #-}
index :: (Vector v, HasCallStack) => Array (1+n) v a -> Int -> Array n v a
index :: Array (1 + n) v a -> Int -> Array n v a
index (A (Int
s:ShapeL
ss) T v a
t) Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
s = String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"index: out of bounds " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
i, Int
s)
| Bool
otherwise = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ss (T v a -> Array n v a) -> T v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ T v a -> Int -> T v a
forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t Int
i
index (A [] T v a
_) Int
_ = String -> Array n v a
forall a. HasCallStack => String -> a
error String
"index: scalar"
{-# INLINE toList #-}
toList :: (Vector v, VecElem v a) => Array n v a -> [a]
toList :: Array n v a -> [a]
toList (A ShapeL
sh T v a
t) = ShapeL -> T v a -> [a]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> [a]
toListT ShapeL
sh T v a
t
{-# INLINE toVector #-}
toVector :: (Vector v, VecElem v a) => Array n v a -> v a
toVector :: Array n v a -> v a
toVector (A ShapeL
sh T v a
t) = ShapeL -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> v a
toVectorT ShapeL
sh T v a
t
{-# INLINE fromList #-}
fromList :: forall n v a . (HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> [a] -> Array n v a
fromList :: ShapeL -> [a] -> Array n v a
fromList ShapeL
ss [a]
vs | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
l = String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"fromList: size mismatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
n, Int
l)
| ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"fromList: rank mismatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss, forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
| Bool
otherwise = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ss (T v a -> Array n v a) -> T v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 (v a -> T v a) -> v a -> T v a
forall a b. (a -> b) -> a -> b
$ [a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList [a]
vs
where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
ss
l :: Int
l = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
vs
{-# INLINE fromVector #-}
fromVector :: forall n v a . (HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector :: ShapeL -> v a -> Array n v a
fromVector ShapeL
ss v a
v | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
l = String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"fromVector: size mismatch" String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
n, Int
l)
| ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"fromVector: rank mismatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss, forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
| Bool
otherwise = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ss (T v a -> Array n v a) -> T v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 v a
v
where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
ss
l :: Int
l = v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v
{-# INLINE normalize #-}
normalize :: (Vector v, VecElem v a, KnownNat n) => Array n v a -> Array n v a
normalize :: Array n v a -> Array n v a
normalize Array n v a
a = ShapeL -> v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector (Array n v a -> ShapeL
forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL Array n v a
a) (v a -> Array n v a) -> v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ Array n v a -> v a
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector Array n v a
a
{-# INLINE reshape #-}
reshape :: forall n n' v a . (HasCallStack,Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape :: ShapeL -> Array n v a -> Array n' v a
reshape ShapeL
sh (A ShapeL
sh' t :: T v a
t@(T ShapeL
ost Int
oo v a
v))
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n' = String -> Array n' v a
forall a. HasCallStack => String -> a
error (String -> Array n' v a) -> String -> Array n' v a
forall a b. (a -> b) -> a -> b
$ String
"reshape: size mismatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> String
forall a. Show a => a -> String
show (ShapeL
sh, ShapeL
sh')
| ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= forall i. (KnownNat n', Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n' = String -> Array n' v a
forall a. HasCallStack => String -> a
error (String -> Array n' v a) -> String -> Array n' v a
forall a b. (a -> b) -> a -> b
$ String
"reshape: rank mismatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh, forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
| v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = ShapeL -> T v a -> Array n' v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v a -> Array n' v a) -> T v a -> Array n' v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ((Int -> Int) -> ShapeL -> ShapeL
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) ShapeL
sh) Int
0 v a
v
| Just ShapeL
nst <- ShapeL -> ShapeL -> ShapeL -> Maybe ShapeL
simpleReshape ShapeL
ost ShapeL
sh' ShapeL
sh = ShapeL -> T v a -> Array n' v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v a -> Array n' v a) -> T v a -> Array n' v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
nst Int
oo v a
v
| Bool
otherwise = ShapeL -> T v a -> Array n' v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v a -> Array n' v a) -> T v a -> Array n' v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 (v a -> T v a) -> v a -> T v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> v a
toVectorT ShapeL
sh' T v a
t
where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
sh
n' :: Int
n' = ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
sh'
{-# INLINE stretch #-}
stretch :: (HasCallStack) => ShapeL -> Array n v a -> Array n v a
stretch :: ShapeL -> Array n v a -> Array n v a
stretch ShapeL
sh (A ShapeL
sh' T v a
vs) | Just [Bool]
bs <- ShapeL -> ShapeL -> Maybe [Bool]
forall a. (Eq a, Num a) => [a] -> [a] -> Maybe [Bool]
str ShapeL
sh ShapeL
sh' = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v a -> Array n v a) -> T v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ [Bool] -> T v a -> T v a
forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT [Bool]
bs T v a
vs
| Bool
otherwise = String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"stretch: incompatible " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> String
forall a. Show a => a -> String
show (ShapeL
sh, ShapeL
sh')
where str :: [a] -> [a] -> Maybe [Bool]
str [] [] = [Bool] -> Maybe [Bool]
forall a. a -> Maybe a
Just []
str (a
x:[a]
xs) (a
y:[a]
ys) | a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y = (Bool
False Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
:) ([Bool] -> [Bool]) -> Maybe [Bool] -> Maybe [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [a] -> Maybe [Bool]
str [a]
xs [a]
ys
| a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 = (Bool
True Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
:) ([Bool] -> [Bool]) -> Maybe [Bool] -> Maybe [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [a] -> Maybe [Bool]
str [a]
xs [a]
ys
str [a]
_ [a]
_ = Maybe [Bool]
forall a. Maybe a
Nothing
{-# INLINE stretchOuter #-}
stretchOuter :: (HasCallStack, 1 <= n) =>
Int -> Array n v a -> Array n v a
stretchOuter :: Int -> Array n v a -> Array n v a
stretchOuter Int
s (A (Int
1:ShapeL
sh) T v a
vs) =
ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (Int
sInt -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
:ShapeL
sh) (T v a -> Array n v a) -> T v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ [Bool] -> T v a -> T v a
forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT (Bool
True Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: (Int -> Bool) -> ShapeL -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Bool -> Int -> Bool
forall a b. a -> b -> a
const Bool
False) (T v a -> ShapeL
forall (v :: * -> *) a. T v a -> ShapeL
strides T v a
vs)) T v a
vs
stretchOuter Int
_ Array n v a
_ = String -> Array n v a
forall a. HasCallStack => String -> a
error String
"stretchOuter: needs outermost dimension of size 1"
{-# INLINE scalar #-}
scalar :: (Vector v, VecElem v a) => a -> Array 0 v a
scalar :: a -> Array 0 v a
scalar = ShapeL -> T v a -> Array 0 v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [] (T v a -> Array 0 v a) -> (a -> T v a) -> a -> Array 0 v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> T v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT
{-# INLINE unScalar #-}
unScalar :: (Vector v, VecElem v a) => Array 0 v a -> a
unScalar :: Array 0 v a -> a
unScalar (A ShapeL
_ T v a
t) = T v a -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT T v a
t
{-# INLINE constant #-}
constant :: forall n v a . (Vector v, VecElem v a, KnownNat n) =>
ShapeL -> a -> Array n v a
constant :: ShapeL -> a -> Array n v a
constant ShapeL
sh | ShapeL -> Bool
badShape ShapeL
sh = String -> a -> Array n v a
forall a. HasCallStack => String -> a
error (String -> a -> Array n v a) -> String -> a -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"constant: bad shape: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShapeL -> String
forall a. Show a => a -> String
show ShapeL
sh
| ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = String -> a -> Array n v a
forall a. HasCallStack => String -> a
error String
"constant: rank mismatch"
| Bool
otherwise = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v a -> Array n v a) -> (a -> T v a) -> a -> Array n v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeL -> a -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> a -> T v a
constantT ShapeL
sh
{-# INLINE mapA #-}
mapA :: (Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> Array n v a -> Array n v b
mapA :: (a -> b) -> Array n v a -> Array n v b
mapA a -> b
f (A ShapeL
s T v a
t) = ShapeL -> T v b -> Array n v b
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
s (ShapeL -> (a -> b) -> T v a -> T v b
forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
ShapeL -> (a -> b) -> T v a -> T v b
mapT ShapeL
s a -> b
f T v a
t)
{-# INLINE zipWithA #-}
zipWithA :: (Vector v, VecElem v a, VecElem v b, VecElem v c) =>
(a -> b -> c) -> Array n v a -> Array n v b -> Array n v c
zipWithA :: (a -> b -> c) -> Array n v a -> Array n v b -> Array n v c
zipWithA a -> b -> c
f (A ShapeL
s T v a
t) (A ShapeL
s' T v b
t') | ShapeL
s ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeL
s' = ShapeL -> T v c -> Array n v c
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
s (ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT ShapeL
s a -> b -> c
f T v a
t T v b
t')
| Bool
otherwise = String -> Array n v c
forall a. HasCallStack => String -> a
error (String -> Array n v c) -> String -> Array n v c
forall a b. (a -> b) -> a -> b
$ String
"zipWithA: shape mismatch: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> String
forall a. Show a => a -> String
show (ShapeL
s, ShapeL
s')
{-# INLINE zipWith3A #-}
zipWith3A :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
(a -> b -> c -> d) -> Array n v a -> Array n v b -> Array n v c -> Array n v d
zipWith3A :: (a -> b -> c -> d)
-> Array n v a -> Array n v b -> Array n v c -> Array n v d
zipWith3A a -> b -> c -> d
f (A ShapeL
s T v a
t) (A ShapeL
s' T v b
t') (A ShapeL
s'' T v c
t'') | ShapeL
s ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeL
s' Bool -> Bool -> Bool
&& ShapeL
s ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeL
s'' = ShapeL -> T v d -> Array n v d
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
s (ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T ShapeL
s a -> b -> c -> d
f T v a
t T v b
t' T v c
t'')
| Bool
otherwise = String -> Array n v d
forall a. HasCallStack => String -> a
error (String -> Array n v d) -> String -> Array n v d
forall a b. (a -> b) -> a -> b
$ String
"zipWith3A: shape mismatch: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL, ShapeL) -> String
forall a. Show a => a -> String
show (ShapeL
s, ShapeL
s', ShapeL
s'')
{-# INLINE pad #-}
pad :: forall n a v . (Vector v, VecElem v a) =>
[(Int, Int)] -> a -> Array n v a -> Array n v a
pad :: [(Int, Int)] -> a -> Array n v a -> Array n v a
pad [(Int, Int)]
aps a
v (A ShapeL
ash T v a
at) = (ShapeL -> T v a -> Array n v a) -> (ShapeL, T v a) -> Array n v a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ((ShapeL, T v a) -> Array n v a) -> (ShapeL, T v a) -> Array n v a
forall a b. (a -> b) -> a -> b
$ a -> [(Int, Int)] -> ShapeL -> T v a -> (ShapeL, T v a)
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> [(Int, Int)] -> ShapeL -> T v a -> (ShapeL, T v a)
padT a
v [(Int, Int)]
aps ShapeL
ash T v a
at
{-# INLINE transpose #-}
transpose :: forall n v a . (KnownNat n) =>
[Int] -> Array n v a -> Array n v a
transpose :: ShapeL -> Array n v a -> Array n v a
transpose ShapeL
is (A ShapeL
sh T v a
t) | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = String -> Array n v a
forall a. HasCallStack => String -> a
error String
"transpose: rank exceeded"
| ShapeL -> ShapeL
forall a. Ord a => [a] -> [a]
sort ShapeL
is ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
/= [Int
0 .. Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] =
String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"transpose: not a permutation: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShapeL -> String
forall a. Show a => a -> String
show ShapeL
is
| Bool
otherwise = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (ShapeL -> ShapeL -> ShapeL
forall a. ShapeL -> [a] -> [a]
permute ShapeL
is' ShapeL
sh) (ShapeL -> T v a -> T v a
forall (v :: * -> *) a. ShapeL -> T v a -> T v a
transposeT ShapeL
is' T v a
t)
where l :: Int
l = ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
is
n :: Int
n = forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n
is' :: ShapeL
is' = ShapeL
is ShapeL -> ShapeL -> ShapeL
forall a. [a] -> [a] -> [a]
++ [Int
l .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
{-# INLINE append #-}
append :: (Vector v, VecElem v a, KnownNat n) =>
Array n v a -> Array n v a -> Array n v a
append :: Array n v a -> Array n v a -> Array n v a
append a :: Array n v a
a@(A (Int
sa:ShapeL
sh) T v a
_) b :: Array n v a
b@(A (Int
sb:ShapeL
sh') T v a
_) | ShapeL
sh ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeL
sh' =
ShapeL -> v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector (Int
saInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
sb Int -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
: ShapeL
sh) (v a -> v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
v a -> v a -> v a
vAppend (Array n v a -> v a
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector Array n v a
a) (Array n v a -> v a
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector Array n v a
b))
append Array n v a
_ Array n v a
_ = String -> Array n v a
forall a. HasCallStack => String -> a
error String
"append: bad shape"
{-# INLINE concatOuter #-}
concatOuter :: (Vector v, VecElem v a, KnownNat n) => [Array n v a] -> Array n v a
concatOuter :: [Array n v a] -> Array n v a
concatOuter [] = String -> Array n v a
forall a. HasCallStack => String -> a
error String
"concatOuter: empty list"
concatOuter [Array n v a]
as | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ShapeL] -> Bool
forall a. Eq a => [a] -> Bool
allSame ([ShapeL] -> Bool) -> [ShapeL] -> Bool
forall a b. (a -> b) -> a -> b
$ (ShapeL -> ShapeL) -> [ShapeL] -> [ShapeL]
forall a b. (a -> b) -> [a] -> [b]
map ShapeL -> ShapeL
forall a. [a] -> [a]
tail [ShapeL]
shs =
String -> Array n v a
forall a. HasCallStack => String -> a
error (String -> Array n v a) -> String -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"concatOuter: non-conforming inner dimensions: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [ShapeL] -> String
forall a. Show a => a -> String
show [ShapeL]
shs
| Bool
otherwise = ShapeL -> v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
sh' (v a -> Array n v a) -> v a -> Array n v a
forall a b. (a -> b) -> a -> b
$ [v a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat ([v a] -> v a) -> [v a] -> v a
forall a b. (a -> b) -> a -> b
$ (Array n v a -> v a) -> [Array n v a] -> [v a]
forall a b. (a -> b) -> [a] -> [b]
map Array n v a -> v a
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector [Array n v a]
as
where shs :: [ShapeL]
shs@(ShapeL
sh:[ShapeL]
_) = (Array n v a -> ShapeL) -> [Array n v a] -> [ShapeL]
forall a b. (a -> b) -> [a] -> [b]
map Array n v a -> ShapeL
forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL [Array n v a]
as
sh' :: ShapeL
sh' = ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ShapeL -> Int) -> [ShapeL] -> ShapeL
forall a b. (a -> b) -> [a] -> [b]
map ShapeL -> Int
forall a. [a] -> a
head [ShapeL]
shs) Int -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
: ShapeL -> ShapeL
forall a. [a] -> [a]
tail ShapeL
sh
{-# INLINE ravel #-}
ravel :: (Vector v, Vector v', VecElem v a, VecElem v' (Array n v a), KnownNat (1+n)) =>
Array 1 v' (Array n v a) -> Array (1+n) v a
ravel :: Array 1 v' (Array n v a) -> Array (1 + n) v a
ravel Array 1 v' (Array n v a)
aa =
case Array 1 v' (Array n v a) -> [Array n v a]
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> [a]
toList Array 1 v' (Array n v a)
aa of
[] -> String -> Array (1 + n) v a
forall a. HasCallStack => String -> a
error String
"ravel: empty array"
[Array n v a]
as | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ShapeL] -> Bool
forall a. Eq a => [a] -> Bool
allSame [ShapeL]
shs -> String -> Array (1 + n) v a
forall a. HasCallStack => String -> a
error (String -> Array (1 + n) v a) -> String -> Array (1 + n) v a
forall a b. (a -> b) -> a -> b
$ String
"ravel: non-conforming inner dimensions: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [ShapeL] -> String
forall a. Show a => a -> String
show [ShapeL]
shs
| Bool
otherwise -> ShapeL -> v a -> Array (1 + n) v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
sh' (v a -> Array (1 + n) v a) -> v a -> Array (1 + n) v a
forall a b. (a -> b) -> a -> b
$ [v a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat ([v a] -> v a) -> [v a] -> v a
forall a b. (a -> b) -> a -> b
$ (Array n v a -> v a) -> [Array n v a] -> [v a]
forall a b. (a -> b) -> [a] -> [b]
map Array n v a -> v a
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector [Array n v a]
as
where shs :: [ShapeL]
shs@(ShapeL
sh:[ShapeL]
_) = (Array n v a -> ShapeL) -> [Array n v a] -> [ShapeL]
forall a b. (a -> b) -> [a] -> [b]
map Array n v a -> ShapeL
forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL [Array n v a]
as
sh' :: ShapeL
sh' = [Array n v a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Array n v a]
as Int -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
: ShapeL
sh
{-# INLINE unravel #-}
unravel :: (Vector v, Vector v', VecElem v a, VecElem v' (Array n v a)) =>
Array (1+n) v a -> Array 1 v' (Array n v a)
unravel :: Array (1 + n) v a -> Array 1 v' (Array n v a)
unravel = (Array n v a -> Array 0 v' (Array n v a))
-> Array (1 + n) v a -> Array (1 + 0) v' (Array n v a)
forall (n :: Nat) (i :: Nat) (o :: Nat) (v :: * -> *)
(v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, KnownNat n,
KnownNat o, KnownNat (n + o), KnownNat (1 + o)) =>
(Array i v a -> Array o v' b)
-> Array (n + i) v a -> Array (n + o) v' b
rerank @1 Array n v a -> Array 0 v' (Array n v a)
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> Array 0 v a
scalar
{-# INLINE window #-}
window :: forall n n' v a . (Vector v, KnownNat n, KnownNat n') =>
[Int] -> Array n v a -> Array n' v a
window :: ShapeL -> Array n v a -> Array n' v a
window ShapeL
aws Array n v a
_ | forall i. (KnownNat n', Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
aws Int -> Int -> Int
forall a. Num a => a -> a -> a
+ forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = String -> Array n' v a
forall a. HasCallStack => String -> a
error (String -> Array n' v a) -> String -> Array n' v a
forall a b. (a -> b) -> a -> b
$ String
"window: rank mismatch: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int, Int) -> String
forall a. Show a => a -> String
show (forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int, ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
aws, forall i. (KnownNat n', Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n' :: Int)
window ShapeL
aws (A ShapeL
ash (T ShapeL
ss Int
o v a
v)) = ShapeL -> T v a -> Array n' v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (ShapeL -> ShapeL -> ShapeL
win ShapeL
aws ShapeL
ash) (ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (ShapeL
ss' ShapeL -> ShapeL -> ShapeL
forall a. [a] -> [a] -> [a]
++ ShapeL
ss) Int
o v a
v)
where ss' :: ShapeL
ss' = (Int -> Int -> Int) -> ShapeL -> ShapeL -> ShapeL
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a b. a -> b -> a
const ShapeL
ss ShapeL
aws
win :: ShapeL -> ShapeL -> ShapeL
win (Int
w:ShapeL
ws) (Int
s:ShapeL
sh) | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
s = Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
: ShapeL -> ShapeL -> ShapeL
win ShapeL
ws ShapeL
sh
| Bool
otherwise = String -> ShapeL
forall a. HasCallStack => String -> a
error (String -> ShapeL) -> String -> ShapeL
forall a b. (a -> b) -> a -> b
$ String
"window: bad window size : " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
w, Int
s)
win [] ShapeL
sh = ShapeL
aws ShapeL -> ShapeL -> ShapeL
forall a. [a] -> [a] -> [a]
++ ShapeL
sh
win ShapeL
_ ShapeL
_ = String -> ShapeL
forall a. HasCallStack => String -> a
error (String -> ShapeL) -> String -> ShapeL
forall a b. (a -> b) -> a -> b
$ String
"window: rank mismatch: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> String
forall a. Show a => a -> String
show (ShapeL
aws, ShapeL
ash)
{-# INLINE stride #-}
stride :: (Vector v) => [Int] -> Array n v a -> Array n v a
stride :: ShapeL -> Array n v a -> Array n v a
stride ShapeL
ats (A ShapeL
ash (T ShapeL
ss Int
o v a
v)) = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (ShapeL -> ShapeL -> ShapeL
str ShapeL
ats ShapeL
ash) (ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ((Int -> Int -> Int) -> ShapeL -> ShapeL -> ShapeL
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) (ShapeL
ats ShapeL -> ShapeL -> ShapeL
forall a. [a] -> [a] -> [a]
++ Int -> ShapeL
forall a. a -> [a]
repeat Int
1) ShapeL
ss) Int
o v a
v)
where str :: ShapeL -> ShapeL -> ShapeL
str (Int
t:ShapeL
ts) (Int
s:ShapeL
sh) = (Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
tInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
t Int -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
: ShapeL -> ShapeL -> ShapeL
str ShapeL
ts ShapeL
sh
str [] ShapeL
sh = ShapeL
sh
str ShapeL
_ ShapeL
_ = String -> ShapeL
forall a. HasCallStack => String -> a
error (String -> ShapeL) -> String -> ShapeL
forall a b. (a -> b) -> a -> b
$ String
"stride: rank mismatch: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> String
forall a. Show a => a -> String
show (ShapeL
ats, ShapeL
ash)
rotate :: forall d p v a.
(KnownNat p, KnownNat d,
Vector v, VecElem v a,
(d + (p + 1)) ~ ((p + d) + 1),
(d + p) ~ (p + d),
1 <= p + 1,
KnownNat ((p + d) + 1),
KnownNat (p + 1),
KnownNat (1 + (p + 1))
) =>
Int -> Array (p + d) v a -> Array (p + d + 1) v a
rotate :: Int -> Array (p + d) v a -> Array ((p + d) + 1) v a
rotate Int
k Array (p + d) v a
a = (Array p v a -> Array (p + 1) v a)
-> Array (d + p) v a -> Array (d + (p + 1)) v a
forall (n :: Nat) (i :: Nat) (o :: Nat) (v :: * -> *)
(v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, KnownNat n,
KnownNat o, KnownNat (n + o), KnownNat (1 + o)) =>
(Array i v a -> Array o v' b)
-> Array (n + i) v a -> Array (n + o) v' b
rerank @d @p @(p + 1) Array p v a -> Array (p + 1) v a
f Array (d + p) v a
Array (p + d) v a
a
where
f :: Array p v a -> Array (p + 1) v a
f :: Array p v a -> Array (p + 1) v a
f Array p v a
arr = let Int
h:ShapeL
t = Array p v a -> ShapeL
forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL Array p v a
arr
m :: Int
m = ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
t
n :: Int
n = Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m
arr' :: Array (p + 1) v a
arr' = ShapeL -> Array p v a -> Array (p + 1) v a
forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape @p @(p + 1) (Int
1Int -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
:Int
hInt -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
:ShapeL
t) Array p v a
arr
repeated :: Array (p + 1) v a
repeated = Int -> Array (p + 1) v a -> Array (p + 1) v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, 1 <= n) =>
Int -> Array n v a -> Array n v a
stretchOuter (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Array (p + 1) v a
arr'
flattened :: Array 1 v a
flattened = ShapeL -> Array (p + 1) v a -> Array 1 v a
forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape @(p + 1) @1 [(Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n] Array (p + 1) v a
repeated
batched :: Array 2 v a
batched = ShapeL -> Array 1 v a -> Array 2 v a
forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(Vector v, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
window @1 @2 [Int
n] Array 1 v a
flattened
strided :: Array 2 v a
strided = ShapeL -> Array 2 v a -> Array 2 v a
forall (v :: * -> *) (n :: Nat) a.
Vector v =>
ShapeL -> Array n v a -> Array n v a
stride [Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m] Array 2 v a
batched
in ShapeL -> Array (p + 1) v a -> Array (p + 1) v a
forall (n :: Nat) (v :: * -> *) a.
ShapeL -> Array n v a -> Array n v a
rev [Int
0] (ShapeL -> Array 2 v a -> Array (p + 1) v a
forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape (Int
kInt -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
:Int
hInt -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
:ShapeL
t) Array 2 v a
strided)
{-# INLINE slice #-}
slice :: [(Int, Int)] -> Array n v a -> Array n v a
slice :: [(Int, Int)] -> Array n v a -> Array n v a
slice [(Int, Int)]
asl (A ShapeL
ash (T ShapeL
ats Int
ao v a
v)) = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
rsh (ShapeL -> Int -> v a -> T v a
forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
ats Int
o v a
v)
where (Int
o, ShapeL
rsh) = [(Int, Int)] -> ShapeL -> ShapeL -> (Int, ShapeL)
slc [(Int, Int)]
asl ShapeL
ash ShapeL
ats
slc :: [(Int, Int)] -> ShapeL -> ShapeL -> (Int, ShapeL)
slc ((Int
k,Int
n):[(Int, Int)]
sl) (Int
s:ShapeL
sh) (Int
t:ShapeL
ts) | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
s Bool -> Bool -> Bool
|| Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
s = String -> (Int, ShapeL)
forall a. HasCallStack => String -> a
error String
"slice: out of bounds"
| Bool
otherwise = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t, Int
nInt -> ShapeL -> ShapeL
forall a. a -> [a] -> [a]
:ShapeL
ns) where (Int
i, ShapeL
ns) = [(Int, Int)] -> ShapeL -> ShapeL -> (Int, ShapeL)
slc [(Int, Int)]
sl ShapeL
sh ShapeL
ts
slc [] ShapeL
sh ShapeL
_ = (Int
ao, ShapeL
sh)
slc [(Int, Int)]
_ ShapeL
_ ShapeL
_ = String -> (Int, ShapeL)
forall a. HasCallStack => String -> a
error String
"impossible"
{-# INLINE rerank #-}
rerank :: forall n i o v v' a b .
(Vector v, Vector v', VecElem v a, VecElem v' b
, KnownNat n, KnownNat o, KnownNat (n+o), KnownNat (1+o)) =>
(Array i v a -> Array o v' b) -> Array (n+i) v a -> Array (n+o) v' b
rerank :: (Array i v a -> Array o v' b)
-> Array (n + i) v a -> Array (n + o) v' b
rerank Array i v a -> Array o v' b
f (A ShapeL
sh T v a
t) =
ShapeL -> [Array o v' b] -> Array (n + o) v' b
forall (v :: * -> *) a (m :: Nat) (n :: Nat).
(Vector v, VecElem v a, KnownNat m) =>
ShapeL -> [Array n v a] -> Array m v a
ravelOuter ShapeL
osh ([Array o v' b] -> Array (n + o) v' b)
-> [Array o v' b] -> Array (n + o) v' b
forall a b. (a -> b) -> a -> b
$
(T v a -> Array o v' b) -> [T v a] -> [Array o v' b]
forall a b. (a -> b) -> [a] -> [b]
map (Array i v a -> Array o v' b
f (Array i v a -> Array o v' b)
-> (T v a -> Array i v a) -> T v a -> Array o v' b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeL -> T v a -> Array i v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ish) ([T v a] -> [Array o v' b]) -> [T v a] -> [Array o v' b]
forall a b. (a -> b) -> a -> b
$
ShapeL -> T v a -> [T v a]
forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v a
t
where (ShapeL
osh, ShapeL
ish) = Int -> ShapeL -> (ShapeL, ShapeL)
forall a. Int -> [a] -> ([a], [a])
splitAt (forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n) ShapeL
sh
ravelOuter :: (Vector v, VecElem v a, KnownNat m) => ShapeL -> [Array n v a] -> Array m v a
ravelOuter :: ShapeL -> [Array n v a] -> Array m v a
ravelOuter ShapeL
_ [] = String -> Array m v a
forall a. HasCallStack => String -> a
error String
"ravelOuter: empty list"
ravelOuter ShapeL
osh [Array n v a]
as | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ShapeL] -> Bool
forall a. Eq a => [a] -> Bool
allSame [ShapeL]
shs = String -> Array m v a
forall a. HasCallStack => String -> a
error (String -> Array m v a) -> String -> Array m v a
forall a b. (a -> b) -> a -> b
$ String
"ravelOuter: non-conforming inner dimensions: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [ShapeL] -> String
forall a. Show a => a -> String
show [ShapeL]
shs
| Bool
otherwise = ShapeL -> v a -> Array m v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
sh' (v a -> Array m v a) -> v a -> Array m v a
forall a b. (a -> b) -> a -> b
$ [v a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat ([v a] -> v a) -> [v a] -> v a
forall a b. (a -> b) -> a -> b
$ (Array n v a -> v a) -> [Array n v a] -> [v a]
forall a b. (a -> b) -> [a] -> [b]
map Array n v a -> v a
forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector [Array n v a]
as
where shs :: [ShapeL]
shs@(ShapeL
sh:[ShapeL]
_) = (Array n v a -> ShapeL) -> [Array n v a] -> [ShapeL]
forall a b. (a -> b) -> [a] -> [b]
map Array n v a -> ShapeL
forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL [Array n v a]
as
sh' :: ShapeL
sh' = ShapeL
osh ShapeL -> ShapeL -> ShapeL
forall a. [a] -> [a] -> [a]
++ ShapeL
sh
{-# INLINE rerank2 #-}
rerank2 :: forall n i o a b c v .
(Vector v, VecElem v a, VecElem v b, VecElem v c,
KnownNat n, KnownNat o, KnownNat (n+o), KnownNat (1+o)) =>
(Array i v a -> Array i v b -> Array o v c) -> Array (n+i) v a -> Array (n+i) v b -> Array (n+o) v c
rerank2 :: (Array i v a -> Array i v b -> Array o v c)
-> Array (n + i) v a -> Array (n + i) v b -> Array (n + o) v c
rerank2 Array i v a -> Array i v b -> Array o v c
f (A ShapeL
sha T v a
ta) (A ShapeL
shb T v b
tb) | Int -> ShapeL -> ShapeL
forall a. Int -> [a] -> [a]
take Int
n ShapeL
sha ShapeL -> ShapeL -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> ShapeL -> ShapeL
forall a. Int -> [a] -> [a]
take Int
n ShapeL
shb = String -> Array (n + o) v c
forall a. HasCallStack => String -> a
error String
"rerank2: shape mismatch"
| Bool
otherwise =
ShapeL -> [Array o v c] -> Array (n + o) v c
forall (v :: * -> *) a (m :: Nat) (n :: Nat).
(Vector v, VecElem v a, KnownNat m) =>
ShapeL -> [Array n v a] -> Array m v a
ravelOuter ShapeL
osh ([Array o v c] -> Array (n + o) v c)
-> [Array o v c] -> Array (n + o) v c
forall a b. (a -> b) -> a -> b
$
(T v a -> T v b -> Array o v c)
-> [T v a] -> [T v b] -> [Array o v c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ T v a
a T v b
b -> Array i v a -> Array i v b -> Array o v c
f (ShapeL -> T v a -> Array i v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
isha T v a
a) (ShapeL -> T v b -> Array i v b
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ishb T v b
b))
(ShapeL -> T v a -> [T v a]
forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v a
ta)
(ShapeL -> T v b -> [T v b]
forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v b
tb)
where (ShapeL
osh, ShapeL
isha) = Int -> ShapeL -> (ShapeL, ShapeL)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n ShapeL
sha
ishb :: ShapeL
ishb = Int -> ShapeL -> ShapeL
forall a. Int -> [a] -> [a]
drop Int
n ShapeL
shb
n :: Int
n = forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n
{-# INLINE rev #-}
rev :: [Int] -> Array n v a -> Array n v a
rev :: ShapeL -> Array n v a -> Array n v a
rev ShapeL
rs (A ShapeL
sh T v a
t) | (Int -> Bool) -> ShapeL -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\ Int
r -> Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) ShapeL
rs = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (ShapeL -> ShapeL -> T v a -> T v a
forall (v :: * -> *) a. ShapeL -> ShapeL -> T v a -> T v a
reverseT ShapeL
rs ShapeL
sh T v a
t)
| Bool
otherwise = String -> Array n v a
forall a. HasCallStack => String -> a
error String
"reverse: bad reverse dimension"
where n :: Int
n = ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh
{-# INLINE reduce #-}
reduce :: (Vector v, VecElem v a) =>
(a -> a -> a) -> a -> Array n v a -> Array 0 v a
reduce :: (a -> a -> a) -> a -> Array n v a -> Array 0 v a
reduce a -> a -> a
f a
z (A ShapeL
sh T v a
t) = ShapeL -> T v a -> Array 0 v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [] (T v a -> Array 0 v a) -> T v a -> Array 0 v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
reduceT ShapeL
sh a -> a -> a
f a
z T v a
t
{-# INLINE foldrA #-}
foldrA :: (Vector v, VecElem v a) => (a -> b -> b) -> b -> Array n v a -> b
foldrA :: (a -> b -> b) -> b -> Array n v a -> b
foldrA a -> b -> b
f b
z (A ShapeL
sh T v a
t) = ShapeL -> (a -> b -> b) -> b -> T v a -> b
forall (v :: * -> *) a b.
(Vector v, VecElem v a) =>
ShapeL -> (a -> b -> b) -> b -> T v a -> b
foldrT ShapeL
sh a -> b -> b
f b
z T v a
t
{-# INLINE traverseA #-}
traverseA
:: (Vector v, VecElem v a, VecElem v b, Applicative f)
=> (a -> f b) -> Array n v a -> f (Array n v b)
traverseA :: (a -> f b) -> Array n v a -> f (Array n v b)
traverseA a -> f b
f (A ShapeL
sh T v a
t) = ShapeL -> T v b -> Array n v b
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v b -> Array n v b) -> f (T v b) -> f (Array n v b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeL -> (a -> f b) -> T v a -> f (T v b)
forall (v :: * -> *) a b (f :: * -> *).
(Vector v, VecElem v a, VecElem v b, Applicative f) =>
ShapeL -> (a -> f b) -> T v a -> f (T v b)
traverseT ShapeL
sh a -> f b
f T v a
t
allSameA :: (Vector v, VecElem v a, Eq a) => Array r v a -> Bool
allSameA :: Array r v a -> Bool
allSameA (A ShapeL
sh T v a
t) = ShapeL -> T v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a) =>
ShapeL -> T v a -> Bool
allSameT ShapeL
sh T v a
t
instance (KnownNat r, Vector v, VecElem v a, Arbitrary a) => Arbitrary (Array r v a) where
arbitrary :: Gen (Array r v a)
arbitrary = do
ShapeL
ss <- Int -> Gen Int -> Gen ShapeL
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall i. (KnownNat r, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @r) (Small Int -> Int
forall a. Small a -> a
getSmall (Small Int -> Int)
-> (Positive (Small Int) -> Small Int)
-> Positive (Small Int)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Positive (Small Int) -> Small Int
forall a. Positive a -> a
getPositive (Positive (Small Int) -> Int)
-> Gen (Positive (Small Int)) -> Gen Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Positive (Small Int))
forall a. Arbitrary a => Gen a
arbitrary) Gen ShapeL -> (ShapeL -> Bool) -> Gen ShapeL
forall a. Gen a -> (a -> Bool) -> Gen a
`suchThat` ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10000) (Int -> Bool) -> (ShapeL -> Int) -> ShapeL -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product)
ShapeL -> [a] -> Array r v a
forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> [a] -> Array n v a
fromList ShapeL
ss ([a] -> Array r v a) -> Gen [a] -> Gen (Array r v a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Gen [a]
forall a. Arbitrary a => Int -> Gen [a]
vector (ShapeL -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
ss)
{-# INLINE sumA #-}
sumA :: (Vector v, VecElem v a, Num a) => Array r v a -> a
sumA :: Array r v a -> a
sumA (A ShapeL
sh T v a
t) = ShapeL -> T v a -> a
forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
ShapeL -> T v a -> a
sumT ShapeL
sh T v a
t
{-# INLINE productA #-}
productA :: (Vector v, VecElem v a, Num a) => Array r v a -> a
productA :: Array r v a -> a
productA (A ShapeL
sh T v a
t) = ShapeL -> T v a -> a
forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
ShapeL -> T v a -> a
productT ShapeL
sh T v a
t
{-# INLINE maximumA #-}
maximumA :: (HasCallStack, Vector v, VecElem v a, Ord a) => Array r v a -> a
maximumA :: Array r v a -> a
maximumA a :: Array r v a
a@(A ShapeL
sh T v a
t) | Array r v a -> Int
forall (n :: Nat) (v :: * -> *) a. Array n v a -> Int
size Array r v a
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = ShapeL -> T v a -> a
forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
ShapeL -> T v a -> a
maximumT ShapeL
sh T v a
t
| Bool
otherwise = String -> a
forall a. HasCallStack => String -> a
error String
"maximumA called with empty array"
{-# INLINE minimumA #-}
minimumA :: (HasCallStack, Vector v, VecElem v a, Ord a) => Array r v a -> a
minimumA :: Array r v a -> a
minimumA a :: Array r v a
a@(A ShapeL
sh T v a
t) | Array r v a -> Int
forall (n :: Nat) (v :: * -> *) a. Array n v a -> Int
size Array r v a
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = ShapeL -> T v a -> a
forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
ShapeL -> T v a -> a
minimumT ShapeL
sh T v a
t
| Bool
otherwise = String -> a
forall a. HasCallStack => String -> a
error String
"minimumA called with empty array"
{-# INLINE anyA #-}
anyA :: (Vector v, VecElem v a) => (a -> Bool) -> Array r v a -> Bool
anyA :: (a -> Bool) -> Array r v a -> Bool
anyA a -> Bool
p (A ShapeL
sh T v a
t) = ShapeL -> (a -> Bool) -> T v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> Bool) -> T v a -> Bool
anyT ShapeL
sh a -> Bool
p T v a
t
{-# INLINE allA #-}
allA :: (Vector v, VecElem v a) => (a -> Bool) -> Array r v a -> Bool
allA :: (a -> Bool) -> Array r v a -> Bool
allA a -> Bool
p (A ShapeL
sh T v a
t) = ShapeL -> (a -> Bool) -> T v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> Bool) -> T v a -> Bool
anyT ShapeL
sh a -> Bool
p T v a
t
broadcast :: forall r' r v a .
(HasCallStack, Vector v, VecElem v a, KnownNat r, KnownNat r') =>
[Int] -> ShapeL -> Array r v a -> Array r' v a
broadcast :: ShapeL -> ShapeL -> Array r v a -> Array r' v a
broadcast ShapeL
ds ShapeL
sh Array r v a
a | ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= forall i. (KnownNat r, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @r = String -> Array r' v a
forall a. HasCallStack => String -> a
error String
"broadcast: wrong number of broadcasts"
| (Int -> Bool) -> ShapeL -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\ Int
d -> Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
r) ShapeL
ds = String -> Array r' v a
forall a. HasCallStack => String -> a
error String
"broadcast: bad dimension"
| Bool -> Bool
not (ShapeL -> Bool
forall a. Ord a => [a] -> Bool
ascending ShapeL
ds) = String -> Array r' v a
forall a. HasCallStack => String -> a
error String
"broadcast: unordered dimensions"
| ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
r = String -> Array r' v a
forall a. HasCallStack => String -> a
error String
"broadcast: wrong rank"
| Bool
otherwise = ShapeL -> Array r' v a -> Array r' v a
forall (n :: Nat) (v :: * -> *) a.
HasCallStack =>
ShapeL -> Array n v a -> Array n v a
stretch ShapeL
sh (Array r' v a -> Array r' v a) -> Array r' v a -> Array r' v a
forall a b. (a -> b) -> a -> b
$ ShapeL -> Array r v a -> Array r' v a
forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape ShapeL
rsh Array r v a
a
where r :: Int
r = forall i. (KnownNat r', Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @r'
rsh :: ShapeL
rsh = [ if Int
i Int -> ShapeL -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ShapeL
ds then Int
s else Int
1 | (Int
i, Int
s) <- ShapeL -> ShapeL -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ShapeL
sh ]
ascending :: [a] -> Bool
ascending (a
x:a
y:[a]
ys) = a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
y Bool -> Bool -> Bool
&& [a] -> Bool
ascending (a
ya -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
ys)
ascending [a]
_ = Bool
True
{-# INLINE generate #-}
generate :: forall n v a .
(KnownNat n, Vector v, VecElem v a) =>
ShapeL -> ([Int] -> a) -> Array n v a
generate :: ShapeL -> (ShapeL -> a) -> Array n v a
generate ShapeL
sh | ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = String -> (ShapeL -> a) -> Array n v a
forall a. HasCallStack => String -> a
error (String -> (ShapeL -> a) -> Array n v a)
-> String -> (ShapeL -> a) -> Array n v a
forall a b. (a -> b) -> a -> b
$ String
"generate: rank mismatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (ShapeL -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh, forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
| Bool
otherwise = ShapeL -> T v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (T v a -> Array n v a)
-> ((ShapeL -> a) -> T v a) -> (ShapeL -> a) -> Array n v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeL -> (ShapeL -> a) -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (ShapeL -> a) -> T v a
generateT ShapeL
sh
{-# INLINE iterateN #-}
iterateN :: forall v a .
(Vector v, VecElem v a) =>
Int -> (a -> a) -> a -> Array 1 v a
iterateN :: Int -> (a -> a) -> a -> Array 1 v a
iterateN Int
n a -> a
f = ShapeL -> T v a -> Array 1 v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [Int
n] (T v a -> Array 1 v a) -> (a -> T v a) -> a -> Array 1 v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (a -> a) -> a -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (a -> a) -> a -> T v a
iterateNT Int
n a -> a
f
{-# INLINE iota #-}
iota :: forall v a .
(Vector v, VecElem v a, Enum a, Num a) =>
Int -> Array 1 v a
iota :: Int -> Array 1 v a
iota Int
n = ShapeL -> T v a -> Array 1 v a
forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [Int
n] (T v a -> Array 1 v a) -> T v a -> Array 1 v a
forall a b. (a -> b) -> a -> b
$ Int -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a, Enum a, Num a) =>
Int -> T v a
iotaT Int
n