{-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module Data.Array.ShapedS.MatMul(matMul) where import GHC.TypeLits import Data.Array.Internal(valueOf) import Data.Array.ShapedS import Numeric.LinearAlgebra as N matMul :: forall m n o a . (N.Numeric a, KnownNat m, KnownNat n, KnownNat o) => Array [m, n] a -> Array [n, o] a -> Array [m, o] a matMul :: Array '[m, n] a -> Array '[n, o] a -> Array '[m, o] a matMul Array '[m, n] a x Array '[n, o] a y = let n :: Int n = forall i. (KnownNat n, Num i) => i forall (n :: Nat) i. (KnownNat n, Num i) => i valueOf @n o :: Int o = forall i. (KnownNat o, Num i) => i forall (n :: Nat) i. (KnownNat n, Num i) => i valueOf @o x' :: Matrix a x' = Int -> Vector a -> Matrix a forall t. Storable t => Int -> Vector t -> Matrix t N.reshape Int n (Vector a -> Matrix a) -> Vector a -> Matrix a forall a b. (a -> b) -> a -> b $ Array '[m, n] a -> Vector a forall a (sh :: [Nat]). (Unbox a, Shape sh) => Array sh a -> Vector a toVector Array '[m, n] a x y' :: Matrix a y' = Int -> Vector a -> Matrix a forall t. Storable t => Int -> Vector t -> Matrix t N.reshape Int o (Vector a -> Matrix a) -> Vector a -> Matrix a forall a b. (a -> b) -> a -> b $ Array '[n, o] a -> Vector a forall a (sh :: [Nat]). (Unbox a, Shape sh) => Array sh a -> Vector a toVector Array '[n, o] a y xy' :: Matrix a xy' = Matrix a x' Matrix a -> Matrix a -> Matrix a forall t. Numeric t => Matrix t -> Matrix t -> Matrix t N.<> Matrix a y' xy :: Array '[m, o] a xy = Vector a -> Array '[m, o] a forall (sh :: [Nat]) a. (HasCallStack, Unbox a, Shape sh) => Vector a -> Array sh a fromVector (Vector a -> Array '[m, o] a) -> Vector a -> Array '[m, o] a forall a b. (a -> b) -> a -> b $ Matrix a -> Vector a forall t. Element t => Matrix t -> Vector t N.flatten Matrix a xy' in Array '[m, o] a xy