{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Split where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Triangular.Private as TriPriv
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Tri
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Triangular.Private (diagonalPointers, unpack)
import Numeric.LAPACK.Matrix.Triangular.Basic (UnitLower, Upper)
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor, ColumnMajor), transposeFromOrder,
          swapOnRowMajor, sideSwapFromOrder,
          Triangle, uploFromOrder, flipOrder)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition, transposeOrder,
          Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver, withInfo)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (copyBlock, conjugateToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (poke)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


type Split lower vert horiz height width =
      Array (MatrixShape.Split lower vert horiz height width)

type Square lower sh = Split lower Extent.Small Extent.Small sh sh


determinantR ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   Split lower vert Extent.Small height width a -> a
determinantR :: Split lower vert Small height width a -> a
determinantR (Array (MatrixShape.Split lower
_ Order
order Extent vert Small height width
extent) ForeignPtr a
a) =
   let (height
height,width
width) = Extent vert Small height width -> (height, width)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vert Small height width
extent
       m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
       n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
       k :: Int
k = case Order
order of Order
RowMajor -> Int
n; Order
ColumnMajor -> Int
m
   in IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$
      ForeignPtr a -> (Ptr a -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a ((Ptr a -> IO a) -> IO a) -> (Ptr a -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr a
aPtr ->
      Int -> Ptr a -> Int -> IO a
forall a. Floating a => Int -> Ptr a -> Int -> IO a
Private.product (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
m Int
n) Ptr a
aPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)


extractTriangle ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Either lower Triangle ->
   Split lower vert horiz height width a ->
   Full vert horiz height width a
extractTriangle :: Either lower Triangle
-> Split lower vert horiz height width a
-> Full vert horiz height width a
extractTriangle Either lower Triangle
part (Array (MatrixShape.Split lower
_ Order
order Extent vert horiz height width
extent) ForeignPtr a
qr) =

   Full vert horiz height width
-> (Ptr a -> IO ()) -> Full vert horiz height width a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (Order
-> Extent vert horiz height width -> Full vert horiz height width
forall vert horiz height width.
Order
-> Extent vert horiz height width -> Full vert horiz height width
MatrixShape.Full Order
order Extent vert horiz height width
extent) ((Ptr a -> IO ()) -> Full vert horiz height width a)
-> (Ptr a -> IO ()) -> Full vert horiz height width a
forall a b. (a -> b) -> a -> b
$ \Ptr a
rPtr -> do

   let (height
height,width
width) = Extent vert horiz height width -> (height, width)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vert horiz height width
extent
   let ((Char
loup,Int
m), (Char
uplo,Int
n)) =
         Order -> ((Char, Int), (Char, Int)) -> ((Char, Int), (Char, Int))
forall a. Order -> (a, a) -> (a, a)
swapOnRowMajor Order
order
            ((Char
'L', height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height), (Char
'U', width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width))
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
loupPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
loup
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
uplo
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
qrPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
qr
      Ptr CInt
ldqrPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      Ptr CInt
ldrPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      Ptr a
zeroPtr <- a -> ContT () IO (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr a
onePtr <- a -> ContT () IO (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         case Either lower Triangle
part of
            Left lower
_ -> do
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
LapackGen.lacpy Ptr CChar
loupPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
qrPtr Ptr CInt
ldqrPtr Ptr a
rPtr Ptr CInt
ldrPtr
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
LapackGen.laset Ptr CChar
uploPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
zeroPtr Ptr a
onePtr Ptr a
rPtr Ptr CInt
ldrPtr
            Right Triangle
_ -> do
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
LapackGen.laset Ptr CChar
loupPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
zeroPtr Ptr a
zeroPtr Ptr a
rPtr Ptr CInt
ldrPtr
               Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
LapackGen.lacpy Ptr CChar
uploPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
qrPtr Ptr CInt
ldqrPtr Ptr a
rPtr Ptr CInt
ldrPtr


wideExtractL ::
   (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   Split lower Extent.Small horiz height width a -> UnitLower height a
wideExtractL :: Split lower Small horiz height width a -> UnitLower height a
wideExtractL =
   (Unit, Order -> Int -> Ptr a -> IO ())
-> Full Small horiz height width a -> UnitLower height a
forall diag horiz height width a.
(TriDiag diag, C horiz, C height, C width, Floating a) =>
(diag, Order -> Int -> Ptr a -> IO ())
-> Full Small horiz height width a -> FlexLower diag height a
TriPriv.takeLower
      (Unit
MatrixShape.Unit,
       \Order
order Int
m Ptr a
lPtr -> (Ptr a -> IO ()) -> [Ptr a] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Ptr a -> a -> IO ()) -> a -> Ptr a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke a
forall a. Floating a => a
one) ([Ptr a] -> IO ()) -> [Ptr a] -> IO ()
forall a b. (a -> b) -> a -> b
$ Order -> Int -> Ptr a -> [Ptr a]
forall a. Storable a => Order -> Int -> Ptr a -> [Ptr a]
diagonalPointers Order
order Int
m Ptr a
lPtr)
   (Full Small horiz height width a -> UnitLower height a)
-> (Split lower Small horiz height width a
    -> Full Small horiz height width a)
-> Split lower Small horiz height width a
-> UnitLower height a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   Split lower Small horiz height width a
-> Full Small horiz height width a
forall lower vert horiz height width a.
Split lower vert horiz height width a
-> Full vert horiz height width a
toFull

tallExtractR ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   Split lower vert Extent.Small height width a -> Upper width a
tallExtractR :: Split lower vert Small height width a -> Upper width a
tallExtractR = Full vert Small height width a -> Upper width a
forall vert height width a.
(C vert, C height, C width, Floating a) =>
Full vert Small height width a -> Upper width a
Tri.takeUpper (Full vert Small height width a -> Upper width a)
-> (Split lower vert Small height width a
    -> Full vert Small height width a)
-> Split lower vert Small height width a
-> Upper width a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Split lower vert Small height width a
-> Full vert Small height width a
forall lower vert horiz height width a.
Split lower vert horiz height width a
-> Full vert horiz height width a
toFull

toFull ::
   Split lower vert horiz height width a ->
   Full vert horiz height width a
toFull :: Split lower vert horiz height width a
-> Full vert horiz height width a
toFull =
   (Split lower vert horiz height width
 -> Full vert horiz height width)
-> Split lower vert horiz height width a
-> Full vert horiz height width a
forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
Array.mapShape
      (\(MatrixShape.Split lower
_ Order
order Extent vert horiz height width
extent) -> Order
-> Extent vert horiz height width -> Full vert horiz height width
forall vert horiz height width.
Order
-> Extent vert horiz height width -> Full vert horiz height width
MatrixShape.Full Order
order Extent vert horiz height width
extent)


wideMultiplyL ::
   (Extent.C horizA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
    Shape.C widthA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   Split Triangle Extent.Small horizA height widthA a ->
   Full vert horiz height widthB a ->
   Full vert horiz height widthB a
wideMultiplyL :: Transposition
-> Split Triangle Small horizA height widthA a
-> Full vert horiz height widthB a
-> Full vert horiz height widthB a
wideMultiplyL Transposition
transposed Split Triangle Small horizA height widthA a
a Full vert horiz height widthB a
b =
   if Split Triangle Small horizA height widthA -> height
forall vert horiz lower height width.
(C vert, C horiz) =>
Split lower vert horiz height width -> height
MatrixShape.splitHeight (Split Triangle Small horizA height widthA a
-> Split Triangle Small horizA height widthA
forall sh a. Array sh a -> sh
Array.shape Split Triangle Small horizA height widthA a
a) height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== Full vert horiz height widthB a
-> HeightOf (Full vert horiz height widthB)
forall shape a. Box shape => Array shape a -> HeightOf shape
Matrix.height Full vert horiz height widthB a
b
      then (Char, Char)
-> Char
-> Transposition
-> Split Triangle Small horizA height widthA a
-> Full vert horiz height widthB a
-> Full vert horiz height widthB a
forall vertA horizA vertB horizB heightA widthA heightB widthB a
       lower.
(C vertA, C horizA, C vertB, C horizB, C heightA, C widthA,
 C heightB, C widthB, Floating a) =>
(Char, Char)
-> Char
-> Transposition
-> Split lower vertA horizA heightA widthA a
-> Full vertB horizB heightB widthB a
-> Full vertB horizB heightB widthB a
multiplyTriangular (Char
'L',Char
'U') Char
'U' Transposition
transposed Split Triangle Small horizA height widthA a
a Full vert horiz height widthB a
b
      else [Char] -> Full vert horiz height widthB a
forall a. HasCallStack => [Char] -> a
error [Char]
"wideMultiplyL: height shapes mismatch"

tallMultiplyR ::
   (Extent.C vertA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
    Shape.C heightA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   Split lower vertA Extent.Small heightA height a ->
   Full vert horiz height widthB a ->
   Full vert horiz height widthB a
tallMultiplyR :: Transposition
-> Split lower vertA Small heightA height a
-> Full vert horiz height widthB a
-> Full vert horiz height widthB a
tallMultiplyR Transposition
transposed Split lower vertA Small heightA height a
a Full vert horiz height widthB a
b =
   if Split lower vertA Small heightA height -> height
forall vert horiz lower height width.
(C vert, C horiz) =>
Split lower vert horiz height width -> width
MatrixShape.splitWidth (Split lower vertA Small heightA height a
-> Split lower vertA Small heightA height
forall sh a. Array sh a -> sh
Array.shape Split lower vertA Small heightA height a
a) height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== Full vert horiz height widthB a
-> HeightOf (Full vert horiz height widthB)
forall shape a. Box shape => Array shape a -> HeightOf shape
Matrix.height Full vert horiz height widthB a
b
      then (Char, Char)
-> Char
-> Transposition
-> Split lower vertA Small heightA height a
-> Full vert horiz height widthB a
-> Full vert horiz height widthB a
forall vertA horizA vertB horizB heightA widthA heightB widthB a
       lower.
(C vertA, C horizA, C vertB, C horizB, C heightA, C widthA,
 C heightB, C widthB, Floating a) =>
(Char, Char)
-> Char
-> Transposition
-> Split lower vertA horizA heightA widthA a
-> Full vertB horizB heightB widthB a
-> Full vertB horizB heightB widthB a
multiplyTriangular (Char
'U',Char
'L') Char
'N' Transposition
transposed Split lower vertA Small heightA height a
a Full vert horiz height widthB a
b
      else [Char] -> Full vert horiz height widthB a
forall a. HasCallStack => [Char] -> a
error [Char]
"wideMultiplyR: height shapes mismatch"

multiplyTriangular ::
   (Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB,
    Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   (Char,Char) -> Char -> Transposition ->
   Split lower vertA horizA heightA widthA a ->
   Full vertB horizB heightB widthB a ->
   Full vertB horizB heightB widthB a
multiplyTriangular :: (Char, Char)
-> Char
-> Transposition
-> Split lower vertA horizA heightA widthA a
-> Full vertB horizB heightB widthB a
-> Full vertB horizB heightB widthB a
multiplyTriangular (Char
normalPart,Char
transposedPart) Char
diag Transposition
transposed
   (Array (MatrixShape.Split lower
_ Order
orderA Extent vertA horizA heightA widthA
extentA) ForeignPtr a
a)
   (Array (MatrixShape.Full Order
orderB Extent vertB horizB heightB widthB
extentB) ForeignPtr a
b) =

   Full vertB horizB heightB widthB
-> (Ptr a -> IO ()) -> Full vertB horizB heightB widthB a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (Order
-> Extent vertB horizB heightB widthB
-> Full vertB horizB heightB widthB
forall vert horiz height width.
Order
-> Extent vert horiz height width -> Full vert horiz height width
MatrixShape.Full Order
orderB Extent vertB horizB heightB widthB
extentB) ((Ptr a -> IO ()) -> Full vertB horizB heightB widthB a)
-> (Ptr a -> IO ()) -> Full vertB horizB heightB widthB a
forall a b. (a -> b) -> a -> b
$ \Ptr a
cPtr -> do

   let (heightA
heightA,widthA
widthA) = Extent vertA horizA heightA widthA -> (heightA, widthA)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vertA horizA heightA widthA
extentA
   let (heightB
heightB,widthB
widthB) = Extent vertB horizB heightB widthB -> (heightB, widthB)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vertB horizB heightB widthB
extentB
   let transOrderB :: Order
transOrderB = Transposition -> Order -> Order
transposeOrder Transposition
transposed Order
orderB
   let ((Char
uplo, Order
transa), Int
lda) =
         case Order
orderA of
            Order
RowMajor ->
               ((Char
transposedPart, Order -> Order
flipOrder Order
transOrderB), widthA -> Int
forall sh. C sh => sh -> Int
Shape.size widthA
widthA)
            Order
ColumnMajor ->
               ((Char
normalPart, Order
transOrderB), heightA -> Int
forall sh. C sh => sh -> Int
Shape.size heightA
heightA)
   let (Char
side,(Int
m,Int
n)) =
         Order -> (Int, Int) -> (Char, (Int, Int))
forall a. Order -> (a, a) -> (Char, (a, a))
sideSwapFromOrder Order
orderB (heightB -> Int
forall sh. C sh => sh -> Int
Shape.size heightB
heightB, widthB -> Int
forall sh. C sh => sh -> Int
Shape.size widthB
widthB)
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
sidePtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
side
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
uplo
      Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
transa
      Ptr CChar
diagPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
diag
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
      Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
m
      Ptr a
alphaPtr <- a -> ContT () IO (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Ptr a
bPtr Ptr a
cPtr
         Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.trmm Ptr CChar
sidePtr Ptr CChar
uploPtr Ptr CChar
transaPtr Ptr CChar
diagPtr
            Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
cPtr Ptr CInt
ldcPtr


wideSolveL ::
   (Extent.C horizA, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   Split Triangle Extent.Small horizA height width a ->
   Full vert horiz height nrhs a -> Full vert horiz height nrhs a
wideSolveL :: Transposition
-> Conjugation
-> Split Triangle Small horizA height width a
-> Full vert horiz height nrhs a
-> Full vert horiz height nrhs a
wideSolveL Transposition
transposed Conjugation
conjugated
      (Array (MatrixShape.Split Triangle
_ Order
orderA Extent Small horizA height width
extentA) ForeignPtr a
a) =
   let heightA :: height
heightA = Extent Small horizA height width -> height
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> height
Extent.height Extent Small horizA height width
extentA
   in [Char]
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height nrhs a
-> Full vert horiz height nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
[Char]
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver [Char]
"Split.wideSolveL" height
heightA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz height nrhs a -> Full vert horiz height nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height nrhs a
-> Full vert horiz height nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do

      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder (Order -> Char) -> Order -> Char
forall a b. (a -> b) -> a -> b
$ Order -> Order
flipOrder Order
orderA
      Ptr CChar
diagPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'U'
      let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
heightA
      Transposition
-> Conjugation
-> Order
-> Int
-> Int
-> ForeignPtr a
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> ContT () IO ()
forall a r.
Floating a =>
Transposition
-> Conjugation
-> Order
-> Int
-> Int
-> ForeignPtr a
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> ContT r IO ()
solveTriangular Transposition
transposed Conjugation
conjugated Order
orderA Int
m Int
n ForeignPtr a
a
         Ptr CChar
uploPtr Ptr CChar
diagPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr

tallSolveR ::
   (Extent.C vertA, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   Split lower vertA Extent.Small height width a ->
   Full vert horiz width nrhs a -> Full vert horiz width nrhs a
tallSolveR :: Transposition
-> Conjugation
-> Split lower vertA Small height width a
-> Full vert horiz width nrhs a
-> Full vert horiz width nrhs a
tallSolveR Transposition
transposed Conjugation
conjugated
      (Array (MatrixShape.Split lower
_ Order
orderA Extent vertA Small height width
extentA) ForeignPtr a
a) =
   let (height
heightA,width
widthA) = Extent vertA Small height width -> (height, width)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vertA Small height width
extentA
   in [Char]
-> width
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz width nrhs a
-> Full vert horiz width nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
[Char]
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver [Char]
"Split.tallSolveR" width
widthA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz width nrhs a -> Full vert horiz width nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz width nrhs a
-> Full vert horiz width nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do

      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
orderA
      Ptr CChar
diagPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
heightA
      Transposition
-> Conjugation
-> Order
-> Int
-> Int
-> ForeignPtr a
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> ContT () IO ()
forall a r.
Floating a =>
Transposition
-> Conjugation
-> Order
-> Int
-> Int
-> ForeignPtr a
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> ContT r IO ()
solveTriangular Transposition
transposed Conjugation
conjugated Order
orderA Int
m Int
n ForeignPtr a
a
         Ptr CChar
uploPtr Ptr CChar
diagPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr

solveTriangular ::
   Class.Floating a =>
   Transposition -> Conjugation ->
   Order -> Int -> Int -> ForeignPtr a ->
   Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr CInt -> ContT r IO ()
solveTriangular :: Transposition
-> Conjugation
-> Order
-> Int
-> Int
-> ForeignPtr a
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> ContT r IO ()
solveTriangular Transposition
transposed Conjugation
conjugated Order
orderA Int
m Int
n ForeignPtr a
a
   Ptr CChar
uploPtr Ptr CChar
diagPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr = do
      let (Char
trans, ContT r IO (Ptr a)
getA) =
            case (Transposition -> Order -> Order
transposeOrder Transposition
transposed Order
orderA, Conjugation
conjugated) of
               (Order
RowMajor, Conjugation
NonConjugated) -> (Char
'T', ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a)
               (Order
RowMajor, Conjugation
Conjugated) -> (Char
'C', ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a)
               (Order
ColumnMajor, Conjugation
NonConjugated) -> (Char
'N', ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a)
               (Order
ColumnMajor, Conjugation
Conjugated) -> (Char
'N', Int -> ForeignPtr a -> ContT r IO (Ptr a)
forall a r. Floating a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) ForeignPtr a
a)
      Ptr CChar
transPtr <- Char -> FortranIO r (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
trans
      Ptr a
aPtr <- ContT r IO (Ptr a)
getA
      Ptr CInt
ldaPtr <- Int -> FortranIO r (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim (Int -> FortranIO r (Ptr CInt)) -> Int -> FortranIO r (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ case Order
orderA of Order
ColumnMajor -> Int
m; Order
RowMajor -> Int
n
      IO () -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r IO ()) -> IO () -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$
         [Char] -> (Ptr CInt -> IO ()) -> IO ()
withInfo [Char]
"trtrs" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.trtrs Ptr CChar
uploPtr Ptr CChar
transPtr Ptr CChar
diagPtr
               Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
ldxPtr


data Corrupt = Corrupt
   deriving (Corrupt -> Corrupt -> Bool
(Corrupt -> Corrupt -> Bool)
-> (Corrupt -> Corrupt -> Bool) -> Eq Corrupt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Corrupt -> Corrupt -> Bool
$c/= :: Corrupt -> Corrupt -> Bool
== :: Corrupt -> Corrupt -> Bool
$c== :: Corrupt -> Corrupt -> Bool
Eq)

{-
We could use Plain.Class.ShapeOrder
but this would currently cause an import cycle.
-}
{- |
> let b = takeHalf a
> ==>
> isTriangular b && a == addTransposed b
-}
takeHalf ::
   (Box.Box symShape, Box.HeightOf symShape ~ sh, Shape.C sh,
    Class.Floating a) =>
   (symShape -> Order) -> Array symShape a -> Square Corrupt sh a
takeHalf :: (symShape -> Order) -> Array symShape a -> Square Corrupt sh a
takeHalf symShape -> Order
shapeOrder (Array symShape
symShape ForeignPtr a
a) =
   let sh :: HeightOf symShape
sh = symShape -> HeightOf symShape
forall shape. Box shape => shape -> HeightOf shape
Box.height symShape
symShape
       order :: Order
order = symShape -> Order
shapeOrder symShape
symShape
   in Split Corrupt Small Small sh sh
-> (Ptr a -> IO ()) -> Square Corrupt sh a
forall sh a.
(C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> Array sh a
Array.unsafeCreate (Corrupt
-> Order
-> Extent Small Small sh sh
-> Split Corrupt Small Small sh sh
forall lower vert horiz height width.
lower
-> Order
-> Extent vert horiz height width
-> Split lower vert horiz height width
MatrixShape.Split Corrupt
Corrupt Order
order (sh -> Extent Small Small sh sh
forall sh. sh -> Square sh
Extent.square sh
HeightOf symShape
sh)) ((Ptr a -> IO ()) -> Square Corrupt sh a)
-> (Ptr a -> IO ()) -> Square Corrupt sh a
forall a b. (a -> b) -> a -> b
$
         \Ptr a
bPtr -> ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
HeightOf symShape
sh
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
alphaPtr <- a -> ContT () IO (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
0.5
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Order -> Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
unpack Order
order Int
n Ptr a
aPtr Ptr a
bPtr
         Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
BlasGen.scal Ptr CInt
nPtr Ptr a
alphaPtr Ptr a
bPtr Ptr CInt
incxPtr