{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Triangular (
   Triangular, MatrixShape.UpLo,
   Diagonal, FlexDiagonal,
   Upper, FlexUpper, UnitUpper,
   Lower, FlexLower, UnitLower,
   Symmetric, FlexSymmetric,
   size,
   fromList, autoFromList,
   diagonalFromList, autoDiagonalFromList,
   lowerFromList, autoLowerFromList,
   upperFromList, autoUpperFromList,
   symmetricFromList, autoSymmetricFromList,
   asDiagonal, asLower, asUpper, asSymmetric,
   requireUnitDiagonal, requireNonUnitDiagonal,
   relaxUnitDiagonal, strictNonUnitDiagonal,
   identity,
   diagonal,
   takeDiagonal,
   transpose,
   adjoint,

   stackDiagonal,
   stackLower,
   stackUpper,
   stackSymmetric,
   splitDiagonal,
   splitLower,
   splitUpper,
   splitSymmetric,
   takeTopLeft,
   takeTopRight,
   takeBottomLeft,
   takeBottomRight,

   toSquare,
   takeLower,
   takeUpper,

   fromLowerRowMajor, toLowerRowMajor,
   fromUpperRowMajor, toUpperRowMajor,
   forceOrder, adaptOrder,

   add, sub,

   Tri.PowerDiag,
   multiplyVector,
   square, squareGeneric,
   multiply,
   multiplyFull,

   solve,
   inverse,
   inverseGeneric,
   determinant,

   eigenvalues,
   eigensystem,
   ) where

import qualified Numeric.LAPACK.Matrix.Triangular.Eigen as Eigen
import qualified Numeric.LAPACK.Matrix.Triangular.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Triangular.Private as Tri

import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import Numeric.LAPACK.Matrix.Array.Triangular (
   Triangular,
   Diagonal,  FlexDiagonal,
   Lower,     FlexLower,     UnitLower,
   Upper,     FlexUpper,     UnitUpper,
   Symmetric, FlexSymmetric,
   )
import Numeric.LAPACK.Matrix.Array (Full, General, Square)
import Numeric.LAPACK.Matrix.Shape.Private (NonUnit, Unit, Order)
import Numeric.LAPACK.Matrix.Private (ZeroInt)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import Data.Array.Comfort.Shape ((:+:))

import Foreign.Storable (Storable)

import Data.Tuple.HT (mapTriple)


size :: Triangular lo diag up sh a -> sh
size = MatrixShape.triangularSize . ArrMatrix.shape

transpose ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    MatrixShape.TriDiag diag) =>
   Triangular lo diag up sh a -> Triangular up diag lo sh a
transpose = ArrMatrix.lift1 Basic.transpose

adjoint ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    MatrixShape.TriDiag diag, Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Triangular up diag lo sh a
adjoint = ArrMatrix.lift1 Basic.adjoint

fromList ::
   (MatrixShape.Content lo, MatrixShape.Content up, Shape.C sh, Storable a) =>
   Order -> sh -> [a] -> Triangular lo NonUnit up sh a
fromList order sh = ArrMatrix.lift0 . Basic.fromList order sh

lowerFromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Lower sh a
lowerFromList = fromList

upperFromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Upper sh a
upperFromList = fromList

symmetricFromList ::
   (Shape.C sh, Storable a) => Order -> sh -> [a] -> Symmetric sh a
symmetricFromList = fromList

diagonalFromList ::
   (Shape.C sh, Storable a) => Order -> sh -> [a] -> Diagonal sh a
diagonalFromList = fromList


autoFromList ::
   (MatrixShape.Content lo, MatrixShape.Content up, Storable a) =>
   Order -> [a] -> Triangular lo NonUnit up ZeroInt a
autoFromList order = ArrMatrix.lift0 . Basic.autoFromList order

autoLowerFromList :: (Storable a) => Order -> [a] -> Lower ZeroInt a
autoLowerFromList = autoFromList

autoUpperFromList :: (Storable a) => Order -> [a] -> Upper ZeroInt a
autoUpperFromList = autoFromList

autoSymmetricFromList :: (Storable a) => Order -> [a] -> Symmetric ZeroInt a
autoSymmetricFromList = autoFromList

autoDiagonalFromList :: (Storable a) => Order -> [a] -> Diagonal ZeroInt a
autoDiagonalFromList = autoFromList


asDiagonal :: FlexDiagonal diag sh a -> FlexDiagonal diag sh a
asDiagonal = id

asLower :: FlexLower diag sh a -> FlexLower diag sh a
asLower = id

asUpper :: FlexUpper diag sh a -> FlexUpper diag sh a
asUpper = id

asSymmetric :: FlexSymmetric diag sh a -> FlexSymmetric diag sh a
asSymmetric = id

requireUnitDiagonal :: Triangular lo Unit up sh a -> Triangular lo Unit up sh a
requireUnitDiagonal = id

requireNonUnitDiagonal ::
   Triangular lo NonUnit up sh a -> Triangular lo NonUnit up sh a
requireNonUnitDiagonal = id


toSquare ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Square sh a
toSquare = ArrMatrix.lift1 Basic.toSquare

takeLower ::
   (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   Full Extent.Small horiz height width a -> Lower height a
takeLower = ArrMatrix.lift1 Basic.takeLower

takeUpper ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   Full vert Extent.Small height width a -> Upper width a
takeUpper = ArrMatrix.lift1 Basic.takeUpper

fromLowerRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Array (Shape.Triangular Shape.Lower sh) a -> Lower sh a
fromLowerRowMajor = ArrMatrix.lift0 . Basic.fromLowerRowMajor

fromUpperRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Array (Shape.Triangular Shape.Upper sh) a -> Upper sh a
fromUpperRowMajor = ArrMatrix.lift0 . Basic.fromUpperRowMajor

toLowerRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Lower sh a -> Array (Shape.Triangular Shape.Lower sh) a
toLowerRowMajor = Basic.toLowerRowMajor . ArrMatrix.toVector

toUpperRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Upper sh a -> Array (Shape.Triangular Shape.Upper sh) a
toUpperRowMajor = Basic.toUpperRowMajor . ArrMatrix.toVector

forceOrder ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Order -> Triangular lo diag up sh a -> Triangular lo diag up sh a
forceOrder = ArrMatrix.lift1 . Basic.forceOrder

{- |
@adaptOrder x y@ contains the data of @y@ with the layout of @x@.
-}
adaptOrder ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Triangular lo diag up sh a ->
   Triangular lo diag up sh a
adaptOrder = ArrMatrix.lift2 Basic.adaptOrder

add, sub ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    Eq lo, Eq up, Eq sh, Shape.C sh, Class.Floating a) =>
   Triangular lo NonUnit up sh a ->
   Triangular lo NonUnit up sh a ->
   Triangular lo NonUnit up sh a
add = ArrMatrix.lift2 Basic.add
sub = ArrMatrix.lift2 Basic.sub


identity ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    Shape.C sh, Class.Floating a) =>
   Order -> sh -> Triangular lo Unit up sh a
identity order = ArrMatrix.lift0 . Basic.identity order

diagonal ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    Shape.C sh, Class.Floating a) =>
   Order -> Vector sh a -> Triangular lo NonUnit up sh a
diagonal order = ArrMatrix.lift0 . Basic.diagonal order

takeDiagonal ::
   (MatrixShape.Content lo, MatrixShape.Content up,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Vector sh a
takeDiagonal = Basic.takeDiagonal . ArrMatrix.toVector

relaxUnitDiagonal ::
   (MatrixShape.TriDiag diag) =>
   Triangular lo Unit up sh a -> Triangular lo diag up sh a
relaxUnitDiagonal = ArrMatrix.lift1 Basic.relaxUnitDiagonal

strictNonUnitDiagonal ::
   (MatrixShape.TriDiag diag) =>
   Triangular lo diag up sh a -> Triangular lo NonUnit up sh a
strictNonUnitDiagonal = ArrMatrix.lift1 Basic.strictNonUnitDiagonal


stackDiagonal ::
   (MatrixShape.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexDiagonal diag sh0 a ->
   FlexDiagonal diag sh1 a ->
   FlexDiagonal diag (sh0:+:sh1) a
stackDiagonal = ArrMatrix.lift2 Basic.stackDiagonal

stackLower ::
   (MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexLower diag sh0 a ->
   General sh1 sh0 a ->
   FlexLower diag sh1 a ->
   FlexLower diag (sh0:+:sh1) a
stackLower = ArrMatrix.lift3 Basic.stackLower

stackUpper ::
   (MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexUpper diag sh0 a ->
   General sh0 sh1 a ->
   FlexUpper diag sh1 a ->
   FlexUpper diag (sh0:+:sh1) a
stackUpper = ArrMatrix.lift3 Basic.stackUpper

stackSymmetric ::
   (MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexSymmetric diag sh0 a ->
   General sh0 sh1 a ->
   FlexSymmetric diag sh1 a ->
   FlexSymmetric diag (sh0:+:sh1) a
stackSymmetric = ArrMatrix.lift3 Basic.stackSymmetric


splitDiagonal ::
   (MatrixShape.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexDiagonal diag (sh0:+:sh1) a ->
   (FlexDiagonal diag sh0 a, FlexDiagonal diag sh1 a)
splitDiagonal a = (takeTopLeft a, takeBottomRight a)

splitLower ::
   (MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexLower diag (sh0:+:sh1) a ->
   (FlexLower diag sh0 a, General sh1 sh0 a, FlexLower diag sh1 a)
splitLower a = (takeTopLeft a, takeBottomLeft a, takeBottomRight a)

splitUpper ::
   (MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexUpper diag (sh0:+:sh1) a ->
   (FlexUpper diag sh0 a, General sh0 sh1 a, FlexUpper diag sh1 a)
splitUpper a = (takeTopLeft a, takeTopRight a, takeBottomRight a)

splitSymmetric ::
   (MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexSymmetric diag (sh0:+:sh1) a ->
   (FlexSymmetric diag sh0 a, General sh0 sh1 a, FlexSymmetric diag sh1 a)
splitSymmetric a = (takeTopLeft a, takeTopRight a, takeBottomRight a)


takeTopLeft ::
   (MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Triangular lo diag up (sh0:+:sh1) a ->
   Triangular lo diag up sh0 a
takeTopLeft = ArrMatrix.lift1 Basic.takeTopLeft

takeBottomLeft ::
   (MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Triangular MatrixShape.Filled diag up (sh0:+:sh1) a ->
   General sh1 sh0 a
takeBottomLeft = ArrMatrix.lift1 Basic.takeBottomLeft

takeTopRight ::
   (MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Triangular lo diag MatrixShape.Filled (sh0:+:sh1) a ->
   General sh0 sh1 a
takeTopRight = ArrMatrix.lift1 Basic.takeTopRight

takeBottomRight ::
   (MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Triangular lo diag up (sh0:+:sh1) a ->
   Triangular lo diag up sh1 a
takeBottomRight = ArrMatrix.lift1 Basic.takeBottomRight


multiplyVector ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Vector sh a -> Vector sh a
multiplyVector = Basic.multiplyVector . ArrMatrix.toVector

square ::
   (MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Triangular lo diag up sh a
square = ArrMatrix.lift1 Basic.square

{- |
Include symmetric matrices.
However, symmetric matrices do not preserve unit diagonals.
-}
squareGeneric ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Triangular lo (Tri.PowerDiag lo up diag) up sh a
squareGeneric = ArrMatrix.lift1 Basic.squareGeneric

multiply ::
   (MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Triangular lo diag up sh a ->
   Triangular lo diag up sh a
multiply = ArrMatrix.lift2 Basic.multiply

multiplyFull ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width,
    Class.Floating a) =>
   Triangular lo diag up height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
multiplyFull = ArrMatrix.lift2 Basic.multiplyFull



solve ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve = ArrMatrix.lift2 Linear.solve

inverse ::
   (MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Triangular lo diag up sh a
inverse = ArrMatrix.lift1 Linear.inverse

inverseGeneric ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Triangular lo (Basic.PowerDiag lo up diag) up sh a
inverseGeneric = ArrMatrix.lift1 Linear.inverseGeneric

determinant ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> a
determinant = Linear.determinant . ArrMatrix.toVector



eigenvalues ::
   (MatrixShape.DiagUpLo lo up, Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Vector sh a
eigenvalues = Eigen.values . ArrMatrix.toVector

{- |
@(vr,d,vlAdj) = eigensystem a@

Counterintuitively, @vr@ contains the right eigenvectors as columns
and @vlAdj@ contains the left conjugated eigenvectors as rows.
The idea is to provide a decomposition of @a@.
If @a@ is diagonalizable, then @vr@ and @vlAdj@
are almost inverse to each other.
More precisely, @vlAdj \<#\> vr@ is a diagonal matrix.
This is because the eigenvectors are not normalized.
With the following scaling, the decomposition becomes perfect:

> let scal = Array.map recip $ takeDiagonal $ vlAdj <#> vr
> a == vr <#> diagonal d <#> diagonal scal <#> vlAdj

If @a@ is non-diagonalizable
then some columns of @vr@ and corresponding rows of @vlAdj@ are left zero
and the above property does not hold.
-}
eigensystem ::
   (MatrixShape.DiagUpLo lo up, Shape.C sh, Class.Floating a) =>
   Triangular lo NonUnit up sh a ->
   (Triangular lo NonUnit up sh a, Vector sh a, Triangular lo NonUnit up sh a)
eigensystem =
   mapTriple (ArrMatrix.lift0, id, ArrMatrix.lift0) .
   Eigen.decompose . ArrMatrix.toVector