{-# LANGUAGE TypeFamilies          #-}

{-# LANGUAGE MultiParamTypeClasses #-}

{-# LANGUAGE TypeSynonymInstances  #-}           

{-# LANGUAGE DataKinds             #-}

{-# LANGUAGE KindSignatures        #-}

{-# LANGUAGE TypeOperators         #-}

{-# LANGUAGE FlexibleInstances     #-}

{-# LANGUAGE FlexibleContexts      #-}           

{-# LANGUAGE ScopedTypeVariables   #-}

{-# LANGUAGE DeriveGeneric         #-}

{-# LANGUAGE DeriveAnyClass        #-}

{-# LANGUAGE TemplateHaskell       #-}

{-# LANGUAGE StandaloneDeriving    #-}           

{-# LANGUAGE UndecidableInstances  #-}

{-# LANGUAGE TypeApplications      #-}

{-# LANGUAGE TypeInType            #-}

{-# LANGUAGE AllowAmbiguousTypes   #-}

{-# LANGUAGE ConstraintKinds       #-}

{-# LANGUAGE RankNTypes            #-}

{-# LANGUAGE GADTs                 #-}



{-# OPTIONS_GHC -fno-solve-constant-dicts #-} -- See https://ghc.haskell.org/trac/ghc/ticket/13943#comment:2



-----------------------------------------------------------------------------

-- |

-- Module      :  Data.Matrix.Static

-- Copyright   :  (C) 2017 Alexey Vagarenko

-- License     :  BSD-style (see LICENSE)

-- Maintainer  :  Alexey Vagarenko (vagarenko@gmail.com)

-- Stability   :  experimental

-- Portability :  non-portable

--

----------------------------------------------------------------------------



module Data.Matrix.Static (

    -- * Matrix

      Matrix

    , MatrixConstructor

    , IsMatrix

    -- * Matrix construction

    , matrix

    , identity

    , Identity

    -- * Matrix elements

    -- ** Rows

    , row

    , Row

    , getRowElems

    , GetRowElems

    , setRowElems

    , SetRowElems

    , mapRowElems

    , MapRowElems

    -- ** Columns

    , col

    , Col

    , getColElems

    , GetColElems

    , setColElems

    , SetColElems

    , mapColElems

    , MapColElems

    -- * Matrix multiplication

    , MatrixMultDims

    , MatrixMult(..)

    -- * Matrix operations

    , transpose

    , Transpose

    , minorMatrix

    , MinorMatrix

    , Determinant(..)

    , minor

    , Minor

    , cofactor

    , Cofactor

    , cofactorMatrix

    , CofactorMatrix

    , adjugateMatrix

    , AdjugateMatrix

    , inverse

    , Inverse

    -- * Generating matrix instances

    , genMatrixInstance

) where



import Control.Lens             (Lens')

import Data.Kind                (Constraint)

import Data.Proxy               (Proxy(..))

import Data.Singletons          (type (~>))

import Data.Singletons.TH       (genDefunSymbols)

import Data.Tensor.Static       ( IsTensor(..), Tensor, TensorConstructor, NormalizeDims

                                , generate, Generate

                                , subtensor, SubtensorCtx, getSubtensorElems, GetSubtensorElems, setSubtensorElems, SetSubtensorElems

                                , mapSubtensorElems, MapSubtensorElems

                                , slice, Slice, getSliceElems, GetSliceElems, setSliceElems, SetSliceElems

                                , mapSliceElems, MapSliceElems

                                , withTensor

                                , NatsFromTo

                                , scale, Scale)

import Data.Tensor.Static.TH    (genTensorInstance)

import Data.Vector.Static       (Vector)

import GHC.TypeLits             (Nat, type (<=), type (<=?), type (-), type (+), TypeError, ErrorMessage(..))

import Language.Haskell.TH      (Q, Name, Dec)

import Type.List                (DemoteWith(..))



import qualified Data.List.NonEmpty as N

import qualified Data.List.Unrolled as U



---------------------------------------------------------------------------------------------------

-- | Matrix with @m@ rows, @n@ columns

type Matrix m n e = Tensor '[m, n] e



-- | Type of matrix data constructor.

type MatrixConstructor m n e = TensorConstructor '[m, n] e



-- | Matrix constraint.

type IsMatrix m n e = IsTensor '[m, n] e



---------------------------------------------------------------------------------------------------

-- | Alias for a conrete matrix data constructor.

matrix :: forall m n e. (IsMatrix m n e) => MatrixConstructor m n e

matrix = tensor @'[m, n] @e

{-# INLINE matrix #-}



-- | Identity matrix of size @m*m@

identity :: forall m e.

    ( IsMatrix m m e

    , Generate '[m, m] e ([Nat] -> Constraint) (IdentityWrk e)

    , Num e

    )

    => Matrix m m e -- ^

identity = generate @'[m, m] @e @([Nat] -> Constraint) @(IdentityWrk e) go

    where

        go :: forall (index :: [Nat]).

              (IdentityWrk e index) => 

              Proxy index -> e

        go _ = identityWrk @e @index

{-# INLINE identity #-}



-- | Constraints for 'identity' function.

type Identity m e =

    ( IsMatrix m m e

    , Generate '[m, m] e ([Nat] -> Constraint) (IdentityWrk e)

    , Num e

    )    



class IdentityWrk e (index :: [Nat]) where

    identityWrk :: e



instance {-# OVERLAPPABLE #-} (Num e) => IdentityWrk e '[i, j] where

    identityWrk = 0

    {-# INLINE identityWrk #-}



instance {-# OVERLAPPING #-} (Num e) => IdentityWrk e '[i, i] where

    identityWrk = 1

    {-# INLINE identityWrk #-}



---------------------------------------------------------------------------------------------------

-- | Lens for the row number @r@ of the matrix @m@x@n@.

--

-- >>> matrix @2 @2 @Float 0 1 2 3 ^. row @0

-- Tensor'2 [0.0,1.0]

--

-- >>> set (row @1) (vector @2 @Float 20 30) (matrix @2 @2 @Float 0 1 2 3)

-- Tensor'2'2 [[0.0,1.0],[20.0,30.0]]

row :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.

    (Row r m n e)

    => Lens' (Matrix m n e) (Vector n e)    -- ^

row = subtensor @'[r] @'[m, n] @e

{-# INLINE row #-}



-- | Constraints for 'row' function.

type Row (r :: Nat) (m :: Nat) (n :: Nat) e =

    ( SubtensorCtx '[r] '[m, n] e

    , r <= m - 1                   -- TODO: Why do I need this constraint?

    , NormalizeDims '[n] ~ '[n]    -- TODO: Why do I need this constraint?

    )



-- | List of elements of the row number @r@ of the matrix @m@x@n@.

--

-- >>> getRowElems @0 (matrix @2 @2 @Float 0 1 2 3)

-- [0.0,1.0]

getRowElems :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.

    (GetRowElems r m n e)

    => Matrix m n e         -- ^

    -> [e]

getRowElems = getSubtensorElems @'[r] @'[m, n] @e

{-# INLINE getRowElems #-}



-- | Constraints for 'getRowElems' function.

type GetRowElems (r :: Nat) (m :: Nat) (n :: Nat) e =

    GetSubtensorElems '[r] '[m, n] e



-- | Put elements of the list into row number @r@. The list must have enough elements.

--

-- >>> setRowElems @1 (matrix @2 @2 @Float 0 1 2 3) [20, 30]

-- Just Tensor'2'2 [[0.0,1.0],[20.0,30.0]]

--

-- >>> setRowElems @1 (matrix @2 @2 @Float 0 1 2 3) [20]

-- Nothing

setRowElems :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.

    (SetRowElems r m n e)

    => Matrix m n e             -- ^ The matrix.

    -> [e]                      -- ^ New row elements.

    -> Maybe (Matrix m n e)

setRowElems = setSubtensorElems @'[r] @'[m, n] @e

{-# INLINE setRowElems #-}



-- | Constraints for 'setRowElems' function.

type SetRowElems (r :: Nat) (m :: Nat) (n :: Nat) e =

    SetSubtensorElems '[r] '[m, n] e



-- | Apply a function to all elements of the row number @r@.

--

-- >>> mapRowElems @1 (matrix @2 @2 @Float 0 1 2 3) (* 100)

-- Tensor'2'2 [[0.0,1.0],[200.0,300.0]]

mapRowElems :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.

    (MapRowElems r m n e)

    => Matrix m n e         -- ^ The matrix.

    -> (e -> e)             -- ^ The mapping function.

    -> Matrix m n e

mapRowElems = mapSubtensorElems @'[r] @'[m, n] @e

{-# INLINE mapRowElems #-}



-- | Constraints for 'mapRowElems' function.

type MapRowElems (r :: Nat) (m :: Nat) (n :: Nat) e =

    MapSubtensorElems '[r] '[m, n] e



---------------------------------------------------------------------------------------------------

-- | Lens for the column number @c@ of the matrix @m@x@n@.

--

-- >>> matrix @2 @2 @Float 0 1 2 3 ^. col @0

-- Tensor'2 [0.0,2.0]

--

-- >>> set (col @1) (vector @2 @Float 10 30) (matrix @2 @2 @Float 0 1 2 3)

-- Tensor'2'2 [[0.0,10.0],[2.0,30.0]]

col :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.

    (Col c m n e)

    => Lens' (Matrix m n e) (Vector m e)        -- ^

col = slice @'[0, c] @'[m, 1] @'[m, n] @e

{-# INLINE col #-}



-- | Constraints for 'col' function.

type Col (c :: Nat) (m :: Nat) (n :: Nat) e =

    ( Slice '[0, c] '[m, 1] '[m, n] e

    , NormalizeDims '[m, 1] ~ '[m]          -- TODO: Why do I need this constraint?

    )



-- | List of elements of the column number @c@ of the matrix @m@x@n@.

--

-- >>> getColElems @0 (matrix @2 @2 @Float 0 1 2 3)

-- [0.0,2.0]

getColElems :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.

    (GetColElems c m n e)

    => Matrix m n e             -- ^

    -> [e]

getColElems = getSliceElems @'[0, c] @'[m, 1] @'[m, n] @e

{-# INLINE getColElems #-}



-- | Constraints for 'getColElems' function.

type GetColElems (c :: Nat) (m :: Nat) (n :: Nat) e =

    GetSliceElems '[0, c] '[m, 1] '[m, n] e



-- | Put elements of the list into column number @r@. The list must have enough elements.

--

-- >>> setColElems @1 (matrix @2 @2 @Float 0 1 2 3) [10, 30]

-- Just Tensor'2'2 [[0.0,10.0],[2.0,30.0]]

--

-- >>> setColElems @1 (matrix @2 @2 @Float 0 1 2 3) [10]

-- Nothing

setColElems :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.

    (SetColElems c m n e)

    => Matrix m n e         -- ^ The matrix.

    -> [e]                  -- ^ New column elements.

    -> Maybe (Matrix m n e)

setColElems = setSliceElems @'[0, c] @'[m, 1] @'[m, n] @e

{-# INLINE setColElems #-}



-- | Constraints for 'setColElems' function.

type SetColElems (c :: Nat) (m :: Nat) (n :: Nat) e =

    SetSliceElems '[0, c] '[m, 1] '[m, n] e



-- | Apply a function to all elements of the column number @c@.

--

-- >>> mapColElems @1 (matrix @2 @2 @Float 0 1 2 3) (* 100)

-- Tensor'2'2 [[0.0,100.0],[2.0,300.0]]

mapColElems :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.

    (MapColElems c m n e)

    => Matrix m n e         -- ^ 

    -> (e -> e)             -- ^ 

    -> Matrix m n e

mapColElems = mapSliceElems @'[0, c] @'[m, 1] @'[m, n] @e

{-# INLINE mapColElems #-}



-- | Constraints for 'mapColElems' function.

type MapColElems (c :: Nat) (m :: Nat) (n :: Nat) e =

    MapSliceElems '[0, c] '[m, 1] '[m, n] e



---------------------------------------------------------------------------------------------------

type family ReverseIndex (index :: [Nat]) :: [Nat] where

    ReverseIndex '[i, j] = '[j, i]



type TransposeGo m n e index = GetSliceElems (ReverseIndex index) [1, 1] [m, n] e

$(genDefunSymbols [''TransposeGo])



-- | Transpose a matrix.

transpose :: forall m n e.

    (Transpose m n e)

    => Matrix m n e         -- ^ 

    -> Matrix n m e

transpose m = generate @'[n, m] @e @([Nat] ~> Constraint) @(TransposeGoSym3 m n e) go

    where

        go :: forall (index :: [Nat]).

            (TransposeGo m n e index)

            => Proxy index -> e

        go _ = head $ getSliceElems @(ReverseIndex index) @[1, 1] m

        {-# INLINE go #-}

{-# INLINE transpose #-}



-- | Constraints for 'transpose' function.

type Transpose m n e =

    ( IsMatrix m n e

    , IsMatrix n m e

    , Generate '[n, m] e ([Nat] ~> Constraint) (TransposeGoSym3 m n e)

    )



---------------------------------------------------------------------------------------------------

-- Matrix multiplication.

---------------------------------------------------------------------------------------------------

-- | Shape of the result of matrix multiplication.

type family MatrixMultDims (dims0 :: [Nat]) (dims1 :: [Nat]) :: [Nat] where

    MatrixMultDims '[m, n] '[n, o] = '[m, o]  -- matrix m*n mult by matrix n*o makes matrix m*o

    MatrixMultDims '[n   ] '[n, o] = '[o   ]  -- vector n   mult by matrix n*o makes vector o

    MatrixMultDims '[m, n] '[n   ] = '[m   ]  -- matrix m*n mult by vector n   makes vector m

    MatrixMultDims a       b       =

        TypeError (

            'Text "Matrices of shapes "

            ':<>: 'ShowType a

            ':<>: 'Text " and "

            ':<>: 'ShowType b

            ':<>: 'Text " are incompatible for multiplication.")



-- | Matrix multiplication.

class MatrixMult (dims0 :: [Nat]) (dims1 :: [Nat]) e where

    -- | Multiply two matrices, or matrix and vector. Matrices (or matrix and vector) must have compatible dimensions.

    mult :: 

        ( IsTensor dims0 e

        , IsTensor dims1 e

        , IsTensor (MatrixMultDims dims0 dims1) e

        )

        => Tensor dims0 e                           -- ^ 

        -> Tensor dims1 e                           -- ^ 

        -> Tensor (MatrixMultDims dims0 dims1) e



-- | Get 0-th element of an index.

type family Index0 (index :: [Nat]) :: Nat where

    Index0 (i ': _) = i



-- | Get 1-st element of an index.

type family Index1 (index :: [Nat]) :: Nat where

    Index1 (_ ': j ': _ ) = j



-------------------------------------------------------------------------------

type MultMatMatGo (m :: Nat) (n :: Nat) (o :: Nat) e (index :: [Nat]) =

    ( GetRowElems (Index0 index) m n e

    , GetColElems (Index1 index) n o e

    , U.Sum n e

    , U.ZipWith n

    )

$(genDefunSymbols [''MultMatMatGo])



-- | Multiply two matrices.

instance ( Num e

         , Generate (MatrixMultDims '[m, n] '[n, o]) e ([Nat] ~> Constraint) (MultMatMatGoSym4 m n o e)

         ) =>

         MatrixMult '[m, n] '[n, o] e where

    mult m0 m1 = generate @(MatrixMultDims '[m, n] '[n, o]) @e @([Nat] ~> Constraint) @(MultMatMatGoSym4 m n o e) go

        where

            go :: forall (index :: [Nat]).

                ( GetRowElems (Index0 index) m n e

                , GetColElems (Index1 index) n o e

                , U.Sum n e

                , U.ZipWith n

                ) =>

                Proxy index -> e

            go _ = go' @(Index0 index) @(Index1 index)

            {-# INLINE go #-}



            go' :: forall (i :: Nat) (j :: Nat).

                ( GetRowElems i m n e

                , GetColElems j n o e

                , U.Sum n e

                , U.ZipWith n

                ) =>

                e

            go' = U.sum @n (U.zipWith @n (*) irow jcol)

                where

                    irow = getRowElems @i m0

                    jcol = getColElems @j m1

            {-# INLINE go' #-}

    {-# INLINE mult #-}



-------------------------------------------------------------------------------

type MultVecMatGo (m :: Nat) (n :: Nat) (o :: Nat) e (index :: [Nat]) =

    ( GetColElems (Index0 index) n o e

    , U.Sum n e

    , U.ZipWith n

    )

$(genDefunSymbols [''MultVecMatGo])



-- | Multiply vector and matrix.

instance ( Num e

         , Generate (MatrixMultDims '[n] '[n, o]) e ([Nat] ~> Constraint) (MultVecMatGoSym4 m n o e)

         ) =>

         MatrixMult '[n] '[n, o] e where

    mult v m = generate @(MatrixMultDims '[n] '[n, o]) @e @([Nat] ~> Constraint) @(MultVecMatGoSym4 m n o e) go

        where

            go :: forall (index :: [Nat]).

                ( GetColElems (Index0 index) n o e

                , U.Sum n e

                , U.ZipWith n

                ) =>

                Proxy index -> e

            go _ = go' @(Index0 index)

            {-# INLINE go #-}



            go' :: forall (c :: Nat).

                ( GetColElems c n o e

                , U.Sum n e

                , U.ZipWith n

                ) =>

                e

            go' = U.sum @n (U.zipWith @n (*) irow jcol)

                where

                    irow = toList v

                    jcol = getColElems @c m

            {-# INLINE go' #-}

    {-# INLINE mult #-}



-------------------------------------------------------------------------------

type MultMatVecGo (m :: Nat) (n :: Nat) (o :: Nat) e (index :: [Nat]) =

    ( GetRowElems (Index0 index) m n e

    , U.Sum n e

    , U.ZipWith n

    )

$(genDefunSymbols [''MultMatVecGo])



-- | Multiply matrix and vector.

instance ( Num e

         , Generate (MatrixMultDims '[m, n] '[n]) e ([Nat] ~> Constraint) (MultMatVecGoSym4 m n o e)

         ) =>

         MatrixMult '[m, n] '[n] e where

    mult m v = generate @(MatrixMultDims '[m, n] '[n]) @e @([Nat] ~> Constraint) @(MultMatVecGoSym4 m n o e) go

        where

            go :: forall (index :: [Nat]).

                ( GetRowElems (Index0 index) m n e

                , U.Sum n e

                , U.ZipWith n

                ) =>

                Proxy index -> e

            go _ = go' @(Index0 index)

            {-# INLINE go #-}



            go' :: forall (r :: Nat).

                ( GetRowElems r m n e

                , U.Sum n e

                , U.ZipWith n

                ) =>

                e

            go' = U.sum @n (U.zipWith @n (*) irow jcol)

                where

                    irow = getRowElems @r m

                    jcol = toList v

            {-# INLINE go' #-}

    {-# INLINE mult #-}



---------------------------------------------------------------------------------------------------

type family MinorMatrixNewIndex (cutIndex :: Nat) (index :: Nat) :: Nat where

    MinorMatrixNewIndex 0  i = i + 1

    MinorMatrixNewIndex ci i = MinorMatrixNewIndex' ci i (i <=? ci - 1)



--

type family MinorMatrixNewIndex' (cutIndex :: Nat) (index :: Nat) (indexLTcutIndex :: Bool) :: Nat where

    MinorMatrixNewIndex' ci i 'True  = i

    MinorMatrixNewIndex' ci i 'False = i + 1



type MinorMatrixGo (i :: Nat) (j :: Nat) (n :: Nat) e (index :: [Nat]) =

    (GetSliceElems [ (MinorMatrixNewIndex i (Index0 index))

                   , (MinorMatrixNewIndex j (Index1 index))

                   ]

                   [1, 1]

                   [n, n]

                   e

    )

$(genDefunSymbols [''MinorMatrixGo])



-- | Minor matrix is a matrix made by deleting @i@-th row and @j@-th column from given square matrix.

minorMatrix :: forall (i :: Nat) (j :: Nat) (n :: Nat) e.

    (Generate ([n - 1, n - 1]) e ([Nat] ~> Constraint) (MinorMatrixGoSym4 i j n e))

    => Matrix n n e                 -- ^ 

    -> Matrix (n - 1) (n - 1) e     -- ^ 

minorMatrix m = generate @([n - 1, n - 1]) @e @([Nat] ~> Constraint) @(MinorMatrixGoSym4 i j n e) go

    where

        go :: forall (index :: [Nat]).

            (MinorMatrixGo i j n e index) =>

            Proxy index -> e

        go _ = go' @(MinorMatrixNewIndex i (Index0 index)) @(MinorMatrixNewIndex j (Index1 index))

        {-# INLINE go #-}



        go' :: forall (r :: Nat) (c :: Nat). (GetSliceElems [r, c] [1, 1] [n, n] e) => e

        go' = head $ getSliceElems @[r, c] @[1, 1] @[n, n] @e m

        {-# INLINE go' #-}

{-# INLINE minorMatrix #-}



-- | Constraint for 'minorMatrix' function.

type MinorMatrix (i :: Nat) (j :: Nat) (n :: Nat) e =

    Generate ([n - 1, n - 1]) e ([Nat] ~> Constraint) (MinorMatrixGoSym4 i j n e)



---------------------------------------------------------------------------------------------------

-- | Determinant of a matrix.

class Determinant (n :: Nat) e where

    determinant :: (Num e) => Matrix n n e -> e



instance {-# OVERLAPPING #-}

    (Num e, IsMatrix 2 2 e)

    => Determinant 2 e

    where

    determinant m =

        withTensor m $ \a b c d -> a * d - b * c

    {-# INLINE determinant #-}



instance {-# OVERLAPPING #-}

    (Num e, IsMatrix 3 3 e)

    => Determinant 3 e

    where

    determinant m =

        withTensor m $ \a b c d e f g h i ->

            a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)

    {-# INLINE determinant #-}



-- | Sign is positive for even @n@ and negative for odd.

class Sign (n :: Nat) where

    sign :: (Num a) => a



instance {-# OVERLAPPING #-} Sign 0 where

    sign = 1

    {-# INLINE sign #-}



instance {-# OVERLAPPABLE #-} (Sign (n - 1)) => Sign n where

    sign = (-1) * sign @(n - 1)

    {-# INLINE sign #-}



type DeterminantGo (n :: Nat) e (j :: Nat) =

    ( Determinant (n - 1) e

    , GetSliceElems [0, j] [1, 1] [n, n] e

    , MinorMatrix 0 j n e

    , Sign j

    )

$(genDefunSymbols [''DeterminantGo])



instance {-# OVERLAPPABLE #-}

    ( Num e

    , IsMatrix n n e

    , DemoteWith Nat (Nat ~> Constraint) (DeterminantGoSym2 n e) (NatsFromTo 0 (n - 1))

    , U.Sum n e

    )

    => Determinant n e

    where

    determinant m = U.sum @n $ demoteWith @Nat @(Nat ~> Constraint) @(DeterminantGoSym2 n e) @(NatsFromTo 0 (n - 1)) go

        where

            go :: forall (j :: Nat).

                (DeterminantGo n e j)

                => Proxy j -> e

            go _ = sign @j * el * determinant (minorMatrix @0 @j @n @e m)

                where

                    el = head $ getSliceElems @[0, j] @[1, 1] @[n, n] @e m

            {-# INLINE go #-}

    {-# INLINE determinant #-}



---------------------------------------------------------------------------------------------------

-- | Minor is the determinant of minor matrix.

minor :: forall (i :: Nat) (j :: Nat) (n :: Nat) e.

    (Minor i j n e)

    => Matrix n n e         -- ^

    -> e

minor = determinant @(n - 1) @e . minorMatrix @i @j @n @e

{-# INLINE minor #-}



-- | Constraint for 'minor' function.

type Minor (i :: Nat) (j :: Nat) (n :: Nat) e =

    ( MinorMatrix i j n e

    , Determinant (n - 1) e

    , Num e

    )



---------------------------------------------------------------------------------------------------

-- | @'cofactor' \@i \@j@ is the @'minor' \@i \@j@ multiplied by @(-1) ^ (i + j)@.

cofactor :: forall (i :: Nat) (j :: Nat) (n :: Nat) e.

    (Cofactor i j n e)

    => Matrix n n e         -- ^

    -> e

cofactor m = sign @(i + j) * minor @i @j @n @e m

{-# INLINE cofactor #-}



-- | Constraint for 'cofactor' function.

type Cofactor (i :: Nat) (j :: Nat) (n :: Nat) e =

    ( Minor i j n e

    , Sign (i + j)

    )



---------------------------------------------------------------------------------------------------

type CofactorMatrixGo (n :: Nat) e (index :: [Nat]) =

    (Cofactor (Index0 index) (Index1 index) n e)

$(genDefunSymbols [''CofactorMatrixGo])



-- | The matrix formed by all of the cofactors of given square matrix.

cofactorMatrix :: forall (n :: Nat) e.

    (CofactorMatrix n e)

    => Matrix n n e         -- ^

    -> Matrix n n e

cofactorMatrix m = generate @([n, n]) @e @([Nat] ~> Constraint) @(CofactorMatrixGoSym2 n e) go

    where

        go :: forall (index :: [Nat]).

            (Cofactor (Index0 index) (Index1 index) n e) =>

            Proxy index -> e

        go _ = go' @(Index0 index) @(Index1 index)

        {-# INLINE go #-}



        go' :: forall (i :: Nat) (j :: Nat).

            (Cofactor i j n e) => e

        go' = cofactor @i @j @n @e m

        {-# INLINE go' #-}

{-# INLINE cofactorMatrix #-}



-- | Constraint for 'cofactorMatrix' function.

type CofactorMatrix (n :: Nat) e =

    Generate [n, n] e ([Nat] ~> Constraint) (CofactorMatrixGoSym2 n e)



---------------------------------------------------------------------------------------------------

-- | Adjugate matrix of given square matrix is the transpose of its cofactor matrix.

--

-- @adjugateMatrix = transpose . cofactorMatrix@

--

adjugateMatrix :: forall (n :: Nat) e.

    (AdjugateMatrix n e)

    => Matrix n n e         -- ^

    -> Matrix n n e

adjugateMatrix = transpose . cofactorMatrix

{-# INLINE adjugateMatrix #-}



-- | Constraint for 'adjugateMatrix' function.

type AdjugateMatrix (n :: Nat) e =

    (CofactorMatrix n e, Transpose n n e)



---------------------------------------------------------------------------------------------------

-- | Inverse of the matrix.

inverse :: forall (n :: Nat) e.

    (Inverse n e)

    => Matrix n n e         -- ^

    -> Matrix n n e

inverse m = scale (adjugateMatrix m) (1 / determinant m)

{-# INLINE inverse #-}



-- | Constraint for 'inverse' function.

type Inverse (n :: Nat) e =

    (AdjugateMatrix n e, Determinant n e, Fractional e, Scale '[n, n] e)



---------------------------------------------------------------------------------------------------

-- | Generate instance of a matrix.

genMatrixInstance :: Int       -- ^ Number of rows.

                  -> Int       -- ^ Number of columns.

                  -> Name      -- ^ Type of elements.

                  -> Q [Dec]

genMatrixInstance m n elemTypeName = genTensorInstance (N.fromList [m, n]) elemTypeName