{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Orthogonal.Plain (
   leastSquares,
   minimumNorm,
   leastSquaresMinimumNormRCond,
   pseudoInverseRCond,

   determinantAbsolute,
   complement,
   ) where

import qualified Numeric.LAPACK.Orthogonal.Private as HH

import qualified Numeric.LAPACK.Matrix.Square.Basic as Square
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor,ColumnMajor))
import Numeric.LAPACK.Matrix.Private (Full, Tall, ShapeInt, shapeInt)
import Numeric.LAPACK.Scalar (RealOf, zero, absolute)
import Numeric.LAPACK.Private
         (lacgv, peekCInt,
          copySubMatrix, copyToTemp, copyToColumnMajor, copyToSubColumnMajor,
          withAutoWorkspaceInfo, rankMsg, errorCodeMsg, createHigherArray)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.LAPACK.FFI.Real as LapackReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.Marshal.Array (pokeArray)
import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Complex (Complex)
import Data.Tuple.HT (mapSnd)


leastSquares ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Full horiz Extent.Small height width a ->
   Full vert horiz height nrhs a ->
   Full vert horiz width nrhs a
leastSquares
   (Array shapeA@(MatrixShape.Full orderA extentA) a)
   (Array shapeB@(MatrixShape.Full orderB extentB) b) =

 case Extent.fuse (Extent.generalizeWide $ Extent.transpose extentA) extentB of
  Nothing -> error "leastSquares: height shapes mismatch"
  Just extent ->
      Array.unsafeCreate (MatrixShape.Full ColumnMajor extent) $ \xPtr -> do

   let widthA = Extent.width extentA
   let (height,widthB) = Extent.dimensions extentB
   let (m,n) = MatrixShape.dimensions shapeA
   let lda = m
   let nrhs = Shape.size widthB
   let ldb = Shape.size height
   let ldx = Shape.size widthA
   evalContT $ do
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      nrhsPtr <- Call.cint nrhs
      (transPtr,aPtr) <- adjointA orderA (m*n) a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      ldbPtr <- Call.leadingDim ldb
      let bSize = Shape.size shapeB
      btmpPtr <- Call.allocaArray bSize
      liftIO $ copyToColumnMajor orderB ldb nrhs bPtr btmpPtr
      liftIO $ withAutoWorkspaceInfo rankMsg "gels" $
         LapackGen.gels transPtr
            mPtr nPtr nrhsPtr aPtr ldaPtr btmpPtr ldbPtr
      liftIO $ copySubMatrix ldx nrhs ldb btmpPtr ldx xPtr

minimumNorm ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Full Extent.Small vert height width a ->
   Full vert horiz height nrhs a ->
   Full vert horiz width nrhs a
minimumNorm
   (Array shapeA@(MatrixShape.Full orderA extentA) a)
   (Array        (MatrixShape.Full orderB extentB) b) =

 case Extent.fuse (Extent.generalizeTall $ Extent.transpose extentA) extentB of
  Nothing -> error "minimumNorm: height shapes mismatch"
  Just extent ->
      Array.unsafeCreate (MatrixShape.Full ColumnMajor extent) $ \xPtr -> do

   let widthA = Extent.width extentA
   let (height,widthB) = Extent.dimensions extentB
   let (m,n) = MatrixShape.dimensions shapeA
   let lda = m
   let nrhs = Shape.size widthB
   let ldb = Shape.size height
   let ldx = Shape.size widthA
   evalContT $ do
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      nrhsPtr <- Call.cint nrhs
      (transPtr,aPtr) <- adjointA orderA (m*n) a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      ldxPtr <- Call.leadingDim ldx
      liftIO $ copyToSubColumnMajor orderB ldb nrhs bPtr ldx xPtr
      liftIO $ withAutoWorkspaceInfo rankMsg "gels" $
         LapackGen.gels transPtr
            mPtr nPtr nrhsPtr aPtr ldaPtr xPtr ldxPtr


adjointA ::
   Class.Floating a =>
   Order -> Int -> ForeignPtr a -> ContT r IO (Ptr CChar, Ptr a)
adjointA order size a = do
   aPtr <- copyToTemp size a
   trans <-
      case order of
         RowMajor -> do
            sizePtr <- Call.cint size
            incPtr <- Call.cint 1
            liftIO $ lacgv sizePtr aPtr incPtr
            return $ HH.invChar a
         ColumnMajor -> return 'N'
   transPtr <- Call.char trans
   return (transPtr, aPtr)


leastSquaresMinimumNormRCond ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   RealOf a ->
   Full horiz vert height width a ->
   Full vert horiz height nrhs a ->
   (Int, Full vert horiz width nrhs a)
leastSquaresMinimumNormRCond rcond
      (Array (MatrixShape.Full orderA extentA) a)
      (Array (MatrixShape.Full orderB extentB) b) =
   case Extent.fuse (Extent.transpose extentA) extentB of
      Nothing -> error "leastSquaresMinimumNormRCond: height shapes mismatch"
      Just extent ->
         let widthA = Extent.width extentA
             (height,widthB) = Extent.dimensions extentB
             shapeX = MatrixShape.Full ColumnMajor extent
             m = Shape.size height
             n = Shape.size widthA
             nrhs = Shape.size widthB
         in  if m == 0
                then (0, Vector.zero shapeX)
                else
                  if nrhs == 0
                     then
                        (fst $ unsafePerformIO $
                         case Vector.zero height of
                           Array _ b1 ->
                              leastSquaresMinimumNormIO rcond
                                 (MatrixShape.general ColumnMajor widthA ())
                                 orderA a orderB b1 m n 1,
                         Vector.zero shapeX)
                     else
                        unsafePerformIO $
                        leastSquaresMinimumNormIO rcond shapeX
                           orderA a orderB b m n nrhs

leastSquaresMinimumNormIO ::
   (Shape.C sh, Class.Floating a) =>
   RealOf a -> sh ->
   Order -> ForeignPtr a ->
   Order -> ForeignPtr a ->
   Int -> Int -> Int -> IO (Int, Array sh a)
leastSquaresMinimumNormIO rcond shapeX orderA a orderB b m n nrhs =
   createHigherArray shapeX m n nrhs $ \(tmpPtr,ldtmp) -> do

   let aSize = m*n
   let lda = m
   evalContT $ do
      aPtr <- ContT $ withForeignPtr a
      atmpPtr <- Call.allocaArray aSize
      liftIO $ copyToColumnMajor orderA m n aPtr atmpPtr
      ldaPtr <- Call.leadingDim lda
      ldtmpPtr <- Call.leadingDim ldtmp
      bPtr <- ContT $ withForeignPtr b
      liftIO $ copyToSubColumnMajor orderB m nrhs bPtr ldtmp tmpPtr
      jpvtPtr <- Call.allocaArray n
      liftIO $ pokeArray jpvtPtr (replicate n 0)
      rankPtr <- Call.alloca
      gelsy m n nrhs atmpPtr ldaPtr tmpPtr ldtmpPtr jpvtPtr rcond rankPtr
      liftIO $ peekCInt rankPtr


type GELSY_ r ar a =
   Int -> Int -> Int -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr CInt -> ar -> Ptr CInt -> ContT r IO ()

newtype GELSY r a = GELSY {getGELSY :: GELSY_ r (RealOf a) a}

gelsy :: (Class.Floating a) => GELSY_ r (RealOf a) a
gelsy =
   getGELSY $
   Class.switchFloating
      (GELSY gelsyReal)
      (GELSY gelsyReal)
      (GELSY gelsyComplex)
      (GELSY gelsyComplex)

gelsyReal :: (Class.Real a) => GELSY_ r a a
gelsyReal m n nrhs aPtr ldaPtr bPtr ldbPtr jpvtPtr rcond rankPtr = do
   mPtr <- Call.cint m
   nPtr <- Call.cint n
   nrhsPtr <- Call.cint nrhs
   rcondPtr <- Call.real rcond
   liftIO $ withAutoWorkspaceInfo errorCodeMsg "gelsy" $
      LapackReal.gelsy mPtr nPtr nrhsPtr
         aPtr ldaPtr bPtr ldbPtr jpvtPtr rcondPtr rankPtr

gelsyComplex :: (Class.Real a) => GELSY_ r a (Complex a)
gelsyComplex m n nrhs aPtr ldaPtr bPtr ldbPtr jpvtPtr rcond rankPtr = do
   mPtr <- Call.cint m
   nPtr <- Call.cint n
   nrhsPtr <- Call.cint nrhs
   rcondPtr <- Call.real rcond
   rworkPtr <- Call.allocaArray (2*n)
   liftIO $
      withAutoWorkspaceInfo errorCodeMsg "gelsy" $ \workPtr lworkPtr infoPtr ->
      LapackComplex.gelsy mPtr nPtr nrhsPtr
         aPtr ldaPtr bPtr ldbPtr jpvtPtr rcondPtr rankPtr
         workPtr lworkPtr rworkPtr infoPtr


pseudoInverseRCond ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   RealOf a ->
   Full vert horiz height width a ->
   (Int, Full horiz vert width height a)
pseudoInverseRCond rcond a =
   case Basic.caseTallWide a of
      Left _ ->
         mapSnd Basic.transpose $
         leastSquaresMinimumNormRCond rcond (Basic.transpose a) $
         Square.toFull $ Square.identity $
         MatrixShape.fullWidth $ Array.shape a
      Right _ ->
         leastSquaresMinimumNormRCond rcond a $
         Square.toFull $ Square.identity $
         MatrixShape.fullHeight $ Array.shape a


determinantAbsolute ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a -> RealOf a
determinantAbsolute =
   absolute .
   either (HH.determinantR . HH.fromMatrix) (const zero) .
   Basic.caseTallWide


complement ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Tall height width a -> Tall height ShapeInt a
complement a =
   Basic.dropColumns (Shape.size $ MatrixShape.fullWidth $ Array.shape a) $
   Basic.mapWidth (shapeInt . Shape.size) $ Square.toFull $
   HH.extractQ $ HH.fromMatrix a