-- Copyright 2020 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-- | Arrays of static size.  The arrays are polymorphic in the underlying
-- linear data structure used to store the actual values.
module Data.Array.Internal.ShapedG(
  Array(..), Shape(..), Size, Rank, Vector, VecElem,
  Window, Stride, Permute, Permutation, ValidDims,
  Broadcast,
  size, shapeL, rank,
  toList, fromList, toVector, fromVector,
  normalize,
  scalar, unScalar, constant,
  reshape, stretch, stretchOuter, transpose,
  index, pad,
  mapA, zipWithA, zipWith3A,
  append,
  ravel, unravel,
  window, stride,
  slice, rerank, rerank2, rev,
  reduce, foldrA, traverseA,
  allSameA,
  sumA, productA, minimumA, maximumA,
  anyA, allA,
  broadcast,
  generate, iterateN, iota,
  ) where
import Control.DeepSeq
import Data.Data(Data)
import Data.Proxy(Proxy(..))
import GHC.Generics(Generic)
import GHC.Stack(HasCallStack)
import GHC.TypeLits(Nat, type (<=), KnownNat, type (+))
import Test.QuickCheck hiding (generate)
import Text.PrettyPrint.HughesPJClass

import Data.Array.Internal
import Data.Array.Internal.Shape

-- | Arrays stored in a /v/ with values of type /a/.
newtype Array (sh :: [Nat]) v a = A (T v a)
  deriving (forall (sh :: [Nat]) (v :: * -> *) a x.
Rep (Array sh v a) x -> Array sh v a
forall (sh :: [Nat]) (v :: * -> *) a x.
Array sh v a -> Rep (Array sh v a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (sh :: [Nat]) (v :: * -> *) a x.
Rep (Array sh v a) x -> Array sh v a
$cfrom :: forall (sh :: [Nat]) (v :: * -> *) a x.
Array sh v a -> Rep (Array sh v a) x
Generic, Array sh v a -> DataType
Array sh v a -> Constr
forall {sh :: [Nat]} {v :: * -> *} {a}.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Typeable (Array sh v a)
forall (sh :: [Nat]) (v :: * -> *) a.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Array sh v a -> DataType
forall (sh :: [Nat]) (v :: * -> *) a.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Array sh v a -> Constr
forall (sh :: [Nat]) (v :: * -> *) a.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> Array sh v a -> Array sh v a
forall (sh :: [Nat]) (v :: * -> *) a u.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> Array sh v a -> u
forall (sh :: [Nat]) (v :: * -> *) a u.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> Array sh v a -> [u]
forall (sh :: [Nat]) (v :: * -> *) a r r'.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array sh v a -> r
forall (sh :: [Nat]) (v :: * -> *) a r r'.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array sh v a -> r
forall (sh :: [Nat]) (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
forall (sh :: [Nat]) (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
forall (sh :: [Nat]) (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array sh v a)
forall (sh :: [Nat]) (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array sh v a -> c (Array sh v a)
forall (sh :: [Nat]) (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Array sh v a))
forall (sh :: [Nat]) (v :: * -> *) a (t :: * -> * -> *)
       (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array sh v a))
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array sh v a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array sh v a -> c (Array sh v a)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
$cgmapMo :: forall (sh :: [Nat]) (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
$cgmapMp :: forall (sh :: [Nat]) (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
$cgmapM :: forall (sh :: [Nat]) (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> Array sh v a -> m (Array sh v a)
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Array sh v a -> u
$cgmapQi :: forall (sh :: [Nat]) (v :: * -> *) a u.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> Array sh v a -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> Array sh v a -> [u]
$cgmapQ :: forall (sh :: [Nat]) (v :: * -> *) a u.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> Array sh v a -> [u]
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array sh v a -> r
$cgmapQr :: forall (sh :: [Nat]) (v :: * -> *) a r r'.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array sh v a -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array sh v a -> r
$cgmapQl :: forall (sh :: [Nat]) (v :: * -> *) a r r'.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array sh v a -> r
gmapT :: (forall b. Data b => b -> b) -> Array sh v a -> Array sh v a
$cgmapT :: forall (sh :: [Nat]) (v :: * -> *) a.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> Array sh v a -> Array sh v a
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array sh v a))
$cdataCast2 :: forall (sh :: [Nat]) (v :: * -> *) a (t :: * -> * -> *)
       (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array sh v a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Array sh v a))
$cdataCast1 :: forall (sh :: [Nat]) (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Array sh v a))
dataTypeOf :: Array sh v a -> DataType
$cdataTypeOf :: forall (sh :: [Nat]) (v :: * -> *) a.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Array sh v a -> DataType
toConstr :: Array sh v a -> Constr
$ctoConstr :: forall (sh :: [Nat]) (v :: * -> *) a.
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
Array sh v a -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array sh v a)
$cgunfold :: forall (sh :: [Nat]) (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array sh v a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array sh v a -> c (Array sh v a)
$cgfoldl :: forall (sh :: [Nat]) (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable sh, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array sh v a -> c (Array sh v a)
Data)

instance (Vector v, Show a, VecElem v a, Shape sh, Show (v a)) => Show (Array sh v a) where
  showsPrec :: Int -> Array sh v a -> ShowS
showsPrec Int
p a :: Array sh v a
a@(A T v a
_) = Bool -> ShowS -> ShowS
showParen (Int
p forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$
    String -> ShowS
showString String
"fromList @" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showStringString
" " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> [a]
toList Array sh v a
a)

instance (Shape sh, Vector v, Read a, VecElem v a) => Read (Array sh v a) where
  readsPrec :: Int -> ReadS (Array sh v a)
readsPrec Int
p = forall a. Bool -> ReadS a -> ReadS a
readParen (Int
p forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$ \ String
r1 ->
    [(forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
[a] -> Array sh v a
fromList [a]
xs, String
r4)
    | (String
"fromList", String
r2) <- ReadS String
lex String
r1, (String
"@", String
r2') <- ReadS String
lex String
r2
    , (ShapeL
s, String
r3) <- forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
r2', ([a]
xs, String
r4) <- forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
r3
    , ShapeL
s forall a. Eq a => a -> a -> Bool
== forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh), forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
s forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs]

instance (Vector v, Eq a, VecElem v a, Eq (v a), Shape sh)
         => Eq (Array sh v a) where
  a :: Array sh v a
a@(A T v a
v) == :: Array sh v a -> Array sh v a -> Bool
== (A T v a
v') = forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a, Eq (v a)) =>
ShapeL -> T v a -> T v a -> Bool
equalT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
v T v a
v'
  {-# INLINE (==) #-}

instance (Vector v, Ord a, Ord (v a), VecElem v a, Shape sh)
         => Ord (Array sh v a) where
  a :: Array sh v a
a@(A T v a
v) compare :: Array sh v a -> Array sh v a -> Ordering
`compare` (A T v a
v') = forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a, Ord (v a)) =>
ShapeL -> T v a -> T v a -> Ordering
compareT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
v T v a
v'
  {-# INLINE compare #-}

instance (Vector v, Pretty a, VecElem v a, Shape sh) => Pretty (Array sh v a) where
  pPrintPrec :: PrettyLevel -> Rational -> Array sh v a -> Doc
pPrintPrec PrettyLevel
l Rational
p a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Pretty a) =>
PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
ppT PrettyLevel
l Rational
p (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

instance (NFData (v a)) => NFData (Array sh v a) where
  rnf :: Array sh v a -> ()
rnf (A T v a
t) = forall a. NFData a => a -> ()
rnf T v a
t

-- | The number of elements in the array.
{-# INLINE size #-}
size :: forall sh v a . (Shape sh) => Array sh v a -> Int
size :: forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> Int
size Array sh v a
_ = forall (s :: [Nat]). Shape s => Proxy s -> Int
sizeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)

-- | The shape of an array, i.e., a list of the sizes of its dimensions.
-- In the linearization of the array the outermost (i.e. first list element)
-- varies most slowly.
-- O(1) time.
{-# INLINE shapeL #-}
shapeL :: forall sh v a . (Shape sh) => Array sh v a -> ShapeL
shapeL :: forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
_ = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)

-- | The rank of an array, i.e., the number of dimensions it has.
-- O(1) time.
{-# INLINE rank #-}
rank :: forall sh v a . (Shape sh, KnownNat (Rank sh)) => Array sh v a -> Int
rank :: forall (sh :: [Nat]) (v :: * -> *) a.
(Shape sh, KnownNat (Rank sh)) =>
Array sh v a -> Int
rank Array sh v a
_ = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @(Rank sh)

-- | Index into an array.  Fails if the array has rank 0 or if the index is out of bounds.
-- O(1) time.
{-# INLINE index #-}
index :: forall s sh v a . (HasCallStack, Vector v, KnownNat s) =>
         Array (s:sh) v a -> Int -> Array sh v a
index :: forall (s :: Nat) (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, KnownNat s) =>
Array (s : sh) v a -> Int -> Array sh v a
index (A T v a
t) Int
i | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
>= Int
s = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"index: out of bounds " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i forall a. [a] -> [a] -> [a]
++ String
" >= " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
s
              | Bool
otherwise = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t Int
i
  where s :: Int
s = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @s

-- | Convert to a list with the elements in the linearization order.
-- O(n) time.
{-# INLINE toList #-}
toList :: (Vector v, VecElem v a, Shape sh) => Array sh v a -> [a]
toList :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> [a]
toList a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> [a]
toListT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

-- | Convert to a vector with the elements in the linearization order.
-- O(n) or O(1) time (the latter if the vector is already in the linearization order).
{-# INLINE toVector #-}
toVector :: (Vector v, VecElem v a, Shape sh) => Array sh v a -> v a
toVector :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> v a
toVectorT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

-- | Convert from a list with the elements given in the linearization order.
-- Fails if the given shape does not have the same number of elements as the list.
-- O(n) time.
{-# INLINE fromList #-}
fromList :: forall sh v a . (HasCallStack, Vector v, VecElem v a, Shape sh) =>
            [a] -> Array sh v a
fromList :: forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
[a] -> Array sh v a
fromList [a]
vs | Int
n forall a. Eq a => a -> a -> Bool
/= Int
l = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"fromList: size mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
n, Int
l)
            | Bool
otherwise = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList [a]
vs
  where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
ss
        l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
vs
        ss :: ShapeL
ss = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)

-- | Convert from a vector with the elements given in the linearization order.
-- Fails if the given shape does not have the same number of elements as the list.
-- O(1) time.
{-# INLINE fromVector #-}
fromVector :: forall sh v a . (HasCallStack, Vector v, VecElem v a, Shape sh) =>
              v a -> Array sh v a
fromVector :: forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
v a -> Array sh v a
fromVector v a
v | Int
n forall a. Eq a => a -> a -> Bool
/= Int
l = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"fromVector: size mismatch" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
n, Int
l)
             | Bool
otherwise = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 v a
v
  where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
ss
        l :: Int
l = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v
        ss :: ShapeL
ss = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)

-- | Make sure the underlying vector is in the linearization order.
-- This is semantically an identity function, but can have big performance
-- implications.
-- O(n) or O(1) time.
{-# INLINE normalize #-}
normalize :: (Vector v, VecElem v a, Shape sh) => Array sh v a -> Array sh v a
normalize :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> Array sh v a
normalize = forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
v a -> Array sh v a
fromVector forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector

-- | Change the shape of an array.  Type error if the arrays have different number of elements.
-- O(n) or O(1) time.
{-# INLINE reshape #-}
reshape :: forall sh' sh v a .
           (Vector v, VecElem v a, Shape sh, Shape sh', Size sh ~ Size sh') =>
           Array sh v a -> Array sh' v a
reshape :: forall (sh' :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(Vector v, VecElem v a, Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh v a -> Array sh' v a
reshape Array sh v a
a = forall (v :: * -> *) a (sh :: [Nat]) (sh' :: [Nat]).
(Vector v, VecElem v a) =>
ShapeL -> ShapeL -> Array sh v a -> Array sh' v a
reshape' (forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh')) (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) Array sh v a
a

reshape' :: (Vector v, VecElem v a) =>
            ShapeL -> ShapeL -> Array sh v a -> Array sh' v a
reshape' :: forall (v :: * -> *) a (sh :: [Nat]) (sh' :: [Nat]).
(Vector v, VecElem v a) =>
ShapeL -> ShapeL -> Array sh v a -> Array sh' v a
reshape' ShapeL
sh ShapeL
sh' (A t :: T v a
t@(T ShapeL
ost Int
oo v a
v))
  | forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v forall a. Eq a => a -> a -> Bool
== Int
1 = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Int
0) ShapeL
sh) Int
0 v a
v  -- Fast special case for singleton vector
  | Just ShapeL
nst <- ShapeL -> ShapeL -> ShapeL -> Maybe ShapeL
simpleReshape ShapeL
ost ShapeL
sh' ShapeL
sh = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
nst Int
oo v a
v
  | Bool
otherwise = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> v a -> T v a
fromVectorT ShapeL
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> v a
toVectorT ShapeL
sh' T v a
t

-- | Change the size of dimensions with size 1.  These dimension can be changed to any size.
-- All other dimensions must remain the same.
-- O(1) time.
{-# INLINE stretch #-}
stretch :: forall sh' sh v a . (Shape sh, Shape sh', ValidStretch sh sh') =>
           Array sh v a -> Array sh' v a
stretch :: forall (sh' :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(Shape sh, Shape sh', ValidStretch sh sh') =>
Array sh v a -> Array sh' v a
stretch = forall (sh :: [Nat]) (v :: * -> *) a (sh' :: [Nat]).
[Bool] -> Array sh v a -> Array sh' v a
stretch' (forall (from :: [Nat]) (to :: [Nat]).
ValidStretch from to =>
Proxy from -> Proxy to -> [Bool]
stretching (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh) (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh'))

stretch' :: [Bool] -> Array sh v a -> Array sh' v a
stretch' :: forall (sh :: [Nat]) (v :: * -> *) a (sh' :: [Nat]).
[Bool] -> Array sh v a -> Array sh' v a
stretch' [Bool]
str (A T v a
vs) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT [Bool]
str T v a
vs

-- | Change the size of the outermost dimension by replication.
{-# INLINE stretchOuter #-}
stretchOuter :: forall s sh v a . (Shape sh) =>
                Array (1 : sh) v a -> Array (s : sh) v a
stretchOuter :: forall (s :: Nat) (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array (1 : sh) v a -> Array (s : sh) v a
stretchOuter (A T v a
vs) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT (Bool
True forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Bool
False) (forall (v :: * -> *) a. T v a -> ShapeL
strides T v a
vs)) T v a
vs

-- | Convert a value to a scalar (rank 0) array.
-- O(1) time.
{-# INLINE scalar #-}
scalar :: (Vector v, VecElem v a) => a -> Array '[] v a
scalar :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> Array '[] v a
scalar = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT

-- | Convert a scalar (rank 0) array to a value.
-- O(1) time.
{-# INLINE unScalar #-}
unScalar :: (Vector v, VecElem v a) => Array '[] v a -> a
unScalar :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Array '[] v a -> a
unScalar (A T v a
t) = forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT T v a
t

-- | Make an array with all elements having the same value.
-- O(1) time.
{-# INLINE constant #-}
constant :: forall sh v a . (Vector v, VecElem v a, Shape sh) =>
            a -> Array sh v a
constant :: forall (sh :: [Nat]) (v :: * -> *) a.
(Vector v, VecElem v a, Shape sh) =>
a -> Array sh v a
constant = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> a -> T v a
constantT (forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh))

-- | Map over the array elements.
-- O(n) time.
{-# INLINE mapA #-}
mapA :: (Vector v, VecElem v a, VecElem v b, Shape sh) =>
        (a -> b) -> Array sh v a -> Array sh v b
mapA :: forall (v :: * -> *) a b (sh :: [Nat]).
(Vector v, VecElem v a, VecElem v b, Shape sh) =>
(a -> b) -> Array sh v a -> Array sh v b
mapA a -> b
f a :: Array sh v a
a@(A T v a
t) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
ShapeL -> (a -> b) -> T v a -> T v b
mapT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> b
f T v a
t

-- | Map over the array elements.
-- O(n) time.
{-# INLINE zipWithA #-}
zipWithA :: (Vector v, VecElem v a, VecElem v b, VecElem v c, Shape sh) =>
            (a -> b -> c) -> Array sh v a -> Array sh v b -> Array sh v c
zipWithA :: forall (v :: * -> *) a b c (sh :: [Nat]).
(Vector v, VecElem v a, VecElem v b, VecElem v c, Shape sh) =>
(a -> b -> c) -> Array sh v a -> Array sh v b -> Array sh v c
zipWithA a -> b -> c
f a :: Array sh v a
a@(A T v a
t) (A T v b
t') = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> b -> c
f T v a
t T v b
t'

-- | Map over the array elements.
-- O(n) time.
{-# INLINE zipWith3A #-}
zipWith3A :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, Shape sh) =>
             (a -> b -> c -> d) -> Array sh v a -> Array sh v b -> Array sh v c -> Array sh v d
zipWith3A :: forall (v :: * -> *) a b c d (sh :: [Nat]).
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 Shape sh) =>
(a -> b -> c -> d)
-> Array sh v a -> Array sh v b -> Array sh v c -> Array sh v d
zipWith3A a -> b -> c -> d
f a :: Array sh v a
a@(A T v a
t) (A T v b
t') (A T v c
t'') = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> b -> c -> d
f T v a
t T v b
t' T v c
t''

-- | Pad each dimension on the low and high side with the given value.
-- O(n) time.
{-# INLINE pad #-}
pad :: forall ps sh' sh a v . (HasCallStack, Vector v, VecElem v a, Padded ps sh sh', Shape sh) =>
       a -> Array sh v a -> Array sh' v a
pad :: forall (ps :: [(Nat, Nat)]) (sh' :: [Nat]) (sh :: [Nat]) a
       (v :: * -> *).
(HasCallStack, Vector v, VecElem v a, Padded ps sh sh',
 Shape sh) =>
a -> Array sh v a -> Array sh' v a
pad a
v a :: Array sh v a
a@(A T v a
at) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> [(Int, Int)] -> ShapeL -> T v a -> (ShapeL, T v a)
padT a
v [(Int, Int)]
aps ShapeL
ash T v a
at
  where ash :: ShapeL
ash = forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a
        aps :: [(Int, Int)]
aps = forall (ps :: [(Nat, Nat)]) (sh :: [Nat]) (sh' :: [Nat]).
Padded ps sh sh' =>
Proxy ps -> Proxy sh -> [(Int, Int)]
padded (forall {k} (t :: k). Proxy t
Proxy :: Proxy ps) (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)

-- | Do an arbitrary array transposition.
-- Fails if the transposition argument is not a permutation of the numbers
-- [0..r-1], where r is the rank of the array.
-- O(1) time.
{-# INLINE transpose #-}
transpose :: forall is sh v a .
             (Permutation is, Rank is <= Rank sh, Shape sh, Shape is, KnownNat (Rank sh)) =>
             Array sh v a -> Array (Permute is sh) v a
transpose :: forall (is :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(Permutation is, Rank is <= Rank sh, Shape sh, Shape is,
 KnownNat (Rank sh)) =>
Array sh v a -> Array (Permute is sh) v a
transpose (A T v a
t) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A (forall (v :: * -> *) a. ShapeL -> T v a -> T v a
transposeT ShapeL
is' T v a
t)
  where l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
is
        n :: Int
n = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @(Rank sh)
        is' :: ShapeL
is' = ShapeL
is forall a. [a] -> [a] -> [a]
++ [Int
l .. Int
nforall a. Num a => a -> a -> a
-Int
1]
        is :: ShapeL
is = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy is)

-- | Append two arrays along the outermost dimension.
-- All dimensions, except the outermost, must be the same.
-- O(n) time.
{-# INLINE append #-}
append :: (Vector v, VecElem v a, Shape sh, KnownNat m, KnownNat n, KnownNat (m+n)) =>
          Array (m ': sh) v a -> Array (n ': sh) v a -> Array (m+n ': sh) v a
append :: forall (v :: * -> *) a (sh :: [Nat]) (m :: Nat) (n :: Nat).
(Vector v, VecElem v a, Shape sh, KnownNat m, KnownNat n,
 KnownNat (m + n)) =>
Array (m : sh) v a
-> Array (n : sh) v a -> Array ((m + n) : sh) v a
append Array (m : sh) v a
a Array (n : sh) v a
b = forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
v a -> Array sh v a
fromVector (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
v a -> v a -> v a
vAppend (forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector Array (m : sh) v a
a) (forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector Array (n : sh) v a
b))

-- | Turn a rank-1 array of arrays into a single array by making the outer array into the outermost
-- dimension of the result array.  All the arrays must have the same shape.
-- O(n) time.
{-# INLINE ravel #-}
ravel :: (Vector v, Vector v', VecElem v a, VecElem v' (Array sh v a)
         , Shape sh, KnownNat s) =>
         Array '[s] v' (Array sh v a) -> Array (s:sh) v a
ravel :: forall (v :: * -> *) (v' :: * -> *) a (sh :: [Nat]) (s :: Nat).
(Vector v, Vector v', VecElem v a, VecElem v' (Array sh v a),
 Shape sh, KnownNat s) =>
Array '[s] v' (Array sh v a) -> Array (s : sh) v a
ravel = forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
v a -> Array sh v a
fromVector forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> [a]
toList

-- | Turn an array into a nested array, this is the inverse of 'ravel'.
-- I.e., @ravel . unravel == id@.
-- O(n) time.
{-# INLINE unravel #-}
unravel :: (Vector v, Vector v', VecElem v a, VecElem v' (Array sh v a)
           , Shape sh, KnownNat s) =>
           Array (s:sh) v a -> Array '[s] v' (Array sh v a)
unravel :: forall (v :: * -> *) (v' :: * -> *) a (sh :: [Nat]) (s :: Nat).
(Vector v, Vector v', VecElem v a, VecElem v' (Array sh v a),
 Shape sh, KnownNat s) =>
Array (s : sh) v a -> Array '[s] v' (Array sh v a)
unravel = forall (n :: Nat) (i :: [Nat]) (o :: [Nat]) (sh :: [Nat])
       (v :: * -> *) (v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, Drop n sh ~ i,
 Shape sh, KnownNat n, Shape o, Shape (Take n sh ++ o)) =>
(Array i v a -> Array o v' b)
-> Array sh v a -> Array (Take n sh ++ o) v' b
rerank @1 forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> Array '[] v a
scalar

-- | Make a window of the outermost dimensions.
-- The rank increases with the length of the window list.
-- E.g., if the shape of the array is @[10,12,8]@ and
-- the window size is @[3,3]@ then the resulting array will have shape
-- @[8,10,3,3,8]@.
--
-- E.g., @window [2] (fromList [4] [1,2,3,4]) == fromList [3,2] [1,2, 2,3, 3,4]@
-- O(1) time.
--
-- If the window parameter @ws = [w1,...,wk]@ and @wa = window ws a@ then
-- @wa `index` i1 ... `index` ik == slice [(i1,w1),...,(ik,wk)] a@.
{-# INLINE window #-}
window :: forall ws sh' sh v a .
          (Window ws sh sh', Vector v, KnownNat (Rank ws)) =>
          Array sh v a -> Array sh' v a
window :: forall (ws :: [Nat]) (sh' :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(Window ws sh sh', Vector v, KnownNat (Rank ws)) =>
Array sh v a -> Array sh' v a
window (A (T ShapeL
ss Int
o v a
v)) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A (forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (ShapeL
ss' forall a. [a] -> [a] -> [a]
++ ShapeL
ss) Int
o v a
v)
  where ss' :: ShapeL
ss' = forall a. Int -> [a] -> [a]
take (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @(Rank ws)) ShapeL
ss

-- | Stride the outermost dimensions.
-- E.g., if the array shape is @[10,12,8]@ and the strides are
-- @[2,2]@ then the resulting shape will be @[5,6,8]@.
-- O(1) time.
{-# INLINE stride #-}
stride :: forall ts sh' sh v a .
          (Stride ts sh sh', Vector v, Shape ts) =>
          Array sh v a -> Array sh' v a
stride :: forall (ts :: [Nat]) (sh' :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(Stride ts sh sh', Vector v, Shape ts) =>
Array sh v a -> Array sh' v a
stride (A (T ShapeL
ss Int
o v a
v)) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A (forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) (ShapeL
ats forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat Int
1) ShapeL
ss) Int
o v a
v)
  where ats :: ShapeL
ats = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy ts)

-- | Extract a slice of an array.
-- The first type argument is a list of (offset, length) pairs.
-- The length of the slicing argument must not exceed the rank of the array.
-- The extracted slice must fall within the array dimensions.
-- E.g. @slice @'[ '(1,2)] (fromList @'[4] [1,2,3,4]) == fromList @'[2] [2,3]@.
-- O(1) time.
{-# INLINE slice #-}
slice :: forall sl sh' sh v a .
         (Slice sl sh sh') =>
         Array sh v a -> Array sh' v a
slice :: forall (sl :: [(Nat, Nat)]) (sh' :: [Nat]) (sh :: [Nat])
       (v :: * -> *) a.
Slice sl sh sh' =>
Array sh v a -> Array sh' v a
slice (A (T ShapeL
ts Int
o v a
v)) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A (forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
ts (Int
oforall a. Num a => a -> a -> a
+Int
i) v a
v)
  where i :: Int
i = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) ShapeL
ts forall a b. (a -> b) -> a -> b
$ forall (ls :: [(Nat, Nat)]) (ss :: [Nat]) (rs :: [Nat]).
Slice ls ss rs =>
Proxy ls -> Proxy ss -> ShapeL
sliceOffsets (forall {k} (t :: k). Proxy t
Proxy :: Proxy sl) (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)

-- | Apply a function to the subarrays /n/ levels down and make
-- the results into an array with the same /n/ outermost dimensions.
-- The /n/ must not exceed the rank of the array.
-- O(n) time.
{-# INLINE rerank #-}
rerank :: forall n i o sh v v' a b .
          (Vector v, Vector v', VecElem v a, VecElem v' b,
           Drop n sh ~ i, Shape sh, KnownNat n, Shape o, Shape (Take n sh ++ o)) =>
          (Array i v a -> Array o v' b) -> Array sh v a -> Array (Take n sh ++ o) v' b
rerank :: forall (n :: Nat) (i :: [Nat]) (o :: [Nat]) (sh :: [Nat])
       (v :: * -> *) (v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, Drop n sh ~ i,
 Shape sh, KnownNat n, Shape o, Shape (Take n sh ++ o)) =>
(Array i v a -> Array o v' b)
-> Array sh v a -> Array (Take n sh ++ o) v' b
rerank Array i v a -> Array o v' b
f a :: Array sh v a
a@(A T v a
t) =
  forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
v a -> Array sh v a
fromVector forall a b. (a -> b) -> a -> b
$
  forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$
  forall a b. (a -> b) -> [a] -> [b]
map (forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array i v a -> Array o v' b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A) forall a b. (a -> b) -> a -> b
$
  forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v a
t
  where osh :: ShapeL
osh = forall a. Int -> [a] -> [a]
take (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n) (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a)

-- | Apply a two-argument function to the subarrays /n/ levels down and make
-- the results into an array with the same /n/ outermost dimensions.
-- The /n/ must not exceed the rank of the array.
-- O(n) time.
{-# INLINE rerank2 #-}
rerank2 :: forall n i1 i2 o sh1 sh2 r v a b c .
           (Vector v, VecElem v a, VecElem v b, VecElem v c,
            Drop n sh1 ~ i1, Drop n sh2 ~ i2, Shape sh1, Shape sh2,
            Take n sh1 ~ r, Take n sh2 ~ r,
            KnownNat n, Shape o, Shape (r ++ o)) =>
           (Array i1 v a -> Array i2 v b -> Array o v c) -> Array sh1 v a -> Array sh2 v b -> Array (r ++ o) v c
rerank2 :: forall (n :: Nat) (i1 :: [Nat]) (i2 :: [Nat]) (o :: [Nat])
       (sh1 :: [Nat]) (sh2 :: [Nat]) (r :: [Nat]) (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c, Drop n sh1 ~ i1,
 Drop n sh2 ~ i2, Shape sh1, Shape sh2, Take n sh1 ~ r,
 Take n sh2 ~ r, KnownNat n, Shape o, Shape (r ++ o)) =>
(Array i1 v a -> Array i2 v b -> Array o v c)
-> Array sh1 v a -> Array sh2 v b -> Array (r ++ o) v c
rerank2 Array i1 v a -> Array i2 v b -> Array o v c
f aa :: Array sh1 v a
aa@(A T v a
ta) (A T v b
tb) =
  forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
v a -> Array sh v a
fromVector forall a b. (a -> b) -> a -> b
$
  forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$
  forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ T v a
a T v b
b -> forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
Array sh v a -> v a
toVector forall a b. (a -> b) -> a -> b
$ Array i1 v a -> Array i2 v b -> Array o v c
f (forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A T v a
a) (forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A T v b
b))
          (forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v a
ta)
          (forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v b
tb)
  where osh :: ShapeL
osh = forall a. Int -> [a] -> [a]
take (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n) (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh1 v a
aa)


-- | Reverse the given dimensions, with the outermost being dimension 0.
-- O(1) time.
{-# INLINE rev #-}
rev :: forall rs sh v a . (ValidDims rs sh, Shape rs, Shape sh) => Array sh v a -> Array sh v a
rev :: forall (rs :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(ValidDims rs sh, Shape rs, Shape sh) =>
Array sh v a -> Array sh v a
rev a :: Array sh v a
a@(A T v a
t) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A (forall (v :: * -> *) a. ShapeL -> ShapeL -> T v a -> T v a
reverseT ShapeL
rs ShapeL
sh T v a
t)
  where rs :: ShapeL
rs = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy rs)
        sh :: ShapeL
sh = forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a

-- | Reduce all elements of an array into a rank 0 array.
-- To reduce parts use 'rerank' and 'transpose' together with 'reduce'.
-- O(n) time.
{-# INLINE reduce #-}
reduce :: (Vector v, VecElem v a, Shape sh) =>
          (a -> a -> a) -> a -> Array sh v a -> Array '[] v a
reduce :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
(a -> a -> a) -> a -> Array sh v a -> Array '[] v a
reduce a -> a -> a
f a
z a :: Array sh v a
a@(A T v a
t) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
reduceT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> a -> a
f a
z T v a
t

-- | Right fold across all elements of an array.
{-# INLINE foldrA #-}
foldrA
  :: (Vector v, VecElem v a, Shape sh)
  => (a -> b -> b) -> b -> Array sh v a -> b
foldrA :: forall (v :: * -> *) a (sh :: [Nat]) b.
(Vector v, VecElem v a, Shape sh) =>
(a -> b -> b) -> b -> Array sh v a -> b
foldrA a -> b -> b
f b
z a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a b.
(Vector v, VecElem v a) =>
ShapeL -> (a -> b -> b) -> b -> T v a -> b
foldrT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> b -> b
f b
z T v a
t

-- | Constrained version of 'traverse' for 'Array's.
{-# INLINE traverseA #-}
traverseA
  :: (Vector v, VecElem v a, VecElem v b, Applicative f, Shape sh)
  => (a -> f b) -> Array sh v a -> f (Array sh v b)
traverseA :: forall (v :: * -> *) a b (f :: * -> *) (sh :: [Nat]).
(Vector v, VecElem v a, VecElem v b, Applicative f, Shape sh) =>
(a -> f b) -> Array sh v a -> f (Array sh v b)
traverseA a -> f b
f a :: Array sh v a
a@(A T v a
t) = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (v :: * -> *) a b (f :: * -> *).
(Vector v, VecElem v a, VecElem v b, Applicative f) =>
ShapeL -> (a -> f b) -> T v a -> f (T v b)
traverseT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> f b
f T v a
t

-- | Check if all elements of the array are equal.
allSameA :: (Shape sh, Vector v, VecElem v a, Eq a) => Array sh v a -> Bool
allSameA :: forall (sh :: [Nat]) (v :: * -> *) a.
(Shape sh, Vector v, VecElem v a, Eq a) =>
Array sh v a -> Bool
allSameA a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a) =>
ShapeL -> T v a -> Bool
allSameT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

instance (Shape sh, Vector v, VecElem v a, Arbitrary a) => Arbitrary (Array sh v a) where
  arbitrary :: Gen (Array sh v a)
arbitrary = forall (sh :: [Nat]) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, Shape sh) =>
[a] -> Array sh v a
fromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Int -> Gen [a]
vector (forall (s :: [Nat]). Shape s => Proxy s -> Int
sizeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh))

-- | Sum of all elements.
{-# INLINE sumA #-}
sumA :: (Vector v, VecElem v a, Num a, Shape sh) => Array sh v a -> a
sumA :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Num a, Shape sh) =>
Array sh v a -> a
sumA a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
ShapeL -> T v a -> a
sumT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

-- | Product of all elements.
{-# INLINE productA #-}
productA :: (Vector v, VecElem v a, Num a, Shape sh) => Array sh v a -> a
productA :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Num a, Shape sh) =>
Array sh v a -> a
productA a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
ShapeL -> T v a -> a
productT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

-- | Maximum of all elements.
{-# INLINE maximumA #-}
maximumA :: (Vector v, VecElem v a, Ord a, Shape sh, 1 <= Size sh) => Array sh v a -> a
maximumA :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Ord a, Shape sh, 1 <= Size sh) =>
Array sh v a -> a
maximumA a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
ShapeL -> T v a -> a
maximumT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

-- | Minimum of all elements.
{-# INLINE minimumA #-}
minimumA :: (Vector v, VecElem v a, Ord a, Shape sh, 1 <= Size sh) => Array sh v a -> a
minimumA :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Ord a, Shape sh, 1 <= Size sh) =>
Array sh v a -> a
minimumA a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
ShapeL -> T v a -> a
minimumT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) T v a
t

-- | Test if the predicate holds for any element.
{-# INLINE anyA #-}
anyA :: (Vector v, VecElem v a, Shape sh) => (a -> Bool) -> Array sh v a -> Bool
anyA :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
(a -> Bool) -> Array sh v a -> Bool
anyA a -> Bool
p a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> Bool) -> T v a -> Bool
anyT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> Bool
p T v a
t

-- | Test if the predicate holds for all elements.
{-# INLINE allA #-}
allA :: (Vector v, VecElem v a, Shape sh) => (a -> Bool) -> Array sh v a -> Bool
allA :: forall (v :: * -> *) a (sh :: [Nat]).
(Vector v, VecElem v a, Shape sh) =>
(a -> Bool) -> Array sh v a -> Bool
allA a -> Bool
p a :: Array sh v a
a@(A T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> Bool) -> T v a -> Bool
allT (forall (sh :: [Nat]) (v :: * -> *) a.
Shape sh =>
Array sh v a -> ShapeL
shapeL Array sh v a
a) a -> Bool
p T v a
t

-- | Put the dimensions of the argument into the specified dimensions,
-- and just replicate the data along all other dimensions.
-- The list of dimensions indicies must have the same rank as the argument array
-- and it must be strictly ascending.
broadcast :: forall ds sh' sh v a .
             (Shape sh, Shape sh',
              Broadcast ds sh sh',
              Vector v, VecElem v a) =>
             Array sh v a -> Array sh' v a
broadcast :: forall (ds :: [Nat]) (sh' :: [Nat]) (sh :: [Nat]) (v :: * -> *) a.
(Shape sh, Shape sh', Broadcast ds sh sh', Vector v,
 VecElem v a) =>
Array sh v a -> Array sh' v a
broadcast Array sh v a
a = forall (sh :: [Nat]) (v :: * -> *) a (sh' :: [Nat]).
[Bool] -> Array sh v a -> Array sh' v a
stretch' [Bool]
bc forall a b. (a -> b) -> a -> b
$
              forall (v :: * -> *) a (sh :: [Nat]) (sh' :: [Nat]).
(Vector v, VecElem v a) =>
ShapeL -> ShapeL -> Array sh v a -> Array sh' v a
reshape' ShapeL
sh ShapeL
rsh Array sh v a
a
  where sh' :: ShapeL
sh' = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh')
        sh :: ShapeL
sh = forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh)
        rsh :: ShapeL
rsh = [ if Bool
b then Int
1 else Int
s | (Int
s, Bool
b) <- forall a b. [a] -> [b] -> [(a, b)]
zip ShapeL
sh' [Bool]
bc ]
        bc :: [Bool]
bc = forall (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]).
Broadcast ds sh sh' =>
[Bool]
broadcasting @ds @sh @sh'

-- | Generate an array with a function that computes the value for each index.
{-# INLINE generate #-}
generate :: forall sh v a .
            (Vector v, VecElem v a, Shape sh) =>
            ([Int] -> a) -> Array sh v a
generate :: forall (sh :: [Nat]) (v :: * -> *) a.
(Vector v, VecElem v a, Shape sh) =>
(ShapeL -> a) -> Array sh v a
generate = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (ShapeL -> a) -> T v a
generateT (forall (s :: [Nat]). Shape s => Proxy s -> ShapeL
shapeP (forall {k} (t :: k). Proxy t
Proxy :: Proxy sh))

-- | Iterate a function n times.
{-# INLINE iterateN #-}
iterateN :: forall n v a .
            (Vector v, VecElem v a, KnownNat n) =>
            (a -> a) -> a -> Array '[n] v a
iterateN :: forall (n :: Nat) (v :: * -> *) a.
(Vector v, VecElem v a, KnownNat n) =>
(a -> a) -> a -> Array '[n] v a
iterateN a -> a
f = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (a -> a) -> a -> T v a
iterateNT (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n) a -> a
f

-- | Generate a vector from 0 to n-1.
{-# INLINE iota #-}
iota :: forall n v a .
        (Vector v, VecElem v a, KnownNat n, Enum a, Num a) =>
        Array '[n] v a
iota :: forall (n :: Nat) (v :: * -> *) a.
(Vector v, VecElem v a, KnownNat n, Enum a, Num a) =>
Array '[n] v a
iota = forall (sh :: [Nat]) (v :: * -> *) a. T v a -> Array sh v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a, Enum a, Num a) =>
Int -> T v a
iotaT (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n)