-- Copyright 2020 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
module Data.Array.Internal(module Data.Array.Internal) where
import Control.DeepSeq
import Data.Data(Data)
import qualified Data.DList as DL
import Data.Kind (Type)
import Data.List(foldl', zipWith4, zipWith5, sortBy, sortOn)
import Data.Proxy
import GHC.Exts(Constraint, build)
import GHC.Generics(Generic)
import GHC.TypeLits(KnownNat, natVal)
import Text.PrettyPrint
import Text.PrettyPrint.HughesPJClass

{- HLINT ignore "Reduce duplication" -}

-- The underlying storage of values must be an instance of Vector.
-- For some types, like unboxed vectors, we require an extra
-- constraint on the elements, which VecElem allows you to express.
-- For vector types that don't need the constraint it can be set
-- to some dummy class.
-- | The 'Vector' class is the interface to the underlying storage for the arrays.
-- The operations map straight to operations for 'Vector'.
class Vector v where
  type VecElem v :: Type -> Constraint
  vIndex    :: (VecElem v a) => v a -> Int -> a
  vLength   :: (VecElem v a) => v a -> Int
  vToList   :: (VecElem v a) => v a -> [a]
  vFromList :: (VecElem v a) => [a] -> v a
  vSingleton:: (VecElem v a) => a -> v a
  vReplicate:: (VecElem v a) => Int -> a -> v a
  vMap      :: (VecElem v a, VecElem v b) => (a -> b) -> v a -> v b
  vZipWith  :: (VecElem v a, VecElem v b, VecElem v c) => (a -> b -> c) -> v a -> v b -> v c
  vZipWith3 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d) => (a -> b -> c -> d) -> v a -> v b -> v c -> v d
  vZipWith4 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e) => (a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
  vZipWith5 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e, VecElem v f) => (a -> b -> c -> d -> e -> f) -> v a -> v b -> v c -> v d -> v e -> v f
  vAppend   :: (VecElem v a) => v a -> v a -> v a
  vConcat   :: (VecElem v a) => [v a] -> v a
  vFold     :: (VecElem v a) => (a -> a -> a) -> a -> v a -> a
  vSlice    :: (VecElem v a) => Int -> Int -> v a -> v a
  vSum      :: (VecElem v a, Num a) => v a -> a
  vProduct  :: (VecElem v a, Num a) => v a -> a
  vMaximum  :: (VecElem v a, Ord a) => v a -> a
  vMinimum  :: (VecElem v a, Ord a) => v a -> a
  vUpdate   :: (VecElem v a) => v a -> [(Int, a)] -> v a
  vGenerate :: (VecElem v a) => Int -> (Int -> a) -> v a
  vAll      :: (VecElem v a) => (a -> Bool) -> v a -> Bool
  vAny      :: (VecElem v a) => (a -> Bool) -> v a -> Bool

class None a
instance None a

-- This instance is not used anywheer.  It serves more as a reference semantics.
instance Vector [] where
  type VecElem [] = None
  vIndex :: [a] -> Int -> a
vIndex = [a] -> Int -> a
forall a. [a] -> Int -> a
(!!)
  vLength :: [a] -> Int
vLength = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length
  vToList :: [a] -> [a]
vToList = [a] -> [a]
forall a. a -> a
id
  vFromList :: [a] -> [a]
vFromList = [a] -> [a]
forall a. a -> a
id
  vSingleton :: a -> [a]
vSingleton = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  vReplicate :: Int -> a -> [a]
vReplicate = Int -> a -> [a]
forall a. Int -> a -> [a]
replicate
  vMap :: (a -> b) -> [a] -> [b]
vMap = (a -> b) -> [a] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map
  vZipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
vZipWith = (a -> b -> c) -> [a] -> [b] -> [c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
  vZipWith3 :: (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
vZipWith3 = (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
  vZipWith4 :: (a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
vZipWith4 = (a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4
  vZipWith5 :: (a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
vZipWith5 = (a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5
  vAppend :: [a] -> [a] -> [a]
vAppend = [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++)
  vConcat :: [[a]] -> [a]
vConcat = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
  vFold :: (a -> a -> a) -> a -> [a] -> a
vFold = (a -> a -> a) -> a -> [a] -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
  vSlice :: Int -> Int -> [a] -> [a]
vSlice Int
o Int
n = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
n ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop Int
o
  vSum :: [a] -> a
vSum = [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
  vProduct :: [a] -> a
vProduct = [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
  vMaximum :: [a] -> a
vMaximum = [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum
  vMinimum :: [a] -> a
vMinimum = [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum
  vUpdate :: [a] -> [(Int, a)] -> [a]
vUpdate [a]
xs [(Int, a)]
us = [a] -> [(Int, a)] -> Int -> [a]
forall a a. (Ord a, Num a) => [a] -> [(a, a)] -> a -> [a]
loop [a]
xs (((Int, a) -> Int) -> [(Int, a)] -> [(Int, a)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, a) -> Int
forall a b. (a, b) -> a
fst [(Int, a)]
us) Int
0
    where
      loop :: [a] -> [(a, a)] -> a -> [a]
loop [] [] a
_ = []
      loop [] ((a, a)
_:[(a, a)]
_) a
_ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"vUpdate: out of bounds"
      loop [a]
as [] a
_ = [a]
as
      loop (a
a:[a]
as) ias :: [(a, a)]
ias@((a
i,a
a'):[(a, a)]
ias') a
n =
        case a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
i a
n of
          Ordering
LT -> [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"vUpdate: bad index"
          Ordering
EQ -> a
a' a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [(a, a)] -> a -> [a]
loop [a]
as [(a, a)]
ias' (a
na -> a -> a
forall a. Num a => a -> a -> a
+a
1)
          Ordering
GT -> a
a  a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [(a, a)] -> a -> [a]
loop [a]
as [(a, a)]
ias  (a
na -> a -> a
forall a. Num a => a -> a -> a
+a
1)
  vGenerate :: Int -> (Int -> a) -> [a]
vGenerate Int
n Int -> a
f = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> a
f [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  vAll :: (a -> Bool) -> [a] -> Bool
vAll = (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
  vAny :: (a -> Bool) -> [a] -> Bool
vAny = (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any

prettyShowL :: (Pretty a) => PrettyLevel -> a -> String
prettyShowL :: PrettyLevel -> a -> [Char]
prettyShowL PrettyLevel
l = Doc -> [Char]
render (Doc -> [Char]) -> (a -> Doc) -> a -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrettyLevel -> Rational -> a -> Doc
forall a. Pretty a => PrettyLevel -> Rational -> a -> Doc
pPrintPrec PrettyLevel
l Rational
0

-- We expect all N to be non-negative, but we use Int for convenience.
type N = Int

-- | The type /T/ is the internal type of arrays.  In general,
-- operations on /T/ do no sanity checking as that should be done
-- at the point of call.
--
-- To avoid manipulating the data the indexing into the vector containing
-- the data is somewhat complex.  To find where item /i/ of the outermost
-- dimension starts you calculate vector index @offset + i*strides[0]@.
-- To find where item /i,j/ of the two outermost dimensions is you
-- calculate vector index @offset + i*strides[0] + j*strides[1]@, etc.
data T v a = T
    { T v a -> [Int]
strides :: [N]      -- length is tensor rank
    , T v a -> Int
offset  :: !N       -- offset into vector of values
    , T v a -> v a
values  :: !(v a)   -- actual values
    }
    deriving (Int -> T v a -> ShowS
[T v a] -> ShowS
T v a -> [Char]
(Int -> T v a -> ShowS)
-> (T v a -> [Char]) -> ([T v a] -> ShowS) -> Show (T v a)
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
forall (v :: * -> *) a. Show (v a) => Int -> T v a -> ShowS
forall (v :: * -> *) a. Show (v a) => [T v a] -> ShowS
forall (v :: * -> *) a. Show (v a) => T v a -> [Char]
showList :: [T v a] -> ShowS
$cshowList :: forall (v :: * -> *) a. Show (v a) => [T v a] -> ShowS
show :: T v a -> [Char]
$cshow :: forall (v :: * -> *) a. Show (v a) => T v a -> [Char]
showsPrec :: Int -> T v a -> ShowS
$cshowsPrec :: forall (v :: * -> *) a. Show (v a) => Int -> T v a -> ShowS
Show, (forall x. T v a -> Rep (T v a) x)
-> (forall x. Rep (T v a) x -> T v a) -> Generic (T v a)
forall x. Rep (T v a) x -> T v a
forall x. T v a -> Rep (T v a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (v :: * -> *) a x. Rep (T v a) x -> T v a
forall (v :: * -> *) a x. T v a -> Rep (T v a) x
$cto :: forall (v :: * -> *) a x. Rep (T v a) x -> T v a
$cfrom :: forall (v :: * -> *) a x. T v a -> Rep (T v a) x
Generic, Typeable (T v a)
DataType
Constr
Typeable (T v a)
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> T v a -> c (T v a))
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c (T v a))
-> (T v a -> Constr)
-> (T v a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c (T v a)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a)))
-> ((forall b. Data b => b -> b) -> T v a -> T v a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r)
-> (forall u. (forall d. Data d => d -> u) -> T v a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> T v a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> T v a -> m (T v a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> T v a -> m (T v a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> T v a -> m (T v a))
-> Data (T v a)
T v a -> DataType
T v a -> Constr
(forall b. Data b => b -> b) -> T v a -> T v a
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> T v a -> u
forall u. (forall d. Data d => d -> u) -> T v a -> [u]
forall r r'.
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall r r'.
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
Typeable (T v a)
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> DataType
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> Constr
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> T v a -> T v a
forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> T v a -> u
forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> T v a -> [u]
forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
forall (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
forall (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
$cT :: Constr
$tT :: DataType
gmapMo :: (forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapMo :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapMp :: (forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapMp :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapM :: (forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapM :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapQi :: Int -> (forall d. Data d => d -> u) -> T v a -> u
$cgmapQi :: forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> T v a -> u
gmapQ :: (forall d. Data d => d -> u) -> T v a -> [u]
$cgmapQ :: forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> T v a -> [u]
gmapQr :: (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
$cgmapQr :: forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
gmapQl :: (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
$cgmapQl :: forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
gmapT :: (forall b. Data b => b -> b) -> T v a -> T v a
$cgmapT :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> T v a -> T v a
dataCast2 :: (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
$cdataCast2 :: forall (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
dataCast1 :: (forall d. Data d => c (t d)) -> Maybe (c (T v a))
$cdataCast1 :: forall (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
dataTypeOf :: T v a -> DataType
$cdataTypeOf :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> DataType
toConstr :: T v a -> Constr
$ctoConstr :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> Constr
gunfold :: (forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
$cgunfold :: forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
gfoldl :: (forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
$cgfoldl :: forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
$cp1Data :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
Typeable (T v a)
Data)

instance NFData (v a) => NFData (T v a)

-- | The shape of an array is a list of its dimensions.
type ShapeL = [Int]

badShape :: ShapeL -> Bool
badShape :: [Int] -> Bool
badShape = (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0)

-- When shapes match, we can be efficient and use loop-fused comparisons instead
-- of materializing a list.
equalT :: (Vector v, VecElem v a, Eq a, Eq (v a))
                  => ShapeL -> T v a -> T v a -> Bool
equalT :: [Int] -> T v a -> T v a -> Bool
equalT [Int]
s T v a
x T v a
y | T v a -> [Int]
forall (v :: * -> *) a. T v a -> [Int]
strides T v a
x [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== T v a -> [Int]
forall (v :: * -> *) a. T v a -> [Int]
strides T v a
y
               Bool -> Bool -> Bool
&& T v a -> Int
forall (v :: * -> *) a. T v a -> Int
offset T v a
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== T v a -> Int
forall (v :: * -> *) a. T v a -> Int
offset T v a
y
               Bool -> Bool -> Bool
&& T v a -> v a
forall (v :: * -> *) a. T v a -> v a
values T v a
x v a -> v a -> Bool
forall a. Eq a => a -> a -> Bool
== T v a -> v a
forall (v :: * -> *) a. T v a -> v a
values T v a
y = Bool
True
             | Bool
otherwise = [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
x v a -> v a -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
y

-- Note this assumes the shape is the same for both Vectors.
compareT :: (Vector v, VecElem v a, Ord a, Ord (v a))
            => ShapeL -> T v a -> T v a -> Ordering
compareT :: [Int] -> T v a -> T v a -> Ordering
compareT [Int]
s T v a
x T v a
y = v a -> v a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ([Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
x) ([Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
y)

-- Given the dimensions, return the stride in the underlying vector
-- for each dimension.  The first element of the list is the total length.
{-# INLINE getStridesT #-}
getStridesT :: ShapeL -> [N]
getStridesT :: [Int] -> [Int]
getStridesT = (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1

-- Convert an array to a list by indexing through all the elements.
-- The first argument is the array shape.
-- XXX Copy special cases from Tensor.
{-# INLINE toListT #-}
toListT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [a]
toListT :: [Int] -> T v a -> [a]
toListT [Int]
sh a :: T v a
a@(T [Int]
ss0 Int
o0 v a
v)
  | [Int] -> T v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> Bool
isCanonicalT ([Int] -> [Int]
getStridesT [Int]
sh) T v a
a = v a -> [a]
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> [a]
vToList v a
v
  | Bool
otherwise = (forall b. (a -> b -> b) -> b -> b) -> [a]
forall a. (forall b. (a -> b -> b) -> b -> b) -> [a]
build ((forall b. (a -> b -> b) -> b -> b) -> [a])
-> (forall b. (a -> b -> b) -> b -> b) -> [a]
forall a b. (a -> b) -> a -> b
$ \a -> b -> b
cons b
nil ->
      -- TODO: because unScalarT uses vIndex, this has unnecessary bounds
      -- checks.  We should expose an unchecked indexing function in the Vector
      -- class, add top-level bounds checks to cover the full range we'll
      -- access, and then do all accesses with the unchecked version.
      let go :: [Int] -> [Int] -> Int -> b -> b
go []     [Int]
ss Int
o b
rest = a -> b -> b
cons (T v a -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT ([Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
o v a
v)) b
rest
          go (Int
n:[Int]
ns) [Int]
ss Int
o b
rest = (Int -> b -> b) -> b -> [Int] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
            (\Int
i -> case T v a -> Int -> T v a
forall (v :: * -> *) a. T v a -> Int -> T v a
indexT ([Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
o v a
v) Int
i of T [Int]
ss' Int
o' v a
_ -> [Int] -> [Int] -> Int -> b -> b
go [Int]
ns [Int]
ss' Int
o')
            b
rest
            [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      in  [Int] -> [Int] -> Int -> b -> b
go [Int]
sh [Int]
ss0 Int
o0 b
nil

-- | Check if the strides are canonical, i.e., if the vector have the natural layout.
-- XXX Copy special cases from Tensor.
{-# INLINE isCanonicalT #-}
isCanonicalT :: (Vector v, VecElem v a) => [N] -> T v a -> Bool
isCanonicalT :: [Int] -> T v a -> Bool
isCanonicalT (Int
n:[Int]
ss') (T [Int]
ss Int
o v a
v) =
    Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&&         -- Vector offset is 0
    [Int]
ss [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
ss' Bool -> Bool -> Bool
&&      -- All strides are normal
    v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n    -- The vector is the right size
isCanonicalT [Int]
_ T v a
_ = [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

-- Convert a value to a scalar array.
{-# INLINE scalarT #-}
scalarT :: (Vector v, VecElem v a) => a -> T v a
scalarT :: a -> T v a
scalarT = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [] Int
0 (v a -> T v a) -> (a -> v a) -> a -> T v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton

-- Convert a scalar array to the actual value.
{-# INLINE unScalarT #-}
unScalarT :: (Vector v, VecElem v a) => T v a -> a
unScalarT :: T v a -> a
unScalarT (T [Int]
_ Int
o v a
v) = v a -> Int -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
o

-- Make a constant array.
{-# INLINE constantT #-}
constantT :: (Vector v, VecElem v a) => ShapeL -> a -> T v a
constantT :: [Int] -> a -> T v a
constantT [Int]
sh a
x = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) [Int]
sh) Int
0 (a -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton a
x)

-- TODO: change to return a list of vectors.
-- Convert an array to a vector in the natural order.
{-# INLINE toVectorT #-}
toVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
toVectorT :: [Int] -> T v a -> v a
toVectorT [Int]
sh a :: T v a
a@(T [Int]
ats Int
ao v a
v) =
  let Int
l : [Int]
ts' = [Int] -> [Int]
getStridesT [Int]
sh
      -- Are strides ok from this point?
      oks :: [Bool]
oks = (Bool -> Bool -> Bool) -> Bool -> [Bool] -> [Bool]
forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Bool -> Bool -> Bool
(&&) Bool
True ((Int -> Int -> Bool) -> [Int] -> [Int] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) [Int]
ats [Int]
ts')
      loop :: [Bool] -> [Int] -> [Int] -> Int -> DList (v a)
loop [Bool]
_ [] [Int]
_ Int
o =
        v a -> DList (v a)
forall a. a -> DList a
DL.singleton (Int -> Int -> v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> Int -> v a -> v a
vSlice Int
o Int
1 v a
v)
      loop (Bool
b:[Bool]
bs) (Int
s:[Int]
ss) (Int
t:[Int]
ts) Int
o =
        if Bool
b then
          -- All strides normal from this point,
          -- so just take a slice of the underlying vector.
          v a -> DList (v a)
forall a. a -> DList a
DL.singleton (Int -> Int -> v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> Int -> v a -> v a
vSlice Int
o (Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t) v a
v)
        else
          -- Strides are not normal, collect slices.
          [DList (v a)] -> DList (v a)
forall a. [DList a] -> DList a
DL.concat [ [Bool] -> [Int] -> [Int] -> Int -> DList (v a)
loop [Bool]
bs [Int]
ss [Int]
ts (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
o) | Int
i <- [Int
0 .. Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ]
      loop [Bool]
_ [Int]
_ [Int]
_ Int
_ = [Char] -> DList (v a)
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
  in  if [Bool] -> Bool
forall a. [a] -> a
head [Bool]
oks Bool -> Bool -> Bool
&& v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l then
        -- All strides are normal, return entire vector
        v a
v
      else if [Bool]
oks [Bool] -> Int -> Bool
forall a. [a] -> Int -> a
!! [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh then  -- Special case for speed.
        -- Innermost dimension is normal, so slices are non-trivial.
        [v a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat ([v a] -> v a) -> [v a] -> v a
forall a b. (a -> b) -> a -> b
$ DList (v a) -> [v a]
forall a. DList a -> [a]
DL.toList (DList (v a) -> [v a]) -> DList (v a) -> [v a]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Int] -> [Int] -> Int -> DList (v a)
loop [Bool]
oks [Int]
sh [Int]
ats Int
ao
      else
        -- All slices would have length 1, going via a list is faster.
        [a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList ([a] -> v a) -> [a] -> v a
forall a b. (a -> b) -> a -> b
$ [Int] -> T v a -> [a]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
a

-- Convert to a vector containing the right elements,
-- but not necessarily in the right order.
{-# INLINE toUnorderedVectorT #-}
toUnorderedVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
toUnorderedVectorT :: [Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh a :: T v a
a@(T [Int]
ats Int
ao v a
v) =
  -- Figure out if the array maps onto some contiguous slice of the vector.
  -- Do this by checking if a transposition of the array corresponds to
  -- normal strides.
  -- First sort the strides in descending order, amnd rearrange the shape the same way.
  -- Then compute the strides from this rearranged shape; these will be the normal
  -- strides for this shape.  If these strides agree with the sorted actual strides
  -- it is a transposition, and we can just slice out the relevant piece of the vector.
  let
    ([Int]
ats', [Int]
sh') = [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, Int)] -> ([Int], [Int])) -> [(Int, Int)] -> ([Int], [Int])
forall a b. (a -> b) -> a -> b
$ ((Int, Int) -> (Int, Int) -> Ordering)
-> [(Int, Int)] -> [(Int, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Int, Int) -> (Int, Int) -> Ordering)
-> (Int, Int) -> (Int, Int) -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int, Int) -> (Int, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare) ([(Int, Int)] -> [(Int, Int)]) -> [(Int, Int)] -> [(Int, Int)]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
ats [Int]
sh
    Int
l : [Int]
ts' = [Int] -> [Int]
getStridesT [Int]
sh'
  in
      if [Int]
ats' [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
ts' then
        Int -> Int -> v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> Int -> v a -> v a
vSlice Int
ao Int
l v a
v
      else
        [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
a

-- Convert from a vector.
{-# INLINE fromVectorT #-}
fromVectorT :: ShapeL -> v a -> T v a
fromVectorT :: [Int] -> v a -> T v a
fromVectorT [Int]
sh = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T ([Int] -> [Int]
forall a. [a] -> [a]
tail ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
getStridesT [Int]
sh) Int
0

-- Convert from a list
{-# INLINE fromListT #-}
fromListT :: (Vector v, VecElem v a) => [N] -> [a] -> T v a
fromListT :: [Int] -> [a] -> T v a
fromListT [Int]
sh = [Int] -> v a -> T v a
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh (v a -> T v a) -> ([a] -> v a) -> [a] -> T v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList

-- Index into the outermost dimension of an array.
{-# INLINE indexT #-}
indexT :: T v a -> N -> T v a
indexT :: T v a -> Int -> T v a
indexT (T (Int
s : [Int]
ss) Int
o v a
v) Int
i = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss (Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
s) v a
v
indexT T v a
_ Int
_ = [Char] -> T v a
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

-- Stretch the given dimensions to have arbitrary size.
-- The stretched dimensions must have size 1, and stretching is
-- done by setting the stride to 0.
{-# INLINE stretchT #-}
stretchT :: [Bool] -> T v a -> T v a
stretchT :: [Bool] -> T v a -> T v a
stretchT [Bool]
bs (T [Int]
ss Int
o v a
v) = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T ((Bool -> Int -> Int) -> [Bool] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ Bool
b Int
s -> if Bool
b then Int
0 else Int
s) [Bool]
bs [Int]
ss) Int
o v a
v

-- Map over the array elements.
{-# INLINE mapT #-}
mapT :: (Vector v, VecElem v a, VecElem v b) => ShapeL -> (a -> b) -> T v a -> T v b
mapT :: [Int] -> (a -> b) -> T v a -> T v b
mapT [Int]
sh a -> b
f (T [Int]
ss Int
o v a
v) | [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v = [Int] -> Int -> v b -> T v b
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
o ((a -> b) -> v a -> v b
forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> v a -> v b
vMap a -> b
f v a
v)
mapT [Int]
sh a -> b
f T v a
t = [Int] -> v b -> T v b
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh (v b -> T v b) -> v b -> T v b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> v a -> v b
forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> v a -> v b
vMap a -> b
f (v a -> v b) -> v a -> v b
forall a b. (a -> b) -> a -> b
$ [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t

-- Zip two arrays with a function.
{-# INLINE zipWithT #-}
zipWithT :: (Vector v, VecElem v a, VecElem v b, VecElem v c) =>
            ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT :: [Int] -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT [Int]
sh a -> b -> c
f t :: T v a
t@(T [Int]
ss Int
_ v a
v) t' :: T v b
t'@(T [Int]
_ Int
_ v b
v') =
  case (v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v, v b -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v b
v') of
    (Int
1, Int
1) ->
      -- If both vectors have length 1, then it's a degenerate case and it's better
      -- to operate on the single element directly.
      [Int] -> Int -> v c -> T v c
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 (v c -> T v c) -> v c -> T v c
forall a b. (a -> b) -> a -> b
$ c -> v c
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton (c -> v c) -> c -> v c
forall a b. (a -> b) -> a -> b
$ a -> b -> c
f (v a -> Int -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
0) (v b -> Int -> b
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v b
v' Int
0)
    (Int
1, Int
_) ->
      -- First vector has length 1, so use a map instead.
      [Int] -> (b -> c) -> T v b -> T v c
forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[Int] -> (a -> b) -> T v a -> T v b
mapT [Int]
sh (v a -> Int -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
0 a -> b -> c
`f` ) T v b
t'
    (Int
_, Int
1) ->
      -- Second vector has length 1, so use a map instead.
      [Int] -> (a -> c) -> T v a -> T v c
forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[Int] -> (a -> b) -> T v a -> T v b
mapT [Int]
sh (a -> b -> c
`f` v b -> Int -> b
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v b
v' Int
0) T v a
t
    (Int
_, Int
_) ->
      let cv :: v a
cv  = [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
          cv' :: v b
cv' = [Int] -> T v b -> v b
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
      in  [Int] -> v c -> T v c
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh (v c -> T v c) -> v c -> T v c
forall a b. (a -> b) -> a -> b
$ (a -> b -> c) -> v a -> v b -> v c
forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
(a -> b -> c) -> v a -> v b -> v c
vZipWith a -> b -> c
f v a
cv v b
cv'

-- Zip three arrays with a function.
{-# INLINE zipWith3T #-}
zipWith3T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
             ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T :: [Int] -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T [Int]
_ a -> b -> c -> d
f (T [Int]
ss Int
_ v a
v) (T [Int]
_ Int
_ v b
v') (T [Int]
_ Int
_ v c
v'') |
  -- If all vectors have length 1, then it's a degenerate case and it's better
  -- to operate on the single element directly.
  v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1, v b -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v b
v' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1, v c -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v c
v'' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
    [Int] -> Int -> v d -> T v d
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 (v d -> T v d) -> v d -> T v d
forall a b. (a -> b) -> a -> b
$ d -> v d
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton (d -> v d) -> d -> v d
forall a b. (a -> b) -> a -> b
$ a -> b -> c -> d
f (v a -> Int -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
0) (v b -> Int -> b
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v b
v' Int
0) (v c -> Int -> c
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v c
v'' Int
0)
zipWith3T [Int]
sh a -> b -> c -> d
f T v a
t T v b
t' T v c
t'' = [Int] -> v d -> T v d
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh (v d -> T v d) -> v d -> T v d
forall a b. (a -> b) -> a -> b
$ (a -> b -> c -> d) -> v a -> v b -> v c -> v d
forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
(a -> b -> c -> d) -> v a -> v b -> v c -> v d
vZipWith3 a -> b -> c -> d
f v a
v v b
v' v c
v''
  where v :: v a
v   = [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        v' :: v b
v'  = [Int] -> T v b -> v b
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
        v'' :: v c
v'' = [Int] -> T v c -> v c
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v c
t''

-- Zip four arrays with a function.
{-# INLINE zipWith4T #-}
zipWith4T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e) => ShapeL -> (a -> b -> c -> d -> e) -> T v a -> T v b -> T v c -> T v d -> T v e
zipWith4T :: [Int]
-> (a -> b -> c -> d -> e)
-> T v a
-> T v b
-> T v c
-> T v d
-> T v e
zipWith4T [Int]
sh a -> b -> c -> d -> e
f T v a
t T v b
t' T v c
t'' T v d
t''' = [Int] -> v e -> T v e
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh (v e -> T v e) -> v e -> T v e
forall a b. (a -> b) -> a -> b
$ (a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
forall (v :: * -> *) a b c d e.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e) =>
(a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
vZipWith4 a -> b -> c -> d -> e
f v a
v v b
v' v c
v'' v d
v'''
  where v :: v a
v   = [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        v' :: v b
v'  = [Int] -> T v b -> v b
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
        v'' :: v c
v'' = [Int] -> T v c -> v c
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v c
t''
        v''' :: v d
v'''= [Int] -> T v d -> v d
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v d
t'''

-- Zip five arrays with a function.
{-# INLINE zipWith5T #-}
zipWith5T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e, VecElem v f) => ShapeL -> (a -> b -> c -> d -> e -> f) -> T v a -> T v b -> T v c -> T v d -> T v e -> T v f
zipWith5T :: [Int]
-> (a -> b -> c -> d -> e -> f)
-> T v a
-> T v b
-> T v c
-> T v d
-> T v e
-> T v f
zipWith5T [Int]
sh a -> b -> c -> d -> e -> f
f T v a
t T v b
t' T v c
t'' T v d
t''' T v e
t'''' = [Int] -> v f -> T v f
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh (v f -> T v f) -> v f -> T v f
forall a b. (a -> b) -> a -> b
$ (a -> b -> c -> d -> e -> f)
-> v a -> v b -> v c -> v d -> v e -> v f
forall (v :: * -> *) a b c d e f.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e, VecElem v f) =>
(a -> b -> c -> d -> e -> f)
-> v a -> v b -> v c -> v d -> v e -> v f
vZipWith5 a -> b -> c -> d -> e -> f
f v a
v v b
v' v c
v'' v d
v''' v e
v''''
  where v :: v a
v   = [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        v' :: v b
v'  = [Int] -> T v b -> v b
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
        v'' :: v c
v'' = [Int] -> T v c -> v c
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v c
t''
        v''' :: v d
v'''= [Int] -> T v d -> v d
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v d
t'''
        v'''' :: v e
v''''= [Int] -> T v e -> v e
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v e
t''''

-- Do an arbitrary transposition.  The first argument should be
-- a permutation of the dimension, i.e., the numbers [0..r-1] in some order
-- (where r is the rank of the array).
{-# INLINE transposeT #-}
transposeT :: [Int] -> T v a -> T v a
transposeT :: [Int] -> T v a -> T v a
transposeT [Int]
is (T [Int]
ss Int
o v a
v) = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T ([Int] -> [Int] -> [Int]
forall a. [Int] -> [a] -> [a]
permute [Int]
is [Int]
ss) Int
o v a
v

-- Return all subarrays n dimensions down.
-- The shape argument should be a prefix of the array shape.
{-# INLINE subArraysT #-}
subArraysT :: ShapeL -> T v a -> [T v a]
subArraysT :: [Int] -> T v a -> [T v a]
subArraysT [Int]
sh T v a
ten = [Int] -> T v a -> [T v a] -> [T v a]
forall (v :: * -> *) a. [Int] -> T v a -> [T v a] -> [T v a]
sub [Int]
sh T v a
ten []
  where sub :: [Int] -> T v a -> [T v a] -> [T v a]
sub [] T v a
t = (T v a
t T v a -> [T v a] -> [T v a]
forall a. a -> [a] -> [a]
:)
        sub (Int
n:[Int]
ns) T v a
t = (([T v a] -> [T v a])
 -> ([T v a] -> [T v a]) -> [T v a] -> [T v a])
-> ([T v a] -> [T v a])
-> [[T v a] -> [T v a]]
-> [T v a]
-> [T v a]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ([T v a] -> [T v a]) -> ([T v a] -> [T v a]) -> [T v a] -> [T v a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) [T v a] -> [T v a]
forall a. a -> a
id [[Int] -> T v a -> [T v a] -> [T v a]
sub [Int]
ns (T v a -> Int -> T v a
forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t Int
i) | Int
i <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]

-- Reverse the given dimensions.
{-# INLINE reverseT #-}
reverseT :: [N] -> ShapeL -> T v a -> T v a
reverseT :: [Int] -> [Int] -> T v a -> T v a
reverseT [Int]
rs [Int]
sh (T [Int]
ats Int
ao v a
v) = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
rts Int
ro v a
v
  where (Int
ro, [Int]
rts) = Int -> [Int] -> [Int] -> (Int, [Int])
rev Int
0 [Int]
sh [Int]
ats
        rev :: Int -> [Int] -> [Int] -> (Int, [Int])
rev !Int
_ [] [] = (Int
ao, [])
        rev Int
r (Int
m:[Int]
ms) (Int
t:[Int]
ts) | Int
r Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
rs = (Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t, -Int
t Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
ts')
                            | Bool
otherwise   = (Int
o,            Int
t Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
ts')
          where (Int
o, [Int]
ts') = Int -> [Int] -> [Int] -> (Int, [Int])
rev (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Int]
ms [Int]
ts
        rev Int
_ [Int]
_ [Int]
_ = [Char] -> (Int, [Int])
forall a. HasCallStack => [Char] -> a
error [Char]
"reverseT: impossible"

-- Reduction of all array elements.
{-# INLINE reduceT #-}
reduceT :: (Vector v, VecElem v a) =>
           ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
reduceT :: [Int] -> (a -> a -> a) -> a -> T v a -> T v a
reduceT [Int]
sh a -> a -> a
f a
z = a -> T v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT (a -> T v a) -> (T v a -> a) -> T v a -> T v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> a -> v a -> a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> a -> a) -> a -> v a -> a
vFold a -> a -> a
f a
z (v a -> a) -> (T v a -> v a) -> T v a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh

-- Right fold via toListT.
{-# INLINE foldrT #-}
foldrT
  :: (Vector v, VecElem v a) => ShapeL -> (a -> b -> b) -> b -> T v a -> b
foldrT :: [Int] -> (a -> b -> b) -> b -> T v a -> b
foldrT [Int]
sh a -> b -> b
f b
z T v a
a = (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
z ([Int] -> T v a -> [a]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
a)

-- Traversal via toListT/fromListT.
{-# INLINE traverseT #-}
traverseT
  :: (Vector v, VecElem v a, VecElem v b, Applicative f)
  => ShapeL -> (a -> f b) -> T v a -> f (T v b)
traverseT :: [Int] -> (a -> f b) -> T v a -> f (T v b)
traverseT [Int]
sh a -> f b
f T v a
a = ([b] -> T v b) -> f [b] -> f (T v b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int] -> [b] -> T v b
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int]
sh) ((a -> f b) -> [a] -> f [b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f ([Int] -> T v a -> [a]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
a))

-- Fast check if all elements are equal.
allSameT :: (Vector v, VecElem v a, Eq a) => ShapeL -> T v a -> Bool
allSameT :: [Int] -> T v a -> Bool
allSameT [Int]
sh t :: T v a
t@(T [Int]
_ Int
_ v a
v)
  | v a -> Int
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Bool
True
  | Bool
otherwise =
    let !v' :: v a
v' = [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        !x :: a
x = v a -> Int -> a
forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v' Int
0
    in  (a -> Bool) -> v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAll (a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==) v a
v'

ppT
  :: (Vector v, VecElem v a, Pretty a)
  => PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
ppT :: PrettyLevel -> Rational -> [Int] -> T v a -> Doc
ppT PrettyLevel
l Rational
p [Int]
sh = Bool -> Doc -> Doc
maybeParens (Rational
p Rational -> Rational -> Bool
forall a. Ord a => a -> a -> Bool
> Rational
10) (Doc -> Doc) -> (T v a -> Doc) -> T v a -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Doc] -> Doc
vcat ([Doc] -> Doc) -> (T v a -> [Doc]) -> T v a -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> Doc) -> [[Char]] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map [Char] -> Doc
text ([[Char]] -> [Doc]) -> (T v a -> [[Char]]) -> T v a -> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  BoxMode -> [Char] -> [[Char]]
box BoxMode
prettyBoxMode ([Char] -> [[Char]]) -> (T v a -> [Char]) -> T v a -> [[Char]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> [Char]) -> [Int] -> T v a -> [Char]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> [Char]) -> [Int] -> T v a -> [Char]
ppT_ (PrettyLevel -> a -> [Char]
forall a. Pretty a => PrettyLevel -> a -> [Char]
prettyShowL PrettyLevel
l) [Int]
sh

ppT_
  :: (Vector v, VecElem v a)
  => (a -> String) -> ShapeL -> T v a -> String
ppT_ :: (a -> [Char]) -> [Int] -> T v a -> [Char]
ppT_ a -> [Char]
show_ [Int]
sh T v a
t = (Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
revDropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n') ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ [Int] -> T [] [Char] -> ShowS
showsT [Int]
sh T [] [Char]
t' [Char]
""
  where ss :: [[Char]]
ss = (a -> [Char]) -> [a] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map a -> [Char]
show_ ([a] -> [[Char]]) -> [a] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ [Int] -> T v a -> [a]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
t
        n :: Int
n = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ ([Char] -> Int) -> [[Char]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [Char] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Char]]
ss
        ss' :: [[Char]]
ss' = ShowS -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ShowS
padSP [[Char]]
ss
        padSP :: ShowS
padSP [Char]
s = Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Char] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
' ' [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s
        t' :: T [] String
        t' :: T [] [Char]
t' = [Int] -> Int -> [[Char]] -> T [] [Char]
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T ([Int] -> [Int]
forall a. [a] -> [a]
tail ([Int] -> [Int]
getStridesT [Int]
sh)) Int
0 [[Char]]
ss'

showsT :: [N] -> T [] String -> ShowS
showsT :: [Int] -> T [] [Char] -> ShowS
showsT (Int
0:[Int]
_)  T [] [Char]
_ = [Char] -> ShowS
showString [Char]
"EMPTY"
showsT []     T [] [Char]
t = [Char] -> ShowS
showString ([Char] -> ShowS) -> [Char] -> ShowS
forall a b. (a -> b) -> a -> b
$ T [] [Char] -> [Char]
forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT T [] [Char]
t
showsT s :: [Int]
s@[Int
_]  T [] [Char]
t = [Char] -> ShowS
showString ([Char] -> ShowS) -> [Char] -> ShowS
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unwords ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ [Int] -> T [] [Char] -> [[Char]]
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
s T [] [Char]
t
showsT (Int
n:[Int]
ns) T [] [Char]
t =
    (ShowS -> ShowS -> ShowS) -> ShowS -> [ShowS] -> ShowS
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) ShowS
forall a. a -> a
id [ [Int] -> T [] [Char] -> ShowS
showsT [Int]
ns (T [] [Char] -> Int -> T [] [Char]
forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T [] [Char]
t Int
i) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> ShowS
showString [Char]
"\n" | Int
i <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ]

data BoxMode = BoxMode { BoxMode -> Bool
_bmBars, BoxMode -> Bool
_bmUnicode, BoxMode -> Bool
_bmHeader :: Bool }

prettyBoxMode :: BoxMode
prettyBoxMode :: BoxMode
prettyBoxMode = Bool -> Bool -> Bool -> BoxMode
BoxMode Bool
False Bool
False Bool
False

box :: BoxMode -> String -> [String]
box :: BoxMode -> [Char] -> [[Char]]
box BoxMode{Bool
_bmHeader :: Bool
_bmUnicode :: Bool
_bmBars :: Bool
_bmHeader :: BoxMode -> Bool
_bmUnicode :: BoxMode -> Bool
_bmBars :: BoxMode -> Bool
..} [Char]
s =
  let bar :: Char
bar | Bool
_bmUnicode = Char
'\x2502'
          | Bool
otherwise = Char
'|'
      ls :: [[Char]]
ls = [Char] -> [[Char]]
lines [Char]
s
      ls' :: [[Char]]
ls' | Bool
_bmBars = ShowS -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (\ [Char]
l -> if [Char] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
l then [Char]
l else [Char
bar] [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
l [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char
bar]) [[Char]]
ls
          | Bool
otherwise = [[Char]]
ls
      h :: [Char]
h = [Char]
"+" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate ([Char] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([[Char]] -> [Char]
forall a. [a] -> a
head [[Char]]
ls)) Char
'-' [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"+"
      ls'' :: [[Char]]
ls'' | Bool
_bmHeader = [[Char]
h] [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
ls' [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]
h]
           | Bool
otherwise = [[Char]]
ls'
  in  [[Char]]
ls''

zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 a -> b -> b
f (a
a:[a]
as) (b
b:[b]
bs) = a -> b -> b
f a
a b
b b -> [b] -> [b]
forall a. a -> [a] -> [a]
: (a -> b -> b) -> [a] -> [b] -> [b]
forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 a -> b -> b
f [a]
as [b]
bs
zipWithLong2 a -> b -> b
_     [a]
_     [b]
bs  = [b]
bs

padT :: forall v a . (Vector v, VecElem v a) => a -> [(Int, Int)] -> ShapeL -> T v a -> ([Int], T v a)
padT :: a -> [(Int, Int)] -> [Int] -> T v a -> ([Int], T v a)
padT a
v [(Int, Int)]
aps [Int]
ash T v a
at = ([Int]
ss, [Int] -> v a -> T v a
forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
ss (v a -> T v a) -> v a -> T v a
forall a b. (a -> b) -> a -> b
$ [v a] -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat ([v a] -> v a) -> [v a] -> v a
forall a b. (a -> b) -> a -> b
$ [(Int, Int)] -> [Int] -> [Int] -> T v a -> [v a]
pad' [(Int, Int)]
aps [Int]
ash [Int]
st T v a
at)
  where pad' :: [(Int, Int)] -> ShapeL -> [Int] -> T v a -> [v a]
        pad' :: [(Int, Int)] -> [Int] -> [Int] -> T v a -> [v a]
pad' [] [Int]
sh [Int]
_ T v a
t = [[Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t]
        pad' ((Int
l,Int
h):[(Int, Int)]
ps) (Int
s:[Int]
sh) (Int
n:[Int]
ns) T v a
t =
          [Int -> a -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => Int -> a -> v a
vReplicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
l) a
v] [v a] -> [v a] -> [v a]
forall a. [a] -> [a] -> [a]
++ (Int -> [v a]) -> [Int] -> [v a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([(Int, Int)] -> [Int] -> [Int] -> T v a -> [v a]
pad' [(Int, Int)]
ps [Int]
sh [Int]
ns (T v a -> [v a]) -> (Int -> T v a) -> Int -> [v a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. T v a -> Int -> T v a
forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t) [Int
0..Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [v a] -> [v a] -> [v a]
forall a. [a] -> [a] -> [a]
++ [Int -> a -> v a
forall (v :: * -> *) a. (Vector v, VecElem v a) => Int -> a -> v a
vReplicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
h) a
v]
        pad' [(Int, Int)]
_ [Int]
_ [Int]
_ T v a
_ = [Char] -> [v a]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [v a]) -> [Char] -> [v a]
forall a b. (a -> b) -> a -> b
$ [Char]
"pad: rank mismatch: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show ([(Int, Int)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Int)]
aps, [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
ash)
        Int
_ : [Int]
st = [Int] -> [Int]
getStridesT [Int]
ss
        ss :: [Int]
ss = ((Int, Int) -> Int -> Int) -> [(Int, Int)] -> [Int] -> [Int]
forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 (\ (Int
l,Int
h) Int
s -> Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
h) [(Int, Int)]
aps [Int]
ash

-- Check if a reshape is just adding/removing some dimensions of
-- size 1, in which case it can be done by just manipulating
-- the strides.  Given the old strides, the old shapes, and the
-- new shape it will return the possible new strides.
simpleReshape :: [N] -> ShapeL -> ShapeL -> Maybe [N]
simpleReshape :: [Int] -> [Int] -> [Int] -> Maybe [Int]
simpleReshape [Int]
osts [Int]
os [Int]
ns
  | (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=) [Int]
os [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=) [Int]
ns = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
loop [Int]
ns [Int]
sts'
    -- Old and new dimensions agree where they are not 1.
    where
      -- Get old strides for non-1 dimensions
      sts' :: [Int]
sts' = [ Int
st | (Int
st, Int
s) <- [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
osts [Int]
os, Int
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 ]
      -- Insert stride 0 for all 1 dimensions in new shape.
      loop :: [Int] -> [Int] -> [Int]
loop [] [] = []
      loop (Int
1:[Int]
ss)     [Int]
sts  = Int
0  Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
loop [Int]
ss [Int]
sts
      loop (Int
_:[Int]
ss) (Int
st:[Int]
sts) = Int
st Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
loop [Int]
ss [Int]
sts
      loop [Int]
_ [Int]
_ = [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Int]) -> [Char] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Char]
"simpleReshape: shouldn't happen: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Int], [Int], [Int]) -> [Char]
forall a. Show a => a -> [Char]
show ([Int]
osts, [Int]
os, [Int]
ns)
simpleReshape [Int]
_ [Int]
_ [Int]
_ = Maybe [Int]
forall a. Maybe a
Nothing

{-# INLINE sumT #-}
sumT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
sumT :: [Int] -> T v a -> a
sumT [Int]
sh = v a -> a
forall (v :: * -> *) a. (Vector v, VecElem v a, Num a) => v a -> a
vSum (v a -> a) -> (T v a -> v a) -> T v a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE productT #-}
productT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
productT :: [Int] -> T v a -> a
productT [Int]
sh = v a -> a
forall (v :: * -> *) a. (Vector v, VecElem v a, Num a) => v a -> a
vProduct (v a -> a) -> (T v a -> v a) -> T v a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE maximumT #-}
maximumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
maximumT :: [Int] -> T v a -> a
maximumT [Int]
sh = v a -> a
forall (v :: * -> *) a. (Vector v, VecElem v a, Ord a) => v a -> a
vMaximum (v a -> a) -> (T v a -> v a) -> T v a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE minimumT #-}
minimumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
minimumT :: [Int] -> T v a -> a
minimumT [Int]
sh = v a -> a
forall (v :: * -> *) a. (Vector v, VecElem v a, Ord a) => v a -> a
vMinimum (v a -> a) -> (T v a -> v a) -> T v a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE anyT #-}
anyT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
anyT :: [Int] -> (a -> Bool) -> T v a -> Bool
anyT [Int]
sh a -> Bool
p = (a -> Bool) -> v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAny a -> Bool
p (v a -> Bool) -> (T v a -> v a) -> T v a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE allT #-}
allT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
allT :: [Int] -> (a -> Bool) -> T v a -> Bool
allT [Int]
sh a -> Bool
p = (a -> Bool) -> v a -> Bool
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAll a -> Bool
p (v a -> Bool) -> (T v a -> v a) -> T v a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE updateT #-}
updateT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [([Int], a)] -> T v a
updateT :: [Int] -> T v a -> [([Int], a)] -> T v a
updateT [Int]
sh T v a
t [([Int], a)]
us = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 (v a -> T v a) -> v a -> T v a
forall a b. (a -> b) -> a -> b
$ v a -> [(Int, a)] -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
v a -> [(Int, a)] -> v a
vUpdate ([Int] -> T v a -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t) ([(Int, a)] -> v a) -> [(Int, a)] -> v a
forall a b. (a -> b) -> a -> b
$ (([Int], a) -> (Int, a)) -> [([Int], a)] -> [(Int, a)]
forall a b. (a -> b) -> [a] -> [b]
map ([Int], a) -> (Int, a)
ix [([Int], a)]
us
  where Int
_ : [Int]
ss = [Int] -> [Int]
getStridesT [Int]
sh
        ix :: ([Int], a) -> (Int, a)
ix ([Int]
is, a
a) = ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
is [Int]
ss, a
a)

{-# INLINE generateT #-}
generateT :: (Vector v, VecElem v a) => ShapeL -> ([Int] -> a) -> T v a
generateT :: [Int] -> ([Int] -> a) -> T v a
generateT [Int]
sh [Int] -> a
f = [Int] -> Int -> v a -> T v a
forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 (v a -> T v a) -> v a -> T v a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> a) -> v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (Int -> a) -> v a
vGenerate Int
s Int -> a
g
  where Int
s : [Int]
ss = [Int] -> [Int]
getStridesT [Int]
sh
        g :: Int -> a
g Int
i = [Int] -> a
f ([Int] -> Int -> [Int]
forall a. Integral a => [a] -> a -> [a]
toIx [Int]
ss Int
i)
        toIx :: [a] -> a -> [a]
toIx [] a
_ = []
        toIx (a
n:[a]
ns) a
i = a
q a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> a -> [a]
toIx [a]
ns a
r where (a
q, a
r) = a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
quotRem a
i a
n

{-# INLINE iterateNT #-}
iterateNT :: (Vector v, VecElem v a) => Int -> (a -> a) -> a -> T v a
iterateNT :: Int -> (a -> a) -> a -> T v a
iterateNT Int
n a -> a
f a
x = [Int] -> [a] -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int
n] ([a] -> T v a) -> [a] -> T v a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
n ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate a -> a
f a
x

{-# INLINE iotaT #-}
iotaT :: (Vector v, VecElem v a, Enum a, Num a) => Int -> T v a
iotaT :: Int -> T v a
iotaT Int
n = [Int] -> [a] -> T v a
forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int
n] [a
0 .. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n a -> a -> a
forall a. Num a => a -> a -> a
- a
1]    -- TODO: should use V.enumFromTo instead

-------

-- | Permute the elements of a list, the first argument is indices into the original list.
permute :: [Int] -> [a] -> [a]
permute :: [Int] -> [a] -> [a]
permute [Int]
is [a]
xs = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ([a]
xs[a] -> Int -> a
forall a. [a] -> Int -> a
!!) [Int]
is

-- | Like 'dropWhile' but at the end of the list.
revDropWhile :: (a -> Bool) -> [a] -> [a]
revDropWhile :: (a -> Bool) -> [a] -> [a]
revDropWhile a -> Bool
p = [a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Bool) -> [a] -> [a]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile a -> Bool
p ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [a]
forall a. [a] -> [a]
reverse

allSame :: (Eq a) => [a] -> Bool
allSame :: [a] -> Bool
allSame [] = Bool
True
allSame (a
x : [a]
xs) = (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==) [a]
xs

-- | Get the value of a type level Nat.
-- Use with explicit type application, i.e., @valueOf \@42@
{-# INLINE valueOf #-}
valueOf :: forall n i . (KnownNat n, Num i) => i
valueOf :: i
valueOf = Integer -> i
forall a. Num a => Integer -> a
fromInteger (Integer -> i) -> Integer -> i
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)