{-|
Module        : Data.Array.Accelerate.Matrix
Description   : Functions for plain and dependently typed matrix math.
Copyright     : (c) Noah Martin Williams 2024
License       : BSD-3-Clause
Maintainer    : noahmartinwilliams@gmail.com
Stability     : experimental
Portability   : POSIX

This module contains functions for doing matrix math such as addition, subtraction, and multiplication
for both plain and dependently typed matrices.
-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
module Data.Array.Accelerate.Matrix(mMul, matMul, identMat, Mat(..), AccMat(..), matTransp, matAdd, mAdd, mSub, matSub, useMat, matScale) where

import Prelude as P
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Control.Lens.Shape

-- |Creat an identity matrix with the dimension provided
identMat :: Exp Int -> Acc (Matrix Int)
identMat :: Exp Int -> Acc (Matrix Int)
identMat n :: Exp Int
n = Exp ((Z :. Int) :. Int)
-> (Exp ((Z :. Int) :. Int) -> Exp Int) -> Acc (Matrix Int)
forall sh a.
(Shape sh, Elt a) =>
Exp sh -> (Exp sh -> Exp a) -> Acc (Array sh a)
generate (Exp Int -> Exp Int -> Exp ((Z :. Int) :. Int)
forall i. Elt i => Exp i -> Exp i -> Exp ((Z :. i) :. i)
index2 Exp Int
n Exp Int
n) (\(I2 a :: Exp Int
a b :: Exp Int
b) -> Exp Int
a Exp Int -> Exp Int -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
A.== Exp Int
b Exp Bool -> (Exp Int, Exp Int) -> Exp Int
forall t. Elt t => Exp Bool -> (Exp t, Exp t) -> Exp t
A.? ((Int -> Exp Int
forall e. (HasCallStack, Elt e) => e -> Exp e
constant 1), (Int -> Exp Int
forall e. (HasCallStack, Elt e) => e -> Exp e
constant 0)))

-- These two lines were blatantly stolen from here: https://hackage.haskell.org/package/accelerate-1.3.0.0/docs/Data-Array-Accelerate.html
rep0 :: (Shape sh, Elt e) => Exp Int -> Acc (Array sh e) -> Acc (Array (sh:.Int) e)
rep0 :: Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e)
rep0 n :: Exp Int
n a :: Acc (Array sh e)
a = Exp (Any sh :. Int)
-> Acc (Array (SliceShape (Any sh :. Int)) e)
-> Acc (Array (FullShape (Any sh :. Int)) e)
forall slix e.
(Slice slix, Elt e) =>
Exp slix
-> Acc (Array (SliceShape slix) e)
-> Acc (Array (FullShape slix) e)
A.replicate ((Any sh :. Exp Int) -> Exp (Plain (Any sh :. Exp Int))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Any sh
forall sh. Any sh
AnyAny sh -> Exp Int -> Any sh :. Exp Int
forall tail head. tail -> head -> tail :. head
:.Exp Int
n)) Acc (Array sh e)
Acc (Array (SliceShape (Any sh :. Int)) e)
a

{-
    Example:
    let m1 = A.generate (constant (Z:.3:.5)) (\(I2 x y) -> (A.fromIntegral x :: Exp Double) * 5.0 + (A.fromIntegral y :: Exp Double))
        m2 = A.generate (constant (Z:.5:.3)) (\(I2 x y) -> (A.fromIntegral x :: Exp Double) * 5.0 + (A.fromIntegral y :: Exp Double))
        m1 `mMul` m2 = 
   [150   160   170
    400   435   470
    650   710   770]
-}

-- |Multiply two matrices together without dependent types.
mMul :: A.Num e => Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mMul :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mMul m1 :: Acc (Matrix e)
m1 m2 :: Acc (Matrix e)
m2 = do
    let Z:.m1NumRows :: Exp Int
m1NumRows:._ = Exp (Plain ((Z :. Exp Int) :. Exp Int))
-> (Z :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp ((Z :. Int) :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
m1 :: Exp (Plain (Z:.Int:.Int))) :: (Z:.Exp Int:.Exp Int)
        Z:._:.m2NumCols :: Exp Int
m2NumCols = Exp (Plain ((Z :. Exp Int) :. Exp Int))
-> (Z :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp ((Z :. Int) :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
m2 :: Exp (Plain (Z:.Int:.Int))) :: (Z:.Exp Int:.Exp Int)
        m1Tensor :: Acc (Array (((Z :. Int) :. Int) :. Int) e)
m1Tensor = Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
-> Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
forall sh e.
(Shape sh, Elt e) =>
Lens' (Exp sh) (Exp Int)
-> Lens' (Exp sh) (Exp Int) -> Acc (Array sh e) -> Acc (Array sh e)
transposeOn forall s t a b. Field1 s t a b => Lens s t a b
Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
_1 forall s t a b. Field2 s t a b => Lens s t a b
Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
_2 (Exp Int
-> Acc (Matrix e) -> Acc (Array (((Z :. Int) :. Int) :. Int) e)
forall sh e.
(Shape sh, Elt e) =>
Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e)
rep0 Exp Int
m2NumCols Acc (Matrix e)
m1)
        m2Tensor :: Acc (Array (((Z :. Int) :. Int) :. Int) e)
m2Tensor = Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
-> Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
forall sh e.
(Shape sh, Elt e) =>
Lens' (Exp sh) (Exp Int)
-> Lens' (Exp sh) (Exp Int) -> Acc (Array sh e) -> Acc (Array sh e)
transposeOn forall s t a b. Field2 s t a b => Lens s t a b
Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
_2 forall s t a b. Field3 s t a b => Lens s t a b
Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
_3 (Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
-> Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
forall sh e.
(Shape sh, Elt e) =>
Lens' (Exp sh) (Exp Int)
-> Lens' (Exp sh) (Exp Int) -> Acc (Array sh e) -> Acc (Array sh e)
transposeOn forall s t a b. Field1 s t a b => Lens s t a b
Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
_1 forall s t a b. Field2 s t a b => Lens s t a b
Lens' (Exp (((Z :. Int) :. Int) :. Int)) (Exp Int)
_2 (Exp Int
-> Acc (Matrix e) -> Acc (Array (((Z :. Int) :. Int) :. Int) e)
forall sh e.
(Shape sh, Elt e) =>
Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e)
rep0 Exp Int
m1NumRows (Acc (Matrix e) -> Acc (Matrix e)
forall e.
Elt e =>
Acc (Array ((Z :. Int) :. Int) e)
-> Acc (Array ((Z :. Int) :. Int) e)
A.transpose Acc (Matrix e)
m2)))
        ret :: Acc (Matrix e)
ret = Acc (Array (((Z :. Int) :. Int) :. Int) e) -> Acc (Matrix e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum ((Exp e -> Exp e -> Exp e)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
-> Acc (Array (((Z :. Int) :. Int) :. Int) e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
(*) Acc (Array (((Z :. Int) :. Int) :. Int) e)
m1Tensor Acc (Array (((Z :. Int) :. Int) :. Int) e)
m2Tensor)
    Acc (Matrix e)
ret

-- |Add two matrices without dependent types.
mAdd :: A.Num e => Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mAdd :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mAdd left :: Acc (Matrix e)
left right :: Acc (Matrix e)
right = (Exp e -> Exp e -> Exp e)
-> Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
(+) Acc (Matrix e)
left Acc (Matrix e)
right

-- |Subtract one matrix from another without dependent types.
mSub :: A.Num e => Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mSub :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mSub left :: Acc (Matrix e)
left right :: Acc (Matrix e)
right = (Exp e -> Exp e -> Exp e)
-> Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith (-) Acc (Matrix e)
left Acc (Matrix e)
right

-- |Dependent type for accelerate matrices.
data AccMat e a b where 
    -- |Dependently typed accelerated matrix which forces two types to line up.
    AccMat :: (Elt e, A.Num e) => Acc (Matrix e) -> a -> b -> AccMat e a b

-- |Dependent type for plain matrices.
data Mat e a b where
    -- |Dependently typed plain matrix for passing to compiled functions.
    Mat :: (Elt e, A.Num e) => Matrix e -> a -> b -> Mat e a b

-- |Change the type of a dependently typed matrix from AccMat to Mat.
useMat :: Mat e a b -> AccMat e a b
useMat :: Mat e a b -> AccMat e a b
useMat (Mat mat :: Matrix e
mat a :: a
a b :: b
b) = Acc (Matrix e) -> a -> b -> AccMat e a b
forall e a b.
(Elt e, Num e) =>
Acc (Matrix e) -> a -> b -> AccMat e a b
AccMat (Matrix e -> Acc (Matrix e)
forall arrays. Arrays arrays => arrays -> Acc arrays
use Matrix e
mat) a
a b
b

-- |Multiply two dependently typed matrices together.
-- |For example:
-- 
-- @
-- data A = A
-- data B = B
-- data C = C
-- 
-- let m1 = AccMat (use (fromList (Z:.10:.12) [0..] :: Matrix Int)) A B
-- let m2 = AccMat (use (fromList (Z:.12:.13) [0..] :: Matrix Int)) B C
-- let mResult = m1 `matMul` m2
matMul :: A.Num e => AccMat e a b -> AccMat e b c -> AccMat e a c
matMul :: AccMat e a b -> AccMat e b c -> AccMat e a c
matMul (AccMat left :: Acc (Matrix e)
left a :: a
a _) (AccMat right :: Acc (Matrix e)
right _ c :: c
c) = Acc (Matrix e) -> a -> c -> AccMat e a c
forall e a b.
(Elt e, Num e) =>
Acc (Matrix e) -> a -> b -> AccMat e a b
AccMat (Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall e.
Num e =>
Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mMul Acc (Matrix e)
left Acc (Matrix e)
right) a
a c
c

-- |Add two dependently typed matrices.
matAdd :: A.Num e => AccMat e a b -> AccMat e a b -> AccMat e a b
matAdd :: AccMat e a b -> AccMat e a b -> AccMat e a b
matAdd (AccMat left :: Acc (Matrix e)
left a :: a
a b :: b
b) (AccMat right :: Acc (Matrix e)
right _ _) = Acc (Matrix e) -> a -> b -> AccMat e a b
forall e a b.
(Elt e, Num e) =>
Acc (Matrix e) -> a -> b -> AccMat e a b
AccMat (Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall e.
Num e =>
Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mAdd Acc (Matrix e)
left Acc (Matrix e)
right) a
a b
b

-- |Subtract one dependently typed matrix from another.
matSub :: A.Num e => AccMat e a b -> AccMat e a b -> AccMat e a b
matSub :: AccMat e a b -> AccMat e a b -> AccMat e a b
matSub (AccMat left :: Acc (Matrix e)
left a :: a
a b :: b
b) (AccMat right :: Acc (Matrix e)
right _ _) = Acc (Matrix e) -> a -> b -> AccMat e a b
forall e a b.
(Elt e, Num e) =>
Acc (Matrix e) -> a -> b -> AccMat e a b
AccMat (Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
forall e.
Num e =>
Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mSub Acc (Matrix e)
left Acc (Matrix e)
right) a
a b
b

-- |Transpose a dependently typed matrix.
matTransp :: AccMat e a b -> AccMat e b a
matTransp :: AccMat e a b -> AccMat e b a
matTransp (AccMat mat :: Acc (Matrix e)
mat a :: a
a b :: b
b) = Acc (Matrix e) -> b -> a -> AccMat e b a
forall e a b.
(Elt e, Num e) =>
Acc (Matrix e) -> a -> b -> AccMat e a b
AccMat (Acc (Matrix e) -> Acc (Matrix e)
forall e.
Elt e =>
Acc (Array ((Z :. Int) :. Int) e)
-> Acc (Array ((Z :. Int) :. Int) e)
A.transpose Acc (Matrix e)
mat) b
b a
a

-- |Scale a dependently typed matrix.
matScale :: A.Num e => Exp e -> AccMat e a b -> AccMat e a b
matScale :: Exp e -> AccMat e a b -> AccMat e a b
matScale s :: Exp e
s (AccMat m :: Acc (Matrix e)
m a :: a
a b :: b
b) = Acc (Matrix e) -> a -> b -> AccMat e a b
forall e a b.
(Elt e, Num e) =>
Acc (Matrix e) -> a -> b -> AccMat e a b
AccMat ((Exp e -> Exp e) -> Acc (Matrix e) -> Acc (Matrix e)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map (\x :: Exp e
x -> Exp e
s Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e
x) Acc (Matrix e)
m) a
a b
b