{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Symmetric (
   Symmetric,
   size,
   fromList, autoFromList,
   identity,
   diagonal,
   takeDiagonal,
   transpose,
   adjoint,

   stack, (#%%%#),
   split,

   toSquare,
   fromHermitian,

   gramian,            gramianTransposed,
   congruenceDiagonal, congruenceDiagonalTransposed,
   congruence,         congruenceTransposed,
   anticommutator,     anticommutatorTransposed,
   ) where

import qualified Numeric.LAPACK.Matrix.Symmetric.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular

import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import Numeric.LAPACK.Matrix.Array.Triangular (Symmetric, Hermitian)
import Numeric.LAPACK.Matrix.Array (Full, General, Square)
import Numeric.LAPACK.Matrix.Shape.Private (Order)
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (one)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((:+:))

import Foreign.Storable (Storable)


size :: Symmetric sh a -> sh
size :: Symmetric sh a -> sh
size = Triangular Filled NonUnit Filled sh -> sh
forall lo diag up size. Triangular lo diag up size -> size
MatrixShape.triangularSize (Triangular Filled NonUnit Filled sh -> sh)
-> (Symmetric sh a -> Triangular Filled NonUnit Filled sh)
-> Symmetric sh a
-> sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symmetric sh a -> Triangular Filled NonUnit Filled sh
forall sh a. ArrayMatrix sh a -> sh
ArrMatrix.shape

transpose :: Symmetric sh a -> Symmetric sh a
transpose :: Symmetric sh a -> Symmetric sh a
transpose = Symmetric sh a -> Symmetric sh a
forall lo up diag sh a.
(Content lo, Content up, TriDiag diag) =>
Triangular lo diag up sh a -> Triangular up diag lo sh a
Triangular.transpose

adjoint :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Symmetric sh a
adjoint :: Symmetric sh a -> Symmetric sh a
adjoint = Symmetric sh a -> Symmetric sh a
forall lo up diag sh a.
(Content lo, Content up, TriDiag diag, C sh, Floating a) =>
Triangular lo diag up sh a -> Triangular up diag lo sh a
Triangular.adjoint


fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Symmetric sh a
fromList :: Order -> sh -> [a] -> Symmetric sh a
fromList = Order -> sh -> [a] -> Symmetric sh a
forall sh a.
(C sh, Storable a) =>
Order -> sh -> [a] -> Symmetric sh a
Triangular.symmetricFromList

autoFromList :: (Storable a) => Order -> [a] -> Symmetric ShapeInt a
autoFromList :: Order -> [a] -> Symmetric ShapeInt a
autoFromList = Order -> [a] -> Symmetric ShapeInt a
forall a. Storable a => Order -> [a] -> Symmetric ShapeInt a
Triangular.autoSymmetricFromList


toSquare :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Square sh a
toSquare :: Symmetric sh a -> Square sh a
toSquare = Symmetric sh a -> Square sh a
forall lo up sh a diag.
(Content lo, Content up, C sh, Floating a) =>
Triangular lo diag up sh a -> Square sh a
Triangular.toSquare

fromHermitian :: (Shape.C sh, Class.Real a) => Hermitian sh a -> Symmetric sh a
fromHermitian :: Hermitian sh a -> Symmetric sh a
fromHermitian =
   (Array (Hermitian sh) a -> Array (Symmetric sh) a)
-> Hermitian sh a -> Symmetric sh a
forall shA a shB b.
(Array shA a -> Array shB b)
-> ArrayMatrix shA a -> ArrayMatrix shB b
ArrMatrix.lift1 ((Array (Hermitian sh) a -> Array (Symmetric sh) a)
 -> Hermitian sh a -> Symmetric sh a)
-> (Array (Hermitian sh) a -> Array (Symmetric sh) a)
-> Hermitian sh a
-> Symmetric sh a
forall a b. (a -> b) -> a -> b
$ (Hermitian sh -> Symmetric sh)
-> Array (Hermitian sh) a -> Array (Symmetric sh) a
forall sh0 sh1 a.
(C sh0, C sh1) =>
(sh0 -> sh1) -> Array sh0 a -> Array sh1 a
Array.mapShape Hermitian sh -> Symmetric sh
forall size. Hermitian size -> Symmetric size
MatrixShape.symmetricFromHermitian


identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Symmetric sh a
identity :: Order -> sh -> Symmetric sh a
identity Order
order = Triangular Filled Unit Filled sh a -> Symmetric sh a
forall diag lo up sh a.
TriDiag diag =>
Triangular lo Unit up sh a -> Triangular lo diag up sh a
Triangular.relaxUnitDiagonal (Triangular Filled Unit Filled sh a -> Symmetric sh a)
-> (sh -> Triangular Filled Unit Filled sh a)
-> sh
-> Symmetric sh a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Order -> sh -> Triangular Filled Unit Filled sh a
forall lo up sh a.
(Content lo, Content up, C sh, Floating a) =>
Order -> sh -> Triangular lo Unit up sh a
Triangular.identity Order
order

diagonal ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Symmetric sh a
diagonal :: Order -> Vector sh a -> Symmetric sh a
diagonal = Order -> Vector sh a -> Symmetric sh a
forall lo up sh a.
(Content lo, Content up, C sh, Floating a) =>
Order -> Vector sh a -> Triangular lo NonUnit up sh a
Triangular.diagonal

takeDiagonal :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Vector sh a
takeDiagonal :: Symmetric sh a -> Vector sh a
takeDiagonal = Symmetric sh a -> Vector sh a
forall lo up sh a diag.
(Content lo, Content up, C sh, Floating a) =>
Triangular lo diag up sh a -> Vector sh a
Triangular.takeDiagonal


stack ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   Symmetric sh0 a ->
   General sh0 sh1 a ->
   Symmetric sh1 a ->
   Symmetric (sh0:+:sh1) a
stack :: Symmetric sh0 a
-> General sh0 sh1 a
-> Symmetric sh1 a
-> Symmetric (sh0 :+: sh1) a
stack = Symmetric sh0 a
-> General sh0 sh1 a
-> Symmetric sh1 a
-> Symmetric (sh0 :+: sh1) a
forall diag sh0 sh1 a.
(TriDiag diag, C sh0, Eq sh0, C sh1, Eq sh1, Floating a) =>
FlexSymmetric diag sh0 a
-> General sh0 sh1 a
-> FlexSymmetric diag sh1 a
-> FlexSymmetric diag (sh0 :+: sh1) a
Triangular.stackSymmetric

infixr 2 #%%%#

(#%%%#) ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   (Symmetric sh0 a, General sh0 sh1 a) ->
   Symmetric sh1 a ->
   Symmetric (sh0:+:sh1) a
#%%%# :: (Symmetric sh0 a, General sh0 sh1 a)
-> Symmetric sh1 a -> Symmetric (sh0 :+: sh1) a
(#%%%#) = (Symmetric sh0 a
 -> General sh0 sh1 a
 -> Symmetric sh1 a
 -> Symmetric (sh0 :+: sh1) a)
-> (Symmetric sh0 a, General sh0 sh1 a)
-> Symmetric sh1 a
-> Symmetric (sh0 :+: sh1) a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Symmetric sh0 a
-> General sh0 sh1 a
-> Symmetric sh1 a
-> Symmetric (sh0 :+: sh1) a
forall sh0 sh1 a.
(C sh0, Eq sh0, C sh1, Eq sh1, Floating a) =>
Symmetric sh0 a
-> General sh0 sh1 a
-> Symmetric sh1 a
-> Symmetric (sh0 :+: sh1) a
stack


split ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   Symmetric (sh0:+:sh1) a ->
   (Symmetric sh0 a, General sh0 sh1 a, Symmetric sh1 a)
split :: Symmetric (sh0 :+: sh1) a
-> (Symmetric sh0 a, General sh0 sh1 a, Symmetric sh1 a)
split = Symmetric (sh0 :+: sh1) a
-> (Symmetric sh0 a, General sh0 sh1 a, Symmetric sh1 a)
forall diag sh0 sh1 a.
(TriDiag diag, C sh0, Eq sh0, C sh1, Eq sh1, Floating a) =>
FlexSymmetric diag (sh0 :+: sh1) a
-> (FlexSymmetric diag sh0 a, General sh0 sh1 a,
    FlexSymmetric diag sh1 a)
Triangular.splitSymmetric



{- |
gramian A = A^T * A
-}
gramian ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Symmetric width a
gramian :: General height width a -> Symmetric width a
gramian = (Array (General height width) a
 -> Array (FlexSymmetric NonUnit width) a)
-> General height width a -> Symmetric width a
forall shA a shB b.
(Array shA a -> Array shB b)
-> ArrayMatrix shA a -> ArrayMatrix shB b
ArrMatrix.lift1 Array (General height width) a
-> Array (FlexSymmetric NonUnit width) a
forall height width a.
(C height, C width, Floating a) =>
General height width a -> Symmetric width a
Basic.gramian

{- |
gramianTransposed A = A * A^T = gramian (A^T)
-}
gramianTransposed ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Symmetric height a
gramianTransposed :: General height width a -> Symmetric height a
gramianTransposed = (Array (General height width) a
 -> Array (FlexSymmetric NonUnit height) a)
-> General height width a -> Symmetric height a
forall shA a shB b.
(Array shA a -> Array shB b)
-> ArrayMatrix shA a -> ArrayMatrix shB b
ArrMatrix.lift1 Array (General height width) a
-> Array (FlexSymmetric NonUnit height) a
forall height width a.
(C height, C width, Floating a) =>
General height width a -> Symmetric height a
Basic.gramianTransposed

{- |
congruenceDiagonal D A = A^T * D * A
-}
congruenceDiagonal ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> General height width a -> Symmetric width a
congruenceDiagonal :: Vector height a -> General height width a -> Symmetric width a
congruenceDiagonal = (Array (General height width) a
 -> Array (FlexSymmetric NonUnit width) a)
-> General height width a -> Symmetric width a
forall shA a shB b.
(Array shA a -> Array shB b)
-> ArrayMatrix shA a -> ArrayMatrix shB b
ArrMatrix.lift1 ((Array (General height width) a
  -> Array (FlexSymmetric NonUnit width) a)
 -> General height width a -> Symmetric width a)
-> (Vector height a
    -> Array (General height width) a
    -> Array (FlexSymmetric NonUnit width) a)
-> Vector height a
-> General height width a
-> Symmetric width a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector height a
-> Array (General height width) a
-> Array (FlexSymmetric NonUnit width) a
forall height width a.
(C height, Eq height, C width, Floating a) =>
Vector height a -> General height width a -> Symmetric width a
Basic.congruenceDiagonal

{- |
congruenceDiagonalTransposed A D = A * D * A^T
-}
congruenceDiagonalTransposed ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Vector width a -> Symmetric height a
congruenceDiagonalTransposed :: General height width a -> Vector width a -> Symmetric height a
congruenceDiagonalTransposed General height width a
a =
   Array (FlexSymmetric NonUnit height) a -> Symmetric height a
forall shA a. Array shA a -> ArrayMatrix shA a
ArrMatrix.lift0 (Array (FlexSymmetric NonUnit height) a -> Symmetric height a)
-> (Vector width a -> Array (FlexSymmetric NonUnit height) a)
-> Vector width a
-> Symmetric height a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. General height width a
-> Vector width a -> Array (FlexSymmetric NonUnit height) a
forall height width a.
(C height, C width, Eq width, Floating a) =>
General height width a -> Vector width a -> Symmetric height a
Basic.congruenceDiagonalTransposed (General height width a -> General height width a
forall sh a. ArrayMatrix sh a -> Array sh a
ArrMatrix.toVector General height width a
a)

{- |
congruence B A = A^T * B * A
-}
congruence ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Symmetric height a -> General height width a -> Symmetric width a
congruence :: Symmetric height a -> General height width a -> Symmetric width a
congruence = (Array (FlexSymmetric NonUnit height) a
 -> Array (General height width) a
 -> Array (FlexSymmetric NonUnit width) a)
-> Symmetric height a
-> General height width a
-> Symmetric width a
forall shA a shB b shC c.
(Array shA a -> Array shB b -> Array shC c)
-> ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c
ArrMatrix.lift2 Array (FlexSymmetric NonUnit height) a
-> Array (General height width) a
-> Array (FlexSymmetric NonUnit width) a
forall height width a.
(C height, Eq height, C width, Floating a) =>
Symmetric height a -> General height width a -> Symmetric width a
Basic.congruence

{- |
congruenceTransposed B A = A * B * A^T
-}
congruenceTransposed ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Symmetric width a -> Symmetric height a
congruenceTransposed :: General height width a -> Symmetric width a -> Symmetric height a
congruenceTransposed = (Array (General height width) a
 -> Array (FlexSymmetric NonUnit width) a
 -> Array (FlexSymmetric NonUnit height) a)
-> General height width a
-> Symmetric width a
-> Symmetric height a
forall shA a shB b shC c.
(Array shA a -> Array shB b -> Array shC c)
-> ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c
ArrMatrix.lift2 Array (General height width) a
-> Array (FlexSymmetric NonUnit width) a
-> Array (FlexSymmetric NonUnit height) a
forall height width a.
(C height, C width, Eq width, Floating a) =>
General height width a -> Symmetric width a -> Symmetric height a
Basic.congruenceTransposed


{- |
anticommutator A B  =  A^T * B + B^T * A

Not exactly a matrix anticommutator,
thus I like to call it Symmetric anticommutator.
-}
anticommutator ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a -> Symmetric width a
anticommutator :: Full vert horiz height width a
-> Full vert horiz height width a -> Symmetric width a
anticommutator = (Array (Full vert horiz height width) a
 -> Array (Full vert horiz height width) a
 -> Array (FlexSymmetric NonUnit width) a)
-> Full vert horiz height width a
-> Full vert horiz height width a
-> Symmetric width a
forall shA a shB b shC c.
(Array shA a -> Array shB b -> Array shC c)
-> ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c
ArrMatrix.lift2 ((Array (Full vert horiz height width) a
  -> Array (Full vert horiz height width) a
  -> Array (FlexSymmetric NonUnit width) a)
 -> Full vert horiz height width a
 -> Full vert horiz height width a
 -> Symmetric width a)
-> (Array (Full vert horiz height width) a
    -> Array (Full vert horiz height width) a
    -> Array (FlexSymmetric NonUnit width) a)
-> Full vert horiz height width a
-> Full vert horiz height width a
-> Symmetric width a
forall a b. (a -> b) -> a -> b
$ a
-> Array (Full vert horiz height width) a
-> Array (Full vert horiz height width) a
-> Array (FlexSymmetric NonUnit width) a
forall vert horiz height width a.
(C vert, C horiz, C height, Eq height, C width, Eq width,
 Floating a) =>
a
-> Full vert horiz height width a
-> Full vert horiz height width a
-> Symmetric width a
Basic.scaledAnticommutator a
forall a. Floating a => a
one

{- |
anticommutatorTransposed A B
   = A * B^T + B * A^T
   = anticommutator (transpose A) (transpose B)
-}
anticommutatorTransposed ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a -> Symmetric height a
anticommutatorTransposed :: Full vert horiz height width a
-> Full vert horiz height width a -> Symmetric height a
anticommutatorTransposed =
   (Array (Full vert horiz height width) a
 -> Array (Full vert horiz height width) a
 -> Array (FlexSymmetric NonUnit height) a)
-> Full vert horiz height width a
-> Full vert horiz height width a
-> Symmetric height a
forall shA a shB b shC c.
(Array shA a -> Array shB b -> Array shC c)
-> ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c
ArrMatrix.lift2 ((Array (Full vert horiz height width) a
  -> Array (Full vert horiz height width) a
  -> Array (FlexSymmetric NonUnit height) a)
 -> Full vert horiz height width a
 -> Full vert horiz height width a
 -> Symmetric height a)
-> (Array (Full vert horiz height width) a
    -> Array (Full vert horiz height width) a
    -> Array (FlexSymmetric NonUnit height) a)
-> Full vert horiz height width a
-> Full vert horiz height width a
-> Symmetric height a
forall a b. (a -> b) -> a -> b
$ a
-> Array (Full vert horiz height width) a
-> Array (Full vert horiz height width) a
-> Array (FlexSymmetric NonUnit height) a
forall vert horiz height width a.
(C vert, C horiz, C height, Eq height, C width, Eq width,
 Floating a) =>
a
-> Full vert horiz height width a
-> Full vert horiz height width a
-> Symmetric height a
Basic.scaledAnticommutatorTransposed a
forall a. Floating a => a
one