{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyDataDecls #-}
module Numeric.LAPACK.Orthogonal.Private where
import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Split as Split
import qualified Numeric.LAPACK.Shape as ExtShape
import Numeric.LAPACK.Output ((/+/))
import Numeric.LAPACK.Matrix.Plain.Format (formatArray)
import Numeric.LAPACK.Matrix.Type (FormatMatrix(formatMatrix))
import Numeric.LAPACK.Matrix.Triangular.Basic (Upper)
import Numeric.LAPACK.Matrix.Shape.Private
(Order(RowMajor, ColumnMajor), sideSwapFromOrder)
import Numeric.LAPACK.Matrix.Extent.Private (Extent)
import Numeric.LAPACK.Matrix.Modifier
(Transposition(NonTransposed, Transposed),
Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, isZero, absolute, conjugate)
import Numeric.LAPACK.Private
(fill, copySubMatrix, copyBlock, conjugateToTemp, caseRealComplexFunc,
withAutoWorkspaceInfo, errorCodeMsg)
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (liftA2)
import qualified Data.List as List
import Data.Monoid ((<>))
data Hh vert horiz height width
data instance Type.Matrix (Hh vert horiz height width) a =
Householder {
tau_ :: Vector (ExtShape.Min height width) a,
split_ ::
Array
(MatrixShape.Split MatrixShape.Reflector vert horiz height width) a
} deriving (Show)
type Householder vert horiz height width =
Type.Matrix (Hh vert horiz height width)
type General height width = Householder Extent.Big Extent.Big height width
type Tall height width = Householder Extent.Big Extent.Small height width
type Wide height width = Householder Extent.Small Extent.Big height width
type Square sh = Householder Extent.Small Extent.Small sh sh
extent_ ::
Householder vert horiz height width a ->
Extent vert horiz height width
extent_ = MatrixShape.splitExtent . Array.shape . split_
mapExtent ::
(Extent.C vertA, Extent.C horizA) =>
(Extent.C vertB, Extent.C horizB) =>
Extent.Map vertA horizA vertB horizB height width ->
Householder vertA horizA height width a ->
Householder vertB horizB height width a
mapExtent f (Householder tau split) =
Householder tau $ Array.mapShape (MatrixShape.splitMapExtent f) split
caseTallWide ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
Householder vert horiz height width a ->
Either (Tall height width a) (Wide height width a)
caseTallWide (Householder tau (Array shape a)) =
either
(Left . Householder tau . flip Array a)
(Right . Householder tau . flip Array a) $
MatrixShape.caseTallWideSplit shape
instance
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
FormatMatrix (Hh vert horiz height width) where
formatMatrix fmt (Householder tau m) =
formatArray fmt (Array.mapShape (Shape.ZeroBased . Shape.size) tau)
/+/
formatArray fmt m
fromMatrix ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Full vert horiz height width a ->
Householder vert horiz height width a
fromMatrix (Array shape@(MatrixShape.Full order extent) a) =
uncurry Householder $
Array.unsafeCreateWithSizeAndResult
(uncurry ExtShape.Min $ Extent.dimensions extent) $ \_ tauPtr ->
ArrayIO.unsafeCreate
(MatrixShape.Split MatrixShape.Reflector order extent) $ \qrPtr ->
evalContT $ do
let (m,n) = MatrixShape.dimensions shape
mPtr <- Call.cint m
nPtr <- Call.cint n
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim m
liftIO $ do
copyBlock (m*n) aPtr qrPtr
case order of
RowMajor ->
withAutoWorkspaceInfo errorCodeMsg "gelqf" $
LapackGen.gelqf mPtr nPtr qrPtr ldaPtr tauPtr
ColumnMajor ->
withAutoWorkspaceInfo errorCodeMsg "geqrf" $
LapackGen.geqrf mPtr nPtr qrPtr ldaPtr tauPtr
determinantR ::
(Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
Householder vert Extent.Small height width a -> a
determinantR = Split.determinantR . split_
determinant ::
(Shape.C sh, Class.Floating a) =>
Square sh a -> a
determinant (Householder tau split) =
List.foldl' (*) (Split.determinantR split) $
(case MatrixShape.splitOrder $ Array.shape split of
RowMajor -> map conjugate
ColumnMajor -> id) $
map (negate.(^(2::Int)).signum) $
filter (not . isZero) $ Array.toList tau
determinantAbsolute ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Householder vert horiz height width a -> RealOf a
determinantAbsolute =
absolute . either determinantR (const zero) . caseTallWide
leastSquares ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Eq width, Shape.C nrhs,
Class.Floating a) =>
Householder horiz Extent.Small height width a ->
Full vert horiz height nrhs a ->
Full vert horiz width nrhs a
leastSquares qr =
tallSolveR NonTransposed NonConjugated qr . tallMultiplyQAdjoint qr
minimumNorm ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Eq width, Shape.C nrhs,
Class.Floating a) =>
Householder vert Extent.Small width height a ->
Full vert horiz height nrhs a ->
Full vert horiz width nrhs a
minimumNorm qr = tallMultiplyQ qr . tallSolveR Transposed Conjugated qr
takeRows ::
(Extent.C vert, Extent.C horiz,
Eq fuse, Shape.C fuse, Shape.C height, Shape.C width, Class.Floating a) =>
Extent Extent.Small horiz height fuse ->
Full vert horiz fuse width a ->
Full vert horiz height width a
takeRows extentA (Array (MatrixShape.Full order extentB) b) =
case Extent.fuse (ExtentPriv.generalizeWide extentA) extentB of
Nothing -> error "Householder.takeRows: heights mismatch"
Just extentC ->
Basic.takeSub
(Extent.height extentB) 0 b (MatrixShape.Full order extentC)
addRows ::
(Extent.C vert, Extent.C horiz,
Eq fuse, Shape.C fuse, Shape.C height, Shape.C width, Class.Floating a) =>
Extent vert Extent.Small height fuse ->
Full vert horiz fuse width a ->
Full vert horiz height width a
addRows extentA (Array shapeB@(MatrixShape.Full order extentB) b) =
case Extent.fuse (ExtentPriv.generalizeTall extentA) extentB of
Nothing -> error "Householder.addRows: heights mismatch"
Just extentC ->
Array.unsafeCreateWithSize (MatrixShape.Full order extentC) $
\cSize cPtr ->
withForeignPtr b $ \bPtr ->
case order of
RowMajor -> do
let bSize = Shape.size shapeB
copyBlock bSize bPtr cPtr
fill zero (cSize - bSize) (advancePtr cPtr bSize)
ColumnMajor -> do
let n = Shape.size $ Extent.width extentB
mb = Shape.size $ Extent.height extentB
mc = Shape.size $ Extent.height extentC
copySubMatrix mb n mb bPtr mc cPtr
evalContT $ do
uploPtr <- Call.char 'A'
mPtr <- Call.cint (mc-mb)
nPtr <- Call.cint n
ldcPtr <- Call.leadingDim mc
zPtr <- Call.number zero
liftIO $
LapackGen.laset uploPtr mPtr nPtr zPtr zPtr
(advancePtr cPtr mb) ldcPtr
extractQ ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Householder vert horiz height width a -> Matrix.Square height a
extractQ
(Householder tau (Array (MatrixShape.Split _ order extent) qr)) =
extractQAux tau (Extent.width extent) order
(Extent.square $ Extent.height extent) qr
tallExtractQ ::
(Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
Householder vert Extent.Small height width a ->
Full vert Extent.Small height width a
tallExtractQ
(Householder tau (Array (MatrixShape.Split _ order extent) qr)) =
extractQAux tau (Extent.width extent) order extent qr
extractQAux ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Shape.C widthQR,
Class.Floating a) =>
Vector (ExtShape.Min height widthQR) a -> widthQR ->
Order -> Extent vert horiz height width -> ForeignPtr a ->
Full vert horiz height width a
extractQAux (Array widthTau tau) widthQR order extent qr =
Array.unsafeCreate (MatrixShape.Full order extent) $ \qPtr -> do
let (height,width) = Extent.dimensions extent
let m = Shape.size height
let n = Shape.size width
let k = Shape.size widthTau
evalContT $ do
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- Call.cint k
qrPtr <- ContT $ withForeignPtr qr
tauPtr <- ContT $ withForeignPtr tau
case order of
RowMajor -> do
ldaPtr <- Call.leadingDim n
liftIO $ do
copySubMatrix k m (Shape.size widthQR) qrPtr n qPtr
withAutoWorkspaceInfo errorCodeMsg "unglq" $
LapackGen.unglq nPtr mPtr kPtr qPtr ldaPtr tauPtr
ColumnMajor -> do
ldaPtr <- Call.leadingDim m
liftIO $ do
copyBlock (m*k) qrPtr qPtr
withAutoWorkspaceInfo errorCodeMsg "ungqr" $
LapackGen.ungqr mPtr nPtr kPtr qPtr ldaPtr tauPtr
tallMultiplyQ ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse,
Class.Floating a) =>
Householder vert Extent.Small height fuse a ->
Full vert horiz fuse width a ->
Full vert horiz height width a
tallMultiplyQ qr =
multiplyQ NonTransposed NonConjugated qr . addRows (extent_ qr)
tallMultiplyQAdjoint ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
Householder horiz Extent.Small fuse height a ->
Full vert horiz fuse width a ->
Full vert horiz height width a
tallMultiplyQAdjoint qr =
takeRows (Extent.transpose $ extent_ qr) .
multiplyQ Transposed Conjugated qr
multiplyQ ::
(Extent.C vertA, Extent.C horizA, Shape.C widthA,
Extent.C vertB, Extent.C horizB, Shape.C widthB,
Shape.C height, Eq height, Class.Floating a) =>
Transposition -> Conjugation ->
Householder vertA horizA height widthA a ->
Full vertB horizB height widthB a ->
Full vertB horizB height widthB a
multiplyQ transposed conjugated
(Householder
(Array widthTau tau)
(Array shapeA@(MatrixShape.Split _ orderA extentA) qr))
(Array shapeB@(MatrixShape.Full orderB extentB) b) =
Array.unsafeCreateWithSize shapeB $ \cSize cPtr -> do
let (heightA,widthA) = Extent.dimensions extentA
let (height,width) = Extent.dimensions extentB
Call.assert "Householder.multiplyQ: height shapes mismatch"
(heightA == height)
let (side,(m,n)) =
sideSwapFromOrder orderB (Shape.size height, Shape.size width)
evalContT $ do
sidePtr <- Call.char side
mPtr <- Call.cint m
nPtr <- Call.cint n
let k = Shape.size widthTau
kPtr <- Call.cint k
transPtr <-
Call.char $ adjointFromTranspose qr $
transposed <> if orderA==orderB then NonTransposed else Transposed
(qrPtr,tauPtr) <-
if (orderA==orderB)
==
(transposed==NonTransposed && conjugated==NonConjugated
||
transposed==Transposed && conjugated==Conjugated)
then
liftA2 (,)
(ContT $ withForeignPtr qr)
(ContT $ withForeignPtr tau)
else
liftA2 (,)
(conjugateToTemp (Shape.size shapeA) qr)
(conjugateToTemp k tau)
bPtr <- ContT $ withForeignPtr b
ldcPtr <- Call.leadingDim m
liftIO $ copyBlock cSize bPtr cPtr
case orderA of
ColumnMajor -> do
ldaPtr <- Call.leadingDim $ Shape.size heightA
liftIO $ withAutoWorkspaceInfo errorCodeMsg "unmqr" $
LapackGen.unmqr sidePtr transPtr
mPtr nPtr kPtr qrPtr ldaPtr tauPtr cPtr ldcPtr
RowMajor -> do
ldaPtr <- Call.leadingDim $ Shape.size widthA
liftIO $ when (k>0) $
withAutoWorkspaceInfo errorCodeMsg "unmlq" $
LapackGen.unmlq sidePtr transPtr
mPtr nPtr kPtr qrPtr ldaPtr tauPtr cPtr ldcPtr
adjointFromTranspose :: (Class.Floating a) => f a -> Transposition -> Char
adjointFromTranspose qr Transposed = invChar qr
adjointFromTranspose _ NonTransposed = 'N'
invChar :: (Class.Floating a) => f a -> Char
invChar f = caseRealComplexFunc f 'T' 'C'
extractR ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Householder vert horiz height width a ->
Full vert horiz height width a
extractR = Split.extractTriangle (Right MatrixShape.Triangle) . split_
tallExtractR ::
(Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
Householder vert Extent.Small height width a -> Upper width a
tallExtractR = Split.tallExtractR . split_
tallMultiplyR ::
(Extent.C vertA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
Shape.C heightA, Shape.C widthB, Class.Floating a) =>
Transposition ->
Householder vertA Extent.Small heightA height a ->
Full vert horiz height widthB a ->
Full vert horiz height widthB a
tallMultiplyR transposed = Split.tallMultiplyR transposed . split_
tallSolveR ::
(Extent.C vertA, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
Transposition -> Conjugation ->
Householder vertA Extent.Small height width a ->
Full vert horiz width nrhs a -> Full vert horiz width nrhs a
tallSolveR transposed conjugated =
Split.tallSolveR transposed conjugated . split_
instance
(Extent.C vert, Extent.C horiz) =>
Type.Box (Hh vert horiz height width) where
type HeightOf (Hh vert horiz height width) = height
type WidthOf (Hh vert horiz height width) = width
height = MatrixShape.splitHeight . Array.shape . split_
width = MatrixShape.splitWidth . Array.shape . split_
instance
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Eq width) =>
Multiply.MultiplyVector (Hh vert horiz height width) where
matrixVector qr x =
Basic.unliftColumn MatrixShape.ColumnMajor
(multiplyQ NonTransposed NonConjugated qr) $
Basic.multiplyVector (extractR qr) x
vectorMatrix x qr =
Basic.multiplyVector (Basic.transpose $ extractR qr) $
Basic.unliftColumn MatrixShape.ColumnMajor
(multiplyQ Transposed NonConjugated qr) x
instance
(vert ~ Extent.Small, horiz ~ Extent.Small,
Shape.C height, height ~ width) =>
Multiply.MultiplySquare (Hh vert horiz height width) where
squareFull qr =
ArrMatrix.lift1 $
multiplyQ NonTransposed NonConjugated qr .
tallMultiplyR NonTransposed qr
fullSquare = flip $ \qr ->
ArrMatrix.lift1 $
Basic.transpose .
tallMultiplyR Transposed qr .
multiplyQ Transposed NonConjugated qr .
Basic.transpose
instance
(vert ~ Extent.Small, horiz ~ Extent.Small,
Shape.C height, height ~ width) =>
Divide.Determinant (Hh vert horiz height width) where
determinant = determinant
instance
(vert ~ Extent.Small, horiz ~ Extent.Small,
Shape.C height, height ~ width) =>
Divide.Solve (Hh vert horiz height width) where
solveRight = ArrMatrix.lift1 . leastSquares . mapExtent Extent.generalizeWide
solveLeft =
flip $ \a -> ArrMatrix.lift1 $
Basic.adjoint .
minimumNorm (mapExtent Extent.generalizeWide a) .
Basic.adjoint