{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Divide where

import qualified Numeric.LAPACK.Matrix.Array.Divide as ArrDivide
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array (Full)
import Numeric.LAPACK.Matrix.Type (Matrix, scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(Inverted))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape

import Data.Semigroup ((<>))


transposeFull ::
   (Extent.C vert, Extent.C horiz) =>
   Full vert horiz height width a -> Full horiz vert width height a
transposeFull = ArrMatrix.lift1 Basic.transpose

class (Type.Box typ, Type.HeightOf typ ~ Type.WidthOf typ) => Solve typ where
   {-# MINIMAL solve | solveLeft,solveRight #-}
   solve ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Transposition -> Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   solve NonTransposed a b = solveRight a b
   solve Transposed a b = transposeFull $ solveLeft (transposeFull b) a

   solveRight ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   solveRight = solve NonTransposed

   solveLeft ::
      (Type.WidthOf typ ~ width, Eq width, Shape.C height,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Full vert horiz height width a ->
      Matrix typ a -> Full vert horiz height width a
   solveLeft b a = transposeFull $ solve Transposed a $ transposeFull b

class (Solve typ) => Inverse typ where
   inverse :: (Class.Floating a) => Matrix typ a -> Matrix typ a

infixl 7 ##/#
infixr 7 #\##

(#\##) ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Shape.C width,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Matrix typ a ->
   Full vert horiz height width a -> Full vert horiz height width a
(#\##) = solveRight

(##/#) ::
   (Solve typ, Type.WidthOf typ ~ width, Eq width, Shape.C height,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Full vert horiz height width a ->
   Matrix typ a -> Full vert horiz height width a
(##/#) = solveLeft


solveVector ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
   Transposition -> Matrix typ a -> Vector height a -> Vector height a
solveVector trans =
   ArrMatrix.unliftColumn MatrixShape.ColumnMajor . solve trans

infixl 7 -/#
infixr 7 #\|

(#\|) ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
   Matrix typ a -> Vector height a -> Vector height a
(#\|) = solveVector NonTransposed

(-/#) ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
   Vector height a -> Matrix typ a -> Vector height a
(-/#) = flip $ solveVector Transposed


instance (Shape.C shape, Eq shape) => Solve (Type.Scale shape) where
   solve _trans =
      scaleWithCheck "Matrix.Scale.solve" Type.height $
         ArrMatrix.lift1 . Vector.scale . recip

instance (Shape.C shape, Eq shape) => Inverse (Type.Scale shape) where
   inverse (Type.Scale shape a) = Type.Scale shape $ recip a


instance (Shape.C shape) => Solve (PermMatrix.Permutation shape) where
   solve trans =
      PermMatrix.multiplyFull
         (Inverted <> PermMatrix.inversionFromTransposition trans)

instance (Shape.C shape) => Inverse (PermMatrix.Permutation shape) where
   inverse = PermMatrix.transpose


instance (ArrDivide.Solve shape) => Solve (ArrMatrix.Array shape) where
   solve = ArrMatrix.lift2 . ArrDivide.solve
   solveLeft = ArrMatrix.lift2 ArrDivide.solveLeft
   solveRight = ArrMatrix.lift2 ArrDivide.solveRight

instance (ArrDivide.Inverse shape) => Inverse (ArrMatrix.Array shape) where
   inverse = ArrMatrix.lift1 ArrDivide.inverse