{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Class (
   SquareShape(toSquare, identityOrder, takeDiagonal),
   identityFrom,
   identityFromHeight,
   identityFromWidth,
   trace,
   Complex(conjugate, fromReal, toComplex),
   ) where

import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Plain.Class as Plain
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Permutation as Permutation
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Array (ArrayMatrix)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, ComplexOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape


class Complex typ where
   conjugate ::
      (Class.Floating a) => Type.Matrix typ a -> Type.Matrix typ a
   fromReal ::
      (Class.Floating a) => Type.Matrix typ (RealOf a) -> Type.Matrix typ a
   toComplex ::
      (Class.Floating a) => Type.Matrix typ a -> Type.Matrix typ (ComplexOf a)

instance (Plain.Complex sh) => Complex (ArrMatrix.Array sh) where
   conjugate = ArrMatrix.lift1 Plain.conjugate
   fromReal  = ArrMatrix.lift1 Plain.fromReal
   toComplex = ArrMatrix.lift1 Plain.toComplex

instance (Shape.C shape) => Complex (Type.Scale shape) where
   conjugate (Type.Scale sh m) = Type.Scale sh $ Scalar.conjugate m
   fromReal (Type.Scale sh m) = Type.Scale sh $ Scalar.fromReal m
   toComplex (Type.Scale sh m) = Type.Scale sh $ Scalar.toComplex m

instance (Shape.C shape) => Complex (Perm.Permutation shape) where
   conjugate = id
   fromReal (Type.Permutation p) = Type.Permutation p
   toComplex (Type.Permutation p) = Type.Permutation p


class SquareShape typ where
   toSquare ::
      (Type.HeightOf typ ~ sh, Class.Floating a) =>
      Type.Matrix typ a -> ArrMatrix.Square sh a
   identityOrder ::
      (Type.HeightOf typ ~ sh, Class.Floating a) =>
      MatrixShape.Order -> sh -> Type.Matrix typ a
   takeDiagonal ::
      (Type.HeightOf typ ~ sh, Class.Floating a) =>
      Type.Matrix typ a -> Vector sh a

instance (ArrMatrix.SquareShape sh) => SquareShape (ArrMatrix.Array sh) where
   toSquare = ArrMatrix.lift1 Plain.toSquare
   identityOrder order = ArrMatrix.lift0 . Plain.identityOrder order
   takeDiagonal = Plain.takeDiagonal . ArrMatrix.toVector

instance (Shape.C sh) => SquareShape (Type.Scale sh) where
   toSquare (Type.Scale sh a) =
      Triangular.toSquare $ Triangular.asDiagonal $
      Triangular.diagonal MatrixShape.RowMajor $ Vector.constant sh a
   identityOrder _ sh = Type.Scale sh Scalar.one
   takeDiagonal (Type.Scale sh a) = Vector.constant sh a

instance (Shape.C sh) => SquareShape (Perm.Permutation sh) where
   toSquare = Permutation.toMatrix
   identityOrder _ = Permutation.identity
   takeDiagonal = Perm.takeDiagonal . Permutation.toPermutation


identityFrom ::
   (Plain.SquareShape shape, ArrMatrix.ShapeOrder shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a
identityFrom m =
   identityOrder (ArrMatrix.shapeOrder $ ArrMatrix.shape m) (Type.height m)

identityFromHeight ::
   (ArrMatrix.ShapeOrder shape, MatrixShape.Box shape,
    MatrixShape.HeightOf shape ~ Type.HeightOf typ, SquareShape typ,
    Class.Floating a) =>
   ArrayMatrix shape a -> Type.Matrix typ a
identityFromHeight m =
   identityOrder (ArrMatrix.shapeOrder $ ArrMatrix.shape m) (Type.height m)

identityFromWidth ::
   (ArrMatrix.ShapeOrder shape, MatrixShape.Box shape,
    MatrixShape.WidthOf shape ~ Type.HeightOf typ, SquareShape typ,
    Class.Floating a) =>
   ArrayMatrix shape a -> Type.Matrix typ a
identityFromWidth m =
   identityOrder (ArrMatrix.shapeOrder $ ArrMatrix.shape m) (Type.width m)

trace ::
   (SquareShape typ, Type.HeightOf typ ~ sh, Shape.C sh, Class.Floating a) =>
   Type.Matrix typ a -> a
trace = Vector.sum . takeDiagonal