{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}

-- | Arrays with a fixed shape (known shape at compile time).
module NumHask.Array.Fixed
  ( -- $usage
    Array (..),

    -- * Conversion
    with,
    shape,
    toDynamic,

    -- * Operators
    takes,
    reshape,
    transpose,
    indices,
    ident,
    sequent,
    diag,
    undiag,
    singleton,
    selects,
    selectsExcept,
    folds,
    extracts,
    extractsExcept,
    joins,
    maps,
    concatenate,
    insert,
    append,
    reorder,
    expand,
    expandr,
    apply,
    contract,
    dot,
    mult,
    slice,
    squeeze,

    -- * Scalar

    --
    -- Scalar specialisations
    fromScalar,
    toScalar,

    -- * Vector

    --
    -- Vector specialisations.
    Vector,
    sequentv,

    -- * Matrix

    --
    -- Matrix specialisations.
    Matrix,
    col,
    row,
    safeCol,
    safeRow,
    mmult,
    chol,
    invtri,
  )
where

import Data.Distributive (Distributive (..))
import Data.Functor.Rep
import Data.List ((!!))
import Data.Proxy
import Data.Vector qualified as V
import GHC.Exts (IsList (..))
import GHC.Show (Show (..))
import GHC.TypeLits
import NumHask.Array.Dynamic qualified as D
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (sequence, toList)

-- $setup
--
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import GHC.TypeLits (Nat)
-- >>> import Data.Proxy
-- >>> import Data.Functor.Rep
-- >>> let s = [1] :: Array ('[] :: [Nat]) Int -- scalar
-- >>> let v = [1,2,3] :: Array '[3] Int       -- vector
-- >>> let t = [0..3] :: Array '[2,2] Int     -- square matrix
-- >>> let m = [0..11] :: Array '[3,4] Int     -- matrix
-- >>> let a = [1..24] :: Array '[2,3,4] Int

-- $usage
--
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import NumHask.Array.Fixed
-- >>> import GHC.TypeLits (Nat)
-- >>> let s = [1] :: Array ('[] :: [Nat]) Int -- scalar
-- >>> let v = [1,2,3] :: Array '[3] Int       -- vector
-- >>> let m = [0..11] :: Array '[3,4] Int     -- matrix
-- >>> let a = [1..24] :: Array '[2,3,4] Int

-- | a multidimensional array with a type-level shape
--
-- >>> :set -XDataKinds
-- >>> [1..24] :: Array '[2,3,4] Int
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
--
-- >>> [1,2,3] :: Array '[2,2] Int
-- *** Exception: NumHaskException {errorMessage = "shape mismatch"}
-- [[
newtype Array s a = Array {forall {k} (s :: k) a. Array s a -> Vector a
unArray :: V.Vector a} deriving (Array s a -> Array s a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (s :: k) a. Eq a => Array s a -> Array s a -> Bool
/= :: Array s a -> Array s a -> Bool
$c/= :: forall k (s :: k) a. Eq a => Array s a -> Array s a -> Bool
== :: Array s a -> Array s a -> Bool
$c== :: forall k (s :: k) a. Eq a => Array s a -> Array s a -> Bool
Eq, Array s a -> Array s a -> Bool
Array s a -> Array s a -> Ordering
Array s a -> Array s a -> Array s a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {k} {s :: k} {a}. Ord a => Eq (Array s a)
forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
forall k (s :: k) a. Ord a => Array s a -> Array s a -> Ordering
forall k (s :: k) a. Ord a => Array s a -> Array s a -> Array s a
min :: Array s a -> Array s a -> Array s a
$cmin :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Array s a
max :: Array s a -> Array s a -> Array s a
$cmax :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Array s a
>= :: Array s a -> Array s a -> Bool
$c>= :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
> :: Array s a -> Array s a -> Bool
$c> :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
<= :: Array s a -> Array s a -> Bool
$c<= :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
< :: Array s a -> Array s a -> Bool
$c< :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
compare :: Array s a -> Array s a -> Ordering
$ccompare :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Ordering
Ord, forall k (s :: k) a b. a -> Array s b -> Array s a
forall k (s :: k) a b. (a -> b) -> Array s a -> Array s b
forall a b. a -> Array s b -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Array s b -> Array s a
$c<$ :: forall k (s :: k) a b. a -> Array s b -> Array s a
fmap :: forall a b. (a -> b) -> Array s a -> Array s b
$cfmap :: forall k (s :: k) a b. (a -> b) -> Array s a -> Array s b
Functor, forall a. Eq a => a -> Array s a -> Bool
forall a. Num a => Array s a -> a
forall a. Ord a => Array s a -> a
forall m. Monoid m => Array s m -> m
forall a. Array s a -> Bool
forall a. Array s a -> Int
forall a. Array s a -> [a]
forall a. (a -> a -> a) -> Array s a -> a
forall k (s :: k) a. Eq a => a -> Array s a -> Bool
forall k (s :: k) a. Num a => Array s a -> a
forall k (s :: k) a. Ord a => Array s a -> a
forall k (s :: k) m. Monoid m => Array s m -> m
forall k (s :: k) a. Array s a -> Bool
forall k (s :: k) a. Array s a -> Int
forall k (s :: k) a. Array s a -> [a]
forall k (s :: k) a. (a -> a -> a) -> Array s a -> a
forall k (s :: k) m a. Monoid m => (a -> m) -> Array s a -> m
forall k (s :: k) b a. (b -> a -> b) -> b -> Array s a -> b
forall k (s :: k) a b. (a -> b -> b) -> b -> Array s a -> b
forall m a. Monoid m => (a -> m) -> Array s a -> m
forall b a. (b -> a -> b) -> b -> Array s a -> b
forall a b. (a -> b -> b) -> b -> Array s a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => Array s a -> a
$cproduct :: forall k (s :: k) a. Num a => Array s a -> a
sum :: forall a. Num a => Array s a -> a
$csum :: forall k (s :: k) a. Num a => Array s a -> a
minimum :: forall a. Ord a => Array s a -> a
$cminimum :: forall k (s :: k) a. Ord a => Array s a -> a
maximum :: forall a. Ord a => Array s a -> a
$cmaximum :: forall k (s :: k) a. Ord a => Array s a -> a
elem :: forall a. Eq a => a -> Array s a -> Bool
$celem :: forall k (s :: k) a. Eq a => a -> Array s a -> Bool
length :: forall a. Array s a -> Int
$clength :: forall k (s :: k) a. Array s a -> Int
null :: forall a. Array s a -> Bool
$cnull :: forall k (s :: k) a. Array s a -> Bool
toList :: forall a. Array s a -> [a]
$ctoList :: forall k (s :: k) a. Array s a -> [a]
foldl1 :: forall a. (a -> a -> a) -> Array s a -> a
$cfoldl1 :: forall k (s :: k) a. (a -> a -> a) -> Array s a -> a
foldr1 :: forall a. (a -> a -> a) -> Array s a -> a
$cfoldr1 :: forall k (s :: k) a. (a -> a -> a) -> Array s a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> Array s a -> b
$cfoldl' :: forall k (s :: k) b a. (b -> a -> b) -> b -> Array s a -> b
foldl :: forall b a. (b -> a -> b) -> b -> Array s a -> b
$cfoldl :: forall k (s :: k) b a. (b -> a -> b) -> b -> Array s a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> Array s a -> b
$cfoldr' :: forall k (s :: k) a b. (a -> b -> b) -> b -> Array s a -> b
foldr :: forall a b. (a -> b -> b) -> b -> Array s a -> b
$cfoldr :: forall k (s :: k) a b. (a -> b -> b) -> b -> Array s a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> Array s a -> m
$cfoldMap' :: forall k (s :: k) m a. Monoid m => (a -> m) -> Array s a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> Array s a -> m
$cfoldMap :: forall k (s :: k) m a. Monoid m => (a -> m) -> Array s a -> m
fold :: forall m. Monoid m => Array s m -> m
$cfold :: forall k (s :: k) m. Monoid m => Array s m -> m
Foldable, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall k (s :: k) a x. Rep (Array s a) x -> Array s a
forall k (s :: k) a x. Array s a -> Rep (Array s a) x
$cto :: forall k (s :: k) a x. Rep (Array s a) x -> Array s a
$cfrom :: forall k (s :: k) a x. Array s a -> Rep (Array s a) x
Generic, forall k (s :: k). Functor (Array s)
forall k (s :: k). Foldable (Array s)
forall k (s :: k) (m :: * -> *) a.
Monad m =>
Array s (m a) -> m (Array s a)
forall k (s :: k) (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
forall k (s :: k) (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
forall k (s :: k) (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
sequence :: forall (m :: * -> *) a. Monad m => Array s (m a) -> m (Array s a)
$csequence :: forall k (s :: k) (m :: * -> *) a.
Monad m =>
Array s (m a) -> m (Array s a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
$cmapM :: forall k (s :: k) (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
$csequenceA :: forall k (s :: k) (f :: * -> *) a.
Applicative f =>
Array s (f a) -> f (Array s a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
$ctraverse :: forall k (s :: k) (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
Traversable)

instance (HasShape s, Show a) => Show (Array s a) where
  show :: Array s a -> String
show Array s a
a = forall a. Show a => a -> String
GHC.Show.show (forall (s :: [Nat]) a. HasShape s => Array s a -> Array a
toDynamic Array s a
a)

instance
  ( HasShape s
  ) =>
  Data.Distributive.Distributive (Array s)
  where
  distribute :: forall (f :: * -> *) a. Functor f => f (Array s a) -> Array s (f a)
distribute = forall (f :: * -> *) (w :: * -> *) a.
(Representable f, Functor w) =>
w (f a) -> f (w a)
distributeRep
  {-# INLINE distribute #-}

instance
  forall s.
  ( HasShape s
  ) =>
  Representable (Array s)
  where
  type Rep (Array s) = [Int]

  tabulate :: forall a. (Rep (Array s) -> a) -> Array s a
tabulate Rep (Array s) -> a
f =
    forall {k} (s :: k) a. Vector a -> Array s a
Array forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Int -> (Int -> a) -> Vector a
V.generate ([Int] -> Int
size [Int]
s) forall a b. (a -> b) -> a -> b
$ (Rep (Array s) -> a
f forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> Int -> [Int]
shapen [Int]
s)
    where
      s :: [Int]
s = forall (s :: [Nat]). Shape s -> [Int]
shapeVal forall a b. (a -> b) -> a -> b
$ forall (s :: [Nat]). HasShape s => Shape s
toShape @s
  {-# INLINE tabulate #-}

  index :: forall a. Array s a -> Rep (Array s) -> a
index (Array Vector a
v) Rep (Array s)
i = forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s Rep (Array s)
i)
    where
      s :: [Int]
s = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @s)
  {-# INLINE index #-}

-- * NumHask heirarchy

instance
  ( Additive a,
    HasShape s
  ) =>
  Additive (Array s a)
  where
  + :: Array s a -> Array s a -> Array s a
(+) = forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 forall a. Additive a => a -> a -> a
(+)

  zero :: Array s a
zero = forall (f :: * -> *) a. Representable f => a -> f a
pureRep forall a. Additive a => a
zero

instance
  ( Subtractive a,
    HasShape s
  ) =>
  Subtractive (Array s a)
  where
  negate :: Array s a -> Array s a
negate = forall (f :: * -> *) a b. Representable f => (a -> b) -> f a -> f b
fmapRep forall a. Subtractive a => a -> a
negate

instance
  (Multiplicative a) =>
  MultiplicativeAction (Array s a)
  where
  type Scalar (Array s a) = a
  |* :: Array s a -> Scalar (Array s a) -> Array s a
(|*) Array s a
r Scalar (Array s a)
s = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Array s a)
s *) Array s a
r

instance (Additive a) => AdditiveAction (Array s a) where
  type AdditiveScalar (Array s a) = a
  |+ :: Array s a -> AdditiveScalar (Array s a) -> Array s a
(|+) Array s a
r AdditiveScalar (Array s a)
s = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (AdditiveScalar (Array s a)
s +) Array s a
r

instance
  (Subtractive a) =>
  SubtractiveAction (Array s a)
  where
  |- :: Array s a -> AdditiveScalar (Array s a) -> Array s a
(|-) Array s a
r AdditiveScalar (Array s a)
s = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
x -> a
x forall a. Subtractive a => a -> a -> a
- AdditiveScalar (Array s a)
s) Array s a
r

instance
  (Divisive a) =>
  DivisiveAction (Array s a)
  where
  |/ :: Array s a -> Scalar (Array s a) -> Array s a
(|/) Array s a
r Scalar (Array s a)
s = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Divisive a => a -> a -> a
/ Scalar (Array s a)
s) Array s a
r

instance (HasShape s, JoinSemiLattice a) => JoinSemiLattice (Array s a) where
  \/ :: Array s a -> Array s a -> Array s a
(\/) = forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 forall a. JoinSemiLattice a => a -> a -> a
(\/)

instance (HasShape s, MeetSemiLattice a) => MeetSemiLattice (Array s a) where
  /\ :: Array s a -> Array s a -> Array s a
(/\) = forall (f :: * -> *) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 forall a. MeetSemiLattice a => a -> a -> a
(/\)

instance (HasShape s, Subtractive a, Epsilon a) => Epsilon (Array s a) where
  epsilon :: Array s a
epsilon = forall (s :: [Nat]) a. HasShape s => a -> Array s a
singleton forall a. Epsilon a => a
epsilon

instance
  ( HasShape s
  ) =>
  IsList (Array s a)
  where
  type Item (Array s a) = a

  fromList :: [Item (Array s a)] -> Array s a
fromList [Item (Array s a)]
l =
    forall a. a -> a -> Bool -> a
bool
      (forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"shape mismatch"))
      (forall {k} (s :: k) a. Vector a -> Array s a
Array forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Vector a
V.fromList [Item (Array s a)]
l)
      ((forall (t :: * -> *) a. Foldable t => t a -> Int
length [Item (Array s a)]
l forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ds) Bool -> Bool -> Bool
|| (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Item (Array s a)]
l forall a. Eq a => a -> a -> Bool
== [Int] -> Int
size [Int]
ds))
    where
      ds :: [Int]
ds = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @s)

  toList :: Array s a -> [Item (Array s a)]
toList (Array Vector a
v) = forall a. Vector a -> [a]
V.toList Vector a
v

-- | Get shape of an Array as a value.
--
-- >>> shape a
-- [2,3,4]
shape :: forall a s. (HasShape s) => Array s a -> [Int]
shape :: forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s a
_ = forall (s :: [Nat]). Shape s -> [Int]
shapeVal forall a b. (a -> b) -> a -> b
$ forall (s :: [Nat]). HasShape s => Shape s
toShape @s
{-# INLINE shape #-}

-- | convert to a dynamic array with shape at the value level.
toDynamic :: (HasShape s) => Array s a -> D.Array a
toDynamic :: forall (s :: [Nat]) a. HasShape s => Array s a -> Array a
toDynamic Array s a
a = forall a. [Int] -> [a] -> Array a
D.fromFlatList (forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s a
a) (forall l. IsList l => l -> [Item l]
toList Array s a
a)

-- | Use a dynamic array in a fixed context.
--
-- >>> import qualified NumHask.Array.Dynamic as D
-- >>> with (D.fromFlatList [2,3,4] [1..24]) (selects (Proxy :: Proxy '[0,1]) [1,1] :: Array '[2,3,4] Int -> Array '[4] Int)
-- [17, 18, 19, 20]
with ::
  forall a r s.
  (HasShape s) =>
  D.Array a ->
  (Array s a -> r) ->
  r
with :: forall a r (s :: [Nat]).
HasShape s =>
Array a -> (Array s a -> r) -> r
with (D.Array [Int]
_ Vector a
v) Array s a -> r
f = Array s a -> r
f (forall {k} (s :: k) a. Vector a -> Array s a
Array Vector a
v)

-- | Takes the top-most elements according to the new dimension.
--
-- >>> takes a :: Array '[2,2,3] Int
-- [[[1, 2, 3],
--   [5, 6, 7]],
--  [[13, 14, 15],
--   [17, 18, 19]]]
takes ::
  forall s s' a.
  ( HasShape s,
    HasShape s'
  ) =>
  Array s a ->
  Array s' a
takes :: forall (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape s') =>
Array s a -> Array s' a
takes Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate forall a b. (a -> b) -> a -> b
$ \Rep (Array s')
s -> forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a Rep (Array s')
s

-- | Reshape an array (with the same number of elements).
--
-- >>> reshape a :: Array '[4,3,2] Int
-- [[[1, 2],
--   [3, 4],
--   [5, 6]],
--  [[7, 8],
--   [9, 10],
--   [11, 12]],
--  [[13, 14],
--   [15, 16],
--   [17, 18]],
--  [[19, 20],
--   [21, 22],
--   [23, 24]]]
reshape ::
  forall a s s'.
  ( Size s ~ Size s',
    HasShape s,
    HasShape s'
  ) =>
  Array s a ->
  Array s' a
reshape :: forall a (s :: [Nat]) (s' :: [Nat]).
(Size s ~ Size s', HasShape s, HasShape s') =>
Array s a -> Array s' a
reshape Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> Int -> [Int]
shapen [Int]
s forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Int] -> [Int] -> Int
flatten [Int]
s')
  where
    s :: [Int]
s = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @s)
    s' :: [Int]
s' = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @s')

-- | Reverse indices eg transposes the element A/ijk/ to A/kji/.
--
-- >>> index (transpose a) [1,0,0] == index a [0,0,1]
-- True
transpose :: forall a s. (HasShape s, HasShape (Reverse s)) => Array s a -> Array (Reverse s) a
transpose :: forall a (s :: [Nat]).
(HasShape s, HasShape (Reverse s)) =>
Array s a -> Array (Reverse s) a
transpose Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse)

-- | Indices of an Array.
--
-- >>> indices :: Array '[3,3] [Int]
-- [[[0,0], [0,1], [0,2]],
--  [[1,0], [1,1], [1,2]],
--  [[2,0], [2,1], [2,2]]]
indices :: forall s. (HasShape s) => Array s [Int]
indices :: forall (s :: [Nat]). HasShape s => Array s [Int]
indices = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id

-- | The identity array.
--
-- >>> ident :: Array '[3,2] Int
-- [[1, 0],
--  [0, 1],
--  [0, 0]]
ident :: forall a s. (HasShape s, Additive a, Multiplicative a) => Array s a
ident :: forall a (s :: [Nat]).
(HasShape s, Additive a, Multiplicative a) =>
Array s a
ident = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (forall a. a -> a -> Bool -> a
bool forall a. Additive a => a
zero forall a. Multiplicative a => a
one forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {a}. Eq a => [a] -> Bool
isDiag)
  where
    isDiag :: [a] -> Bool
isDiag [] = Bool
True
    isDiag [Item [a]
_] = Bool
True
    isDiag [Item [a]
x, Item [a]
y] = Item [a]
x forall a. Eq a => a -> a -> Bool
== Item [a]
y
    isDiag (a
x : a
y : [a]
xs) = a
x forall a. Eq a => a -> a -> Bool
== a
y Bool -> Bool -> Bool
&& [a] -> Bool
isDiag (a
y forall a. a -> [a] -> [a]
: [a]
xs)

-- | An array of sequential Ints
--
-- >>> sequent :: Array '[3] Int
-- [0, 1, 2]
--
-- >>> sequent :: Array '[3,3] Int
-- [[0, 0, 0],
--  [0, 1, 0],
--  [0, 0, 2]]
sequent :: forall s. (HasShape s) => Array s Int
sequent :: forall (s :: [Nat]). HasShape s => Array s Int
sequent = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate forall {a}. (Additive a, Eq a) => [a] -> a
go
  where
    go :: [a] -> a
go [] = forall a. Additive a => a
zero
    go [Item [a]
i] = Item [a]
i
    go (a
i : [a]
js) = forall a. a -> a -> Bool -> a
bool forall a. Additive a => a
zero a
i (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a
i ==) [a]
js)

-- | Extract the diagonal of an array.
--
-- >>> diag (ident :: Array '[3,2] Int)
-- [1, 1]
diag ::
  forall a s.
  ( HasShape s,
    HasShape '[Minimum s]
  ) =>
  Array s a ->
  Array '[Minimum s] a
diag :: forall a (s :: [Nat]).
(HasShape s, HasShape '[Minimum s]) =>
Array s a -> Array '[Minimum s] a
diag Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
    go (Int
s' : [Int]
_) = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
ds) Int
s')
    ds :: [Int]
ds = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @s)

-- | Expand the array to form a diagonal array
--
-- >>> undiag ([1,1] :: Array '[2] Int)
-- [[1, 0],
--  [0, 1]]
undiag ::
  forall a s.
  ( HasShape s,
    Additive a,
    HasShape ((++) s s)
  ) =>
  Array s a ->
  Array ((++) s s) a
undiag :: forall a (s :: [Nat]).
(HasShape s, Additive a, HasShape (s ++ s)) =>
Array s a -> Array (s ++ s) a
undiag Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
    go xs :: [Int]
xs@(Int
x : [Int]
xs') = forall a. a -> a -> Bool -> a
bool forall a. Additive a => a
zero (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a [Int]
xs) (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int
x ==) [Int]
xs')

-- | Create an array composed of a single value.
--
-- >>> singleton one :: Array '[3,2] Int
-- [[1, 1],
--  [1, 1],
--  [1, 1]]
singleton :: (HasShape s) => a -> Array s a
singleton :: forall (s :: [Nat]) a. HasShape s => a -> Array s a
singleton a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (forall a b. a -> b -> a
const a
a)

-- | Select an array along dimensions.
--
-- >>> let s = selects (Proxy :: Proxy '[0,1]) [1,1] a
-- >>> :t s
-- s :: Array '[4] Int
--
-- >>> s
-- [17, 18, 19, 20]
selects ::
  forall ds s s' a.
  ( HasShape s,
    HasShape ds,
    HasShape s',
    s' ~ DropIndexes s ds
  ) =>
  Proxy ds ->
  [Int] ->
  Array s a ->
  Array s' a
selects :: forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ DropIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selects Proxy ds
_ [Int]
i Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
s [Int]
ds [Int]
i)
    ds :: [Int]
ds = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Select an index /except/ along specified dimensions.
--
-- >>> let s = selectsExcept (Proxy :: Proxy '[2]) [1,1] a
-- >>> :t s
-- s :: Array '[4] Int
--
-- >>> s
-- [17, 18, 19, 20]
selectsExcept ::
  forall ds s s' a.
  ( HasShape s,
    HasShape ds,
    HasShape s',
    s' ~ TakeIndexes s ds
  ) =>
  Proxy ds ->
  [Int] ->
  Array s a ->
  Array s' a
selectsExcept :: forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ TakeIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selectsExcept Proxy ds
_ [Int]
i Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
i [Int]
ds [Int]
s)
    ds :: [Int]
ds = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Fold along specified dimensions.
--
-- >>> folds sum (Proxy :: Proxy '[1]) a
-- [68, 100, 132]
folds ::
  forall ds st si so a b.
  ( HasShape st,
    HasShape ds,
    HasShape si,
    HasShape so,
    si ~ DropIndexes st ds,
    so ~ TakeIndexes st ds
  ) =>
  (Array si a -> b) ->
  Proxy ds ->
  Array st a ->
  Array so b
folds :: forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a b.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 si ~ DropIndexes st ds, so ~ TakeIndexes st ds) =>
(Array si a -> b) -> Proxy ds -> Array st a -> Array so b
folds Array si a -> b
f Proxy ds
d Array st a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> b
go
  where
    go :: [Int] -> b
go [Int]
s = Array si a -> b
f (forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ DropIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selects Proxy ds
d [Int]
s Array st a
a)

-- | Extracts dimensions to an outer layer.
--
-- >>> let e = extracts (Proxy :: Proxy '[1,2]) a
-- >>> :t e
-- e :: Array [3, 4] (Array '[2] Int)
extracts ::
  forall ds st si so a.
  ( HasShape st,
    HasShape ds,
    HasShape si,
    HasShape so,
    si ~ DropIndexes st ds,
    so ~ TakeIndexes st ds
  ) =>
  Proxy ds ->
  Array st a ->
  Array so (Array si a)
extracts :: forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 si ~ DropIndexes st ds, so ~ TakeIndexes st ds) =>
Proxy ds -> Array st a -> Array so (Array si a)
extracts Proxy ds
d Array st a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> Array si a
go
  where
    go :: [Int] -> Array si a
go [Int]
s = forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ DropIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selects Proxy ds
d [Int]
s Array st a
a

-- | Extracts /except/ dimensions to an outer layer.
--
-- >>> let e = extractsExcept (Proxy :: Proxy '[1,2]) a
-- >>> :t e
-- e :: Array '[2] (Array [3, 4] Int)
extractsExcept ::
  forall ds st si so a.
  ( HasShape st,
    HasShape ds,
    HasShape si,
    HasShape so,
    so ~ DropIndexes st ds,
    si ~ TakeIndexes st ds
  ) =>
  Proxy ds ->
  Array st a ->
  Array so (Array si a)
extractsExcept :: forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 so ~ DropIndexes st ds, si ~ TakeIndexes st ds) =>
Proxy ds -> Array st a -> Array so (Array si a)
extractsExcept Proxy ds
d Array st a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> Array si a
go
  where
    go :: [Int] -> Array si a
go [Int]
s = forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ TakeIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selectsExcept Proxy ds
d [Int]
s Array st a
a

-- | Join inner and outer dimension layers.
--
-- >>> let e = extracts (Proxy :: Proxy '[1,0]) a
--
-- >>> :t e
-- e :: Array [3, 2] (Array '[4] Int)
--
-- >>> let j = joins (Proxy :: Proxy '[1,0]) e
--
-- >>> :t j
-- j :: Array [2, 3, 4] Int
--
-- >>> a == j
-- True
joins ::
  forall ds si st so a.
  ( HasShape st,
    HasShape ds,
    st ~ AddIndexes si ds so,
    HasShape si,
    HasShape so
  ) =>
  Proxy ds ->
  Array so (Array si a) ->
  Array st a
joins :: forall (ds :: [Nat]) (si :: [Nat]) (st :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, st ~ AddIndexes si ds so, HasShape si,
 HasShape so) =>
Proxy ds -> Array so (Array si a) -> Array st a
joins Proxy ds
_ Array so (Array si a)
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array so (Array si a)
a ([Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
ds)) ([Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
ds)
    ds :: [Int]
ds = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Maps a function along specified dimensions.
--
-- >>> :t maps (transpose) (Proxy :: Proxy '[1]) a
-- maps (transpose) (Proxy :: Proxy '[1]) a :: Array [4, 3, 2] Int
maps ::
  forall ds st st' si si' so a b.
  ( HasShape st,
    HasShape st',
    HasShape ds,
    HasShape si,
    HasShape si',
    HasShape so,
    si ~ DropIndexes st ds,
    so ~ TakeIndexes st ds,
    st' ~ AddIndexes si' ds so,
    st ~ AddIndexes si ds so
  ) =>
  (Array si a -> Array si' b) ->
  Proxy ds ->
  Array st a ->
  Array st' b
maps :: forall (ds :: [Nat]) (st :: [Nat]) (st' :: [Nat]) (si :: [Nat])
       (si' :: [Nat]) (so :: [Nat]) a b.
(HasShape st, HasShape st', HasShape ds, HasShape si, HasShape si',
 HasShape so, si ~ DropIndexes st ds, so ~ TakeIndexes st ds,
 st' ~ AddIndexes si' ds so, st ~ AddIndexes si ds so) =>
(Array si a -> Array si' b)
-> Proxy ds -> Array st a -> Array st' b
maps Array si a -> Array si' b
f Proxy ds
d Array st a
a = forall (ds :: [Nat]) (si :: [Nat]) (st :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, st ~ AddIndexes si ds so, HasShape si,
 HasShape so) =>
Proxy ds -> Array so (Array si a) -> Array st a
joins Proxy ds
d (forall (f :: * -> *) a b. Representable f => (a -> b) -> f a -> f b
fmapRep Array si a -> Array si' b
f (forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 si ~ DropIndexes st ds, so ~ TakeIndexes st ds) =>
Proxy ds -> Array st a -> Array so (Array si a)
extracts Proxy ds
d Array st a
a))

-- | Concatenate along a dimension.
--
-- >>> :t concatenate (Proxy :: Proxy 1) a a
-- concatenate (Proxy :: Proxy 1) a a :: Array [2, 6, 4] Int
concatenate ::
  forall a s0 s1 d s.
  ( CheckConcatenate d s0 s1 s,
    Concatenate d s0 s1 ~ s,
    HasShape s0,
    HasShape s1,
    HasShape s,
    KnownNat d
  ) =>
  Proxy d ->
  Array s0 a ->
  Array s1 a ->
  Array s a
concatenate :: forall a (s0 :: [Nat]) (s1 :: [Nat]) (d :: Nat) (s :: [Nat]).
(CheckConcatenate d s0 s1 s, Concatenate d s0 s1 ~ s, HasShape s0,
 HasShape s1, HasShape s, KnownNat d) =>
Proxy d -> Array s0 a -> Array s1 a -> Array s a
concatenate Proxy d
_ Array s0 a
s0 Array s1 a
s1 = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s =
      forall a. a -> a -> Bool -> a
bool
        (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s0 a
s0 [Int]
s)
        ( forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index
            Array s1 a
s1
            ( [Int] -> Int -> Int -> [Int]
addIndex
                ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
                Int
d
                (([Int]
s forall a. [a] -> Int -> a
!! Int
d) forall a. Subtractive a => a -> a -> a
- ([Int]
ds0 forall a. [a] -> Int -> a
!! Int
d))
            )
        )
        (([Int]
s forall a. [a] -> Int -> a
!! Int
d) forall a. Ord a => a -> a -> Bool
>= ([Int]
ds0 forall a. [a] -> Int -> a
!! Int
d))
    ds0 :: [Int]
ds0 = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @s0)
    d :: Int
d = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @d forall {k} (t :: k). Proxy t
Proxy

-- | Insert along a dimension at a position.
--
-- >>> insert (Proxy :: Proxy 2) (Proxy :: Proxy 0) a ([100..105])
-- [[[100, 1, 2, 3, 4],
--   [101, 5, 6, 7, 8],
--   [102, 9, 10, 11, 12]],
--  [[103, 13, 14, 15, 16],
--   [104, 17, 18, 19, 20],
--   [105, 21, 22, 23, 24]]]
insert ::
  forall a s s' d i.
  ( DropIndex s d ~ s',
    CheckInsert d i s,
    KnownNat i,
    KnownNat d,
    HasShape s,
    HasShape s',
    HasShape (Insert d s)
  ) =>
  Proxy d ->
  Proxy i ->
  Array s a ->
  Array s' a ->
  Array (Insert d s) a
insert :: forall a (s :: [Nat]) (s' :: [Nat]) (d :: Nat) (i :: Nat).
(DropIndex s d ~ s', CheckInsert d i s, KnownNat i, KnownNat d,
 HasShape s, HasShape s', HasShape (Insert d s)) =>
Proxy d
-> Proxy i -> Array s a -> Array s' a -> Array (Insert d s) a
insert Proxy d
_ Proxy i
_ Array s a
a Array s' a
b = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s
      | [Int]
s forall a. [a] -> Int -> a
!! Int
d forall a. Eq a => a -> a -> Bool
== Int
i = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s' a
b ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
      | [Int]
s forall a. [a] -> Int -> a
!! Int
d forall a. Ord a => a -> a -> Bool
< Int
i = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a [Int]
s
      | Bool
otherwise = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (Int -> [Int] -> [Int]
decAt Int
d [Int]
s)
    d :: Int
d = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @d forall {k} (t :: k). Proxy t
Proxy
    i :: Int
i = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @i forall {k} (t :: k). Proxy t
Proxy

-- | Insert along a dimension at the end.
--
-- >>>  :t append (Proxy :: Proxy 0) a
-- append (Proxy :: Proxy 0) a
--   :: Array [3, 4] Int -> Array [3, 3, 4] Int
append ::
  forall a d s s'.
  ( DropIndex s d ~ s',
    CheckInsert d (Dimension s d - 1) s,
    KnownNat (Dimension s d - 1),
    KnownNat d,
    HasShape s,
    HasShape s',
    HasShape (Insert d s)
  ) =>
  Proxy d ->
  Array s a ->
  Array s' a ->
  Array (Insert d s) a
append :: forall a (d :: Nat) (s :: [Nat]) (s' :: [Nat]).
(DropIndex s d ~ s', CheckInsert d (Dimension s d - 1) s,
 KnownNat (Dimension s d - 1), KnownNat d, HasShape s, HasShape s',
 HasShape (Insert d s)) =>
Proxy d -> Array s a -> Array s' a -> Array (Insert d s) a
append Proxy d
d = forall a (s :: [Nat]) (s' :: [Nat]) (d :: Nat) (i :: Nat).
(DropIndex s d ~ s', CheckInsert d i s, KnownNat i, KnownNat d,
 HasShape s, HasShape s', HasShape (Insert d s)) =>
Proxy d
-> Proxy i -> Array s a -> Array s' a -> Array (Insert d s) a
insert Proxy d
d (forall {k} (t :: k). Proxy t
Proxy :: Proxy (Dimension s d - 1))

-- | Change the order of dimensions.
--
-- >>> let r = reorder (Proxy :: Proxy '[2,0,1]) a
-- >>> :t r
-- r :: Array [4, 2, 3] Int
reorder ::
  forall a ds s.
  ( HasShape ds,
    HasShape s,
    HasShape (Reorder s ds),
    CheckReorder ds s
  ) =>
  Proxy ds ->
  Array s a ->
  Array (Reorder s ds) a
reorder :: forall a (ds :: [Nat]) (s :: [Nat]).
(HasShape ds, HasShape s, HasShape (Reorder s ds),
 CheckReorder ds s) =>
Proxy ds -> Array s a -> Array (Reorder s ds) a
reorder Proxy ds
_ Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [] [Int]
ds [Int]
s)
    ds :: [Int]
ds = forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Product two arrays using the supplied binary function.
--
-- For context, if the function is multiply, and the arrays are tensors,
-- then this can be interpreted as a tensor product.
--
-- < https://en.wikipedia.org/wiki/Tensor_product>
--
-- The concept of a tensor product is a dense crossroad, and a complete treatment is elsewhere.  To quote:
--
-- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category.
--
-- >>> expand (*) v v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
--
-- Alternatively, expand can be understood as representing the permutation of element pairs of two arrays, so like the Applicative List instance.
--
-- >>> i2 = indices :: Array '[2,2] [Int]
-- >>> expand (,) i2 i2
-- [[[[([0,0],[0,0]), ([0,0],[0,1])],
--    [([0,0],[1,0]), ([0,0],[1,1])]],
--   [[([0,1],[0,0]), ([0,1],[0,1])],
--    [([0,1],[1,0]), ([0,1],[1,1])]]],
--  [[[([1,0],[0,0]), ([1,0],[0,1])],
--    [([1,0],[1,0]), ([1,0],[1,1])]],
--   [[([1,1],[0,0]), ([1,1],[0,1])],
--    [([1,1],[1,0]), ([1,1],[1,1])]]]]
expand ::
  forall s s' a b c.
  ( HasShape s,
    HasShape s',
    HasShape ((++) s s')
  ) =>
  (a -> b -> c) ->
  Array s a ->
  Array s' b ->
  Array ((++) s s') c
expand :: forall (s :: [Nat]) (s' :: [Nat]) a b c.
(HasShape s, HasShape s', HasShape (s ++ s')) =>
(a -> b -> c) -> Array s a -> Array s' b -> Array (s ++ s') c
expand a -> b -> c
f Array s a
a Array s' b
b = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array (s ++ s'))
i -> a -> b -> c
f (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (forall a. Int -> [a] -> [a]
take Int
r Rep (Array (s ++ s'))
i)) (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s' b
b (forall a. Int -> [a] -> [a]
drop Int
r Rep (Array (s ++ s'))
i)))
  where
    r :: Int
r = forall a. [a] -> Int
rank (forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s a
a)

-- | Like expand, but permutes the first array first, rather than the second.
--
-- >>> expand (,) v (v |+ 3)
-- [[(1,4), (1,5), (1,6)],
--  [(2,4), (2,5), (2,6)],
--  [(3,4), (3,5), (3,6)]]
--
-- >>> expandr (,) v (v |+ 3)
-- [[(1,4), (2,4), (3,4)],
--  [(1,5), (2,5), (3,5)],
--  [(1,6), (2,6), (3,6)]]
expandr ::
  forall s s' a b c.
  ( HasShape s,
    HasShape s',
    HasShape ((++) s s')
  ) =>
  (a -> b -> c) ->
  Array s a ->
  Array s' b ->
  Array ((++) s s') c
expandr :: forall (s :: [Nat]) (s' :: [Nat]) a b c.
(HasShape s, HasShape s', HasShape (s ++ s')) =>
(a -> b -> c) -> Array s a -> Array s' b -> Array (s ++ s') c
expandr a -> b -> c
f Array s a
a Array s' b
b = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array (s ++ s'))
i -> a -> b -> c
f (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (forall a. Int -> [a] -> [a]
drop Int
r Rep (Array (s ++ s'))
i)) (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s' b
b (forall a. Int -> [a] -> [a]
take Int
r Rep (Array (s ++ s'))
i)))
  where
    r :: Int
r = forall a. [a] -> Int
rank (forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s a
a)

-- | Apply an array of functions to each array of values.
--
-- This is in the spirit of the applicative functor operation (\<*\>).
--
-- > expand f a b == apply (fmap f a) b
--
-- >>> apply ((*) <$> v) v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
--
-- Fixed Arrays can't be applicative functors because the changes in shape are reflected in the types.
--
-- > :t apply
-- > apply
-- >   :: (HasShape s, HasShape s', HasShape (s ++ s')) =>
-- >      Array s (a -> b) -> Array s' a -> Array (s ++ s') b
-- > :t (<*>)
-- > (<*>) :: Applicative f => f (a -> b) -> f a -> f b
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> contract sum (Proxy :: Proxy '[1,2]) (apply (fmap (*) b) (transpose b))
-- [[14, 32],
--  [32, 77]]
apply ::
  forall s s' a b.
  ( HasShape s,
    HasShape s',
    HasShape ((++) s s')
  ) =>
  Array s (a -> b) ->
  Array s' a ->
  Array ((++) s s') b
apply :: forall (s :: [Nat]) (s' :: [Nat]) a b.
(HasShape s, HasShape s', HasShape (s ++ s')) =>
Array s (a -> b) -> Array s' a -> Array (s ++ s') b
apply Array s (a -> b)
f Array s' a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate (\Rep (Array (s ++ s'))
i -> forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s (a -> b)
f (forall a. Int -> [a] -> [a]
take Int
r Rep (Array (s ++ s'))
i) (forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s' a
a (forall a. Int -> [a] -> [a]
drop Int
r Rep (Array (s ++ s'))
i)))
  where
    r :: Int
r = forall a. [a] -> Int
rank (forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s (a -> b)
f)

-- | Contract an array by applying the supplied (folding) function on diagonal elements of the dimensions.
--
-- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2, and allowing a binary operator other than multiplication.
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> contract sum (Proxy :: Proxy '[1,2]) (expand (*) b (transpose b))
-- [[14, 32],
--  [32, 77]]
contract ::
  forall a b s ss s' ds.
  ( KnownNat (Minimum (TakeIndexes s ds)),
    HasShape (TakeIndexes s ds),
    HasShape s,
    HasShape ds,
    HasShape ss,
    HasShape s',
    s' ~ DropIndexes s ds,
    ss ~ '[Minimum (TakeIndexes s ds)]
  ) =>
  (Array ss a -> b) ->
  Proxy ds ->
  Array s a ->
  Array s' b
contract :: forall a b (s :: [Nat]) (ss :: [Nat]) (s' :: [Nat]) (ds :: [Nat]).
(KnownNat (Minimum (TakeIndexes s ds)),
 HasShape (TakeIndexes s ds), HasShape s, HasShape ds, HasShape ss,
 HasShape s', s' ~ DropIndexes s ds,
 ss ~ '[Minimum (TakeIndexes s ds)]) =>
(Array ss a -> b) -> Proxy ds -> Array s a -> Array s' b
contract Array ss a -> b
f Proxy ds
xs Array s a
a = Array ss a -> b
f forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a (s :: [Nat]).
(HasShape s, HasShape '[Minimum s]) =>
Array s a -> Array '[Minimum s] a
diag forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 so ~ DropIndexes st ds, si ~ TakeIndexes st ds) =>
Proxy ds -> Array st a -> Array so (Array si a)
extractsExcept Proxy ds
xs Array s a
a

-- | A generalisation of a dot operation, which is a multiplicative expansion of two arrays and sum contraction along the middle two dimensions.
--
-- matrix multiplication
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> dot sum (*) b (transpose b)
-- [[14, 32],
--  [32, 77]]
--
-- inner product
--
-- >>> let v = [1..3] :: Array '[3] Int
-- >>> :t dot sum (*) v v
-- dot sum (*) v v :: Array '[] Int
--
-- >>> dot sum (*) v v
-- 14
--
-- matrix-vector multiplication
-- (Note how the vector doesn't need to be converted to a row or column vector)
--
-- >>> dot sum (*) v b
-- [9, 12, 15]
--
-- >>> dot sum (*) b v
-- [14, 32]
--
-- Array elements don't have to be numbers:
--
-- >>> x1 = (show <$> [1..4]) :: Array '[2,2] String
-- >>> x2 = (show <$> [5..8]) :: Array '[2,2] String
-- >>> x1
-- [["1", "2"],
--  ["3", "4"]]
--
-- >>> x2
-- [["5", "6"],
--  ["7", "8"]]
--
-- >>> import Data.List (intercalate)
-- >>> dot (intercalate "+" . toList) (\a b -> a <> "*" <> b) x1 x2
-- [["1*5+2*7", "1*6+2*8"],
--  ["3*5+4*7", "3*6+4*8"]]
--
-- 'dot' allows operation on mis-shaped matrices. The algorithm ignores excess positions within the contracting dimension(s):
--
-- >>> let m23 = [1..6] :: Array '[2,3] Int
-- >>> let m12 = [1,2] :: Array '[1,2] Int
-- >>> shape $ dot sum (*) m23 m12
-- [2,2]
--
-- Find instances of a vector in a matrix
--
-- >>> let cs = fromList ("abacbaab" :: [Char]) :: Array '[4,2] Char
-- >>> let v = fromList ("ab" :: [Char]) :: Vector 2 Char
-- >>> dot (all id) (==) cs v
-- [True, False, False, True]
dot ::
  forall a b c d sa sb s' ss se.
  ( HasShape sa,
    HasShape sb,
    HasShape (sa ++ sb),
    se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape se,
    KnownNat (Minimum se),
    KnownNat (Rank sa - 1),
    KnownNat (Rank sa),
    ss ~ '[Minimum se],
    HasShape ss,
    s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape s'
  ) =>
  (Array ss c -> d) ->
  (a -> b -> c) ->
  Array sa a ->
  Array sb b ->
  Array s' d
dot :: forall a b c d (sa :: [Nat]) (sb :: [Nat]) (s' :: [Nat])
       (ss :: [Nat]) (se :: [Nat]).
(HasShape sa, HasShape sb, HasShape (sa ++ sb),
 se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa], HasShape se,
 KnownNat (Minimum se), KnownNat (Rank sa - 1), KnownNat (Rank sa),
 ss ~ '[Minimum se], HasShape ss,
 s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
 HasShape s') =>
(Array ss c -> d)
-> (a -> b -> c) -> Array sa a -> Array sb b -> Array s' d
dot Array ss c -> d
f a -> b -> c
g Array sa a
a Array sb b
b = forall a b (s :: [Nat]) (ss :: [Nat]) (s' :: [Nat]) (ds :: [Nat]).
(KnownNat (Minimum (TakeIndexes s ds)),
 HasShape (TakeIndexes s ds), HasShape s, HasShape ds, HasShape ss,
 HasShape s', s' ~ DropIndexes s ds,
 ss ~ '[Minimum (TakeIndexes s ds)]) =>
(Array ss a -> b) -> Proxy ds -> Array s a -> Array s' b
contract Array ss c -> d
f (forall {k} (t :: k). Proxy t
Proxy :: Proxy '[Rank sa - 1, Rank sa]) (forall (s :: [Nat]) (s' :: [Nat]) a b c.
(HasShape s, HasShape s', HasShape (s ++ s')) =>
(a -> b -> c) -> Array s a -> Array s' b -> Array (s ++ s') c
expand a -> b -> c
g Array sa a
a Array sb b
b)

-- | Array multiplication.
--
-- matrix multiplication
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> mult b (transpose b)
-- [[14, 32],
--  [32, 77]]
--
-- inner product
--
-- >>> let v = [1..3] :: Array '[3] Int
-- >>> :t mult v v
-- mult v v :: Array '[] Int
--
-- >>> mult v v
-- 14
--
-- matrix-vector multiplication
--
-- >>> mult v b
-- [9, 12, 15]
--
-- >>> mult b v
-- [14, 32]
mult ::
  forall a sa sb s' ss se.
  ( Additive a,
    Multiplicative a,
    HasShape sa,
    HasShape sb,
    HasShape (sa ++ sb),
    se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape se,
    KnownNat (Minimum se),
    KnownNat (Rank sa - 1),
    KnownNat (Rank sa),
    ss ~ '[Minimum se],
    HasShape ss,
    s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape s'
  ) =>
  Array sa a ->
  Array sb a ->
  Array s' a
mult :: forall a (sa :: [Nat]) (sb :: [Nat]) (s' :: [Nat]) (ss :: [Nat])
       (se :: [Nat]).
(Additive a, Multiplicative a, HasShape sa, HasShape sb,
 HasShape (sa ++ sb),
 se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa], HasShape se,
 KnownNat (Minimum se), KnownNat (Rank sa - 1), KnownNat (Rank sa),
 ss ~ '[Minimum se], HasShape ss,
 s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
 HasShape s') =>
Array sa a -> Array sb a -> Array s' a
mult = forall a b c d (sa :: [Nat]) (sb :: [Nat]) (s' :: [Nat])
       (ss :: [Nat]) (se :: [Nat]).
(HasShape sa, HasShape sb, HasShape (sa ++ sb),
 se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa], HasShape se,
 KnownNat (Minimum se), KnownNat (Rank sa - 1), KnownNat (Rank sa),
 ss ~ '[Minimum se], HasShape ss,
 s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
 HasShape s') =>
(Array ss c -> d)
-> (a -> b -> c) -> Array sa a -> Array sb b -> Array s' d
dot forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum forall a. Multiplicative a => a -> a -> a
(*)

-- | Select elements along positions in every dimension.
--
-- >>> let s = slice (Proxy :: Proxy '[[0,1],[0,2],[1,2]]) a
-- >>> :t s
-- s :: Array [2, 2, 2] Int
--
-- >>> s
-- [[[2, 3],
--   [10, 11]],
--  [[14, 15],
--   [22, 23]]]
--
-- >>> let s = squeeze $ slice (Proxy :: Proxy '[ '[0], '[0], '[0]]) a
-- >>> :t s
-- s :: Array '[] Int
--
-- >>> s
-- 1
slice ::
  forall (pss :: [[Nat]]) s s' a.
  ( HasShape s,
    HasShape s',
    KnownNatss pss,
    KnownNat (Rank pss),
    s' ~ Ranks pss
  ) =>
  Proxy pss ->
  Array s a ->
  Array s' a
slice :: forall (pss :: [[Nat]]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape s', KnownNatss pss, KnownNat (Rank pss),
 s' ~ Ranks pss) =>
Proxy pss -> Array s a -> Array s' a
slice Proxy pss
pss Array s a
a = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array s a
a (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. [a] -> Int -> a
(!!) [[Int]]
pss' [Int]
s)
    pss' :: [[Int]]
pss' = forall (ns :: [[Nat]]). KnownNatss ns => Proxy ns -> [[Int]]
natValss Proxy pss
pss

-- | Remove single dimensions.
--
-- >>> let a = [1..24] :: Array '[2,1,3,4,1] Int
-- >>> a
-- [[[[[1],
--     [2],
--     [3],
--     [4]],
--    [[5],
--     [6],
--     [7],
--     [8]],
--    [[9],
--     [10],
--     [11],
--     [12]]]],
--  [[[[13],
--     [14],
--     [15],
--     [16]],
--    [[17],
--     [18],
--     [19],
--     [20]],
--    [[21],
--     [22],
--     [23],
--     [24]]]]]
-- >>> squeeze a
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
--
-- >>> squeeze ([1] :: Array '[1,1] Double)
-- 1.0
squeeze ::
  forall s t a.
  (t ~ Squeeze s) =>
  Array s a ->
  Array t a
squeeze :: forall (s :: [Nat]) (t :: [Nat]) a.
(t ~ Squeeze s) =>
Array s a -> Array t a
squeeze (Array Vector a
x) = forall {k} (s :: k) a. Vector a -> Array s a
Array Vector a
x

-- $scalar
-- Scalar specialisations

-- | Unwrapping scalars is probably a performance bottleneck.
--
-- >>> let s = [3] :: Array ('[] :: [Nat]) Int
-- >>> fromScalar s
-- 3
fromScalar :: (HasShape ('[] :: [Nat])) => Array ('[] :: [Nat]) a -> a
fromScalar :: forall a. HasShape '[] => Array '[] a -> a
fromScalar Array '[] a
a = forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[] a
a ([] :: [Int])

-- | Convert a number to a scalar.
--
-- >>> :t toScalar 2
-- toScalar 2 :: FromInteger a => Array '[] a
toScalar :: (HasShape ('[] :: [Nat])) => a -> Array ('[] :: [Nat]) a
toScalar :: forall a. HasShape '[] => a -> Array '[] a
toScalar a
a = forall l. IsList l => [Item l] -> l
fromList [a
a]

-- | <https://en.wikipedia.org/wiki/Vector_(mathematics_and_physics) Wiki Vector>
type Vector s a = Array '[s] a

-- | Vector specialisation of 'sequent'
sequentv :: forall n. (KnownNat n) => Vector n Int
sequentv :: forall (n :: Nat). KnownNat n => Vector n Int
sequentv = forall (s :: [Nat]). HasShape s => Array s Int
sequent

-- | <https://en.wikipedia.org/wiki/Matrix_(mathematics) Wiki Matrix>
type Matrix m n a = Array '[m, n] a

instance
  ( Multiplicative a,
    P.Distributive a,
    Subtractive a,
    KnownNat m,
    HasShape '[m, m]
  ) =>
  Multiplicative (Matrix m m a)
  where
  * :: Matrix m m a -> Matrix m m a -> Matrix m m a
(*) = forall (m :: Nat) (n :: Nat) (k :: Nat) a.
(KnownNat k, KnownNat m, KnownNat n, HasShape '[m, n], Ring a) =>
Array '[m, k] a -> Array '[k, n] a -> Array '[m, n] a
mmult

  one :: Matrix m m a
one = forall a (s :: [Nat]).
(HasShape s, Additive a, Multiplicative a) =>
Array s a
ident

instance
  ( Multiplicative a,
    P.Distributive a,
    Subtractive a,
    Eq a,
    ExpField a,
    KnownNat m,
    HasShape '[m, m]
  ) =>
  Divisive (Matrix m m a)
  where
  recip :: Matrix m m a -> Matrix m m a
recip Matrix m m a
a = forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Array '[n, n] a -> Array '[n, n] a
invtri (forall a (s :: [Nat]).
(HasShape s, HasShape (Reverse s)) =>
Array s a -> Array (Reverse s) a
transpose (forall (n :: Nat) a.
(KnownNat n, ExpField a) =>
Array '[n, n] a -> Array '[n, n] a
chol Matrix m m a
a)) forall a. Multiplicative a => a -> a -> a
* forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Array '[n, n] a -> Array '[n, n] a
invtri (forall (n :: Nat) a.
(KnownNat n, ExpField a) =>
Array '[n, n] a -> Array '[n, n] a
chol Matrix m m a
a)

-- | <https://math.stackexchange.com/questions/1003801/inverse-of-an-invertible-upper-triangular-matrix-of-order-3 Inverse of a triangular> matrix.
invtri :: forall a n. (KnownNat n, ExpField a, Eq a) => Array '[n, n] a -> Array '[n, n] a
invtri :: forall a (n :: Nat).
(KnownNat n, ExpField a, Eq a) =>
Array '[n, n] a -> Array '[n, n] a
invtri Array '[n, n] a
a = forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Array '[n, n] a
l ^) (forall (n :: Nat). KnownNat n => Vector n Int
sequentv :: Vector n Int)) forall a. Multiplicative a => a -> a -> a
* Array ('[n] ++ '[n]) a
ti
  where
    ti :: Array ('[n] ++ '[n]) a
ti = forall a (s :: [Nat]).
(HasShape s, Additive a, HasShape (s ++ s)) =>
Array s a -> Array (s ++ s) a
undiag (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Divisive a => a -> a
recip (forall a (s :: [Nat]).
(HasShape s, HasShape '[Minimum s]) =>
Array s a -> Array '[Minimum s] a
diag Array '[n, n] a
a))
    tl :: Array '[n, n] a
tl = Array '[n, n] a
a forall a. Subtractive a => a -> a -> a
- forall a (s :: [Nat]).
(HasShape s, Additive a, HasShape (s ++ s)) =>
Array s a -> Array (s ++ s) a
undiag (forall a (s :: [Nat]).
(HasShape s, HasShape '[Minimum s]) =>
Array s a -> Array '[Minimum s] a
diag Array '[n, n] a
a)
    l :: Array '[n, n] a
l = forall a. Subtractive a => a -> a
negate (Array ('[n] ++ '[n]) a
ti forall a. Multiplicative a => a -> a -> a
* Array '[n, n] a
tl)

-- | cholesky decomposition
--
-- Uses the <https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky_algorithm Cholesky-Crout> algorithm.
chol :: (KnownNat n, ExpField a) => Array '[n, n] a -> Array '[n, n] a
chol :: forall (n :: Nat) a.
(KnownNat n, ExpField a) =>
Array '[n, n] a -> Array '[n, n] a
chol Array '[n, n] a
a =
  let l :: Array '[n, n] a
l =
        forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate
          ( \[Item [Int]
i, Item [Int]
j] ->
              forall a. a -> a -> Bool -> a
bool
                ( forall a. Multiplicative a => a
one
                    forall a. Divisive a => a -> a -> a
/ forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[n, n] a
l [Item [Int]
j, Item [Int]
j]
                    forall a. Multiplicative a => a -> a -> a
* ( forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[n, n] a
a [Item [Int]
i, Item [Int]
j]
                          forall a. Subtractive a => a -> a -> a
- forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum
                            ( (\Int
k -> forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[n, n] a
l [Item [Int]
i, Int
k] forall a. Multiplicative a => a -> a -> a
* forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[n, n] a
l [Item [Int]
j, Int
k])
                                forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([forall a. Additive a => a
zero .. (Item [Int]
j forall a. Subtractive a => a -> a -> a
- forall a. Multiplicative a => a
one)] :: [Int])
                            )
                      )
                )
                ( forall a. ExpField a => a -> a
sqrt
                    ( forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[n, n] a
a [Item [Int]
i, Item [Int]
i]
                        forall a. Subtractive a => a -> a -> a
- forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum
                          ( (\Int
k -> forall (f :: * -> *) a. Representable f => f a -> Rep f -> a
index Array '[n, n] a
l [Item [Int]
j, Int
k] forall a. Divisive a => a -> Int -> a
^ Int
2)
                              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([forall a. Additive a => a
zero .. (Item [Int]
j forall a. Subtractive a => a -> a -> a
- forall a. Multiplicative a => a
one)] :: [Int])
                          )
                    )
                )
                (Item [Int]
i forall a. Eq a => a -> a -> Bool
== Item [Int]
j)
          )
   in Array '[n, n] a
l

-- | Extract specialised to a matrix.
--
-- >>> row 1 m
-- [4, 5, 6, 7]
row :: forall m n a. (KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
row :: forall (m :: Nat) (n :: Nat) a.
(KnownNat m, KnownNat n, HasShape '[m, n]) =>
Int -> Matrix m n a -> Vector n a
row Int
i (Array Vector a
a) = forall {k} (s :: k) a. Vector a -> Array s a
Array forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
i forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
  where
    n :: Int
n = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @n forall {k} (t :: k). Proxy t
Proxy

-- | Row extraction checked at type level.
--
-- >>> safeRow (Proxy :: Proxy 1) m
-- [4, 5, 6, 7]
--
-- >>> safeRow (Proxy :: Proxy 3) m
-- ...
-- ... index outside range
-- ...
safeRow :: forall m n a j. ('True ~ CheckIndex j m, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeRow :: forall (m :: Nat) (n :: Nat) a (j :: Nat).
('True ~ CheckIndex j m, KnownNat j, KnownNat m, KnownNat n,
 HasShape '[m, n]) =>
Proxy j -> Matrix m n a -> Vector n a
safeRow Proxy j
_j (Array Vector a
a) = forall {k} (s :: k) a. Vector a -> Array s a
Array forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
j forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
  where
    n :: Int
n = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @n forall {k} (t :: k). Proxy t
Proxy
    j :: Int
j = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @j forall {k} (t :: k). Proxy t
Proxy

-- | Extract specialised to a matrix.
--
-- >>> col 1 m
-- [1, 5, 9]
col :: forall m n a. (KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
col :: forall (m :: Nat) (n :: Nat) a.
(KnownNat m, KnownNat n, HasShape '[m, n]) =>
Int -> Matrix m n a -> Vector n a
col Int
i (Array Vector a
a) = forall {k} (s :: k) a. Vector a -> Array s a
Array forall a b. (a -> b) -> a -> b
$ forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
i forall a. Additive a => a -> a -> a
+ Int
x forall a. Multiplicative a => a -> a -> a
* Int
n))
  where
    m :: Int
m = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @m forall {k} (t :: k). Proxy t
Proxy
    n :: Int
n = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @n forall {k} (t :: k). Proxy t
Proxy

-- | Column extraction checked at type level.
--
-- >>> safeCol (Proxy :: Proxy 1) m
-- [1, 5, 9]
--
-- >>> safeCol (Proxy :: Proxy 4) m
-- ...
-- ... index outside range
-- ...
safeCol :: forall m n a j. ('True ~ CheckIndex j n, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeCol :: forall (m :: Nat) (n :: Nat) a (j :: Nat).
('True ~ CheckIndex j n, KnownNat j, KnownNat m, KnownNat n,
 HasShape '[m, n]) =>
Proxy j -> Matrix m n a -> Vector n a
safeCol Proxy j
_j (Array Vector a
a) = forall {k} (s :: k) a. Vector a -> Array s a
Array forall a b. (a -> b) -> a -> b
$ forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
j forall a. Additive a => a -> a -> a
+ Int
x forall a. Multiplicative a => a -> a -> a
* Int
n))
  where
    m :: Int
m = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @m forall {k} (t :: k). Proxy t
Proxy
    n :: Int
n = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @n forall {k} (t :: k). Proxy t
Proxy
    j :: Int
j = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @j forall {k} (t :: k). Proxy t
Proxy

-- | Matrix multiplication.
--
-- This is dot sum (*) specialised to matrices
--
-- >>> let a = [1, 2, 3, 4] :: Array '[2, 2] Int
-- >>> let b = [5, 6, 7, 8] :: Array '[2, 2] Int
-- >>> a
-- [[1, 2],
--  [3, 4]]
--
-- >>> b
-- [[5, 6],
--  [7, 8]]
--
-- >>> mmult a b
-- [[19, 22],
--  [43, 50]]
mmult ::
  forall m n k a.
  ( KnownNat k,
    KnownNat m,
    KnownNat n,
    HasShape [m, n],
    Ring a
  ) =>
  Array [m, k] a ->
  Array [k, n] a ->
  Array [m, n] a
mmult :: forall (m :: Nat) (n :: Nat) (k :: Nat) a.
(KnownNat k, KnownNat m, KnownNat n, HasShape '[m, n], Ring a) =>
Array '[m, k] a -> Array '[k, n] a -> Array '[m, n] a
mmult (Array Vector a
x) (Array Vector a
y) = forall (f :: * -> *) a. Representable f => (Rep f -> a) -> f a
tabulate [Int] -> a
go
  where
    go :: [Int] -> a
go [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
    go [Item [Int]
_] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
    go (Int
i : Int
j : [Int]
_) = forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith forall a. Multiplicative a => a -> a -> a
(*) (forall a. Int -> Int -> Vector a -> Vector a
V.slice (forall a b. FromIntegral a b => b -> a
fromIntegral Int
i forall a. Multiplicative a => a -> a -> a
* Int
k) Int
k Vector a
x) (forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k (\Int
x' -> Vector a
y forall a. Vector a -> Int -> a
V.! (forall a b. FromIntegral a b => b -> a
fromIntegral Int
j forall a. Additive a => a -> a -> a
+ Int
x' forall a. Multiplicative a => a -> a -> a
* Int
n)))
    n :: Int
n = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @n forall {k} (t :: k). Proxy t
Proxy
    k :: Int
k = forall a b. FromIntegral a b => b -> a
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal @k forall {k} (t :: k). Proxy t
Proxy
{-# INLINE mmult #-}