{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.ShapedS.MatMul(matMul) where
import GHC.TypeLits
import Data.Array.Internal(valueOf)
import Data.Array.ShapedS
import Numeric.LinearAlgebra as N

matMul :: forall m n o a .
          (N.Numeric a, KnownNat m, KnownNat n, KnownNat o) =>
          Array [m, n] a -> Array [n, o] a -> Array [m, o] a
matMul :: Array '[m, n] a -> Array '[n, o] a -> Array '[m, o] a
matMul Array '[m, n] a
x Array '[n, o] a
y =
  let n :: Int
n = forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n
      o :: Int
o = forall i. (KnownNat o, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @o
      x' :: Matrix a
x' = Int -> Vector a -> Matrix a
forall t. Storable t => Int -> Vector t -> Matrix t
N.reshape Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Array '[m, n] a -> Vector a
forall a (sh :: [Nat]).
(Unbox a, Shape sh) =>
Array sh a -> Vector a
toVector Array '[m, n] a
x
      y' :: Matrix a
y' = Int -> Vector a -> Matrix a
forall t. Storable t => Int -> Vector t -> Matrix t
N.reshape Int
o (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Array '[n, o] a -> Vector a
forall a (sh :: [Nat]).
(Unbox a, Shape sh) =>
Array sh a -> Vector a
toVector Array '[n, o] a
y
      xy' :: Matrix a
xy' = Matrix a
x' Matrix a -> Matrix a -> Matrix a
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
N.<> Matrix a
y'
      xy :: Array '[m, o] a
xy = Vector a -> Array '[m, o] a
forall (sh :: [Nat]) a.
(HasCallStack, Unbox a, Shape sh) =>
Vector a -> Array sh a
fromVector (Vector a -> Array '[m, o] a) -> Vector a -> Array '[m, o] a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Vector a
forall t. Element t => Matrix t -> Vector t
N.flatten Matrix a
xy'
  in  Array '[m, o] a
xy