module Numerical.HBLAS.BLAS where
--(
import Numerical.HBLAS.UtilsFFI
import Numerical.HBLAS.BLAS.FFI
import Numerical.HBLAS.MatrixTypes
import Control.Monad.Primitive
import Data.Complex
import qualified Data.Vector.Storable.Mutable as SM
flopsThreshold = 10000
gemmComplexity a b c = a * b * c
isBadGemm tra trb ax ay bx by cx cy = isBadGemmHelper (cds tra (ax,ay)) (cds trb (bx,by) ) (cx,cy)
where
cds = coordSwapper
isBadGemmHelper !(ax,ay) !(bx,by) !(cx,cy) = (minimum [ax, ay, bx, by, cx ,cy] <= 0)
|| not ( cy == ay && cx == bx && ax == by)
coordSwapper :: Transpose -> (a,a)-> (a,a)
coordSwapper NoTranspose (a,b) = (a,b)
coordSwapper ConjNoTranspose (a,b) = (a,b)
coordSwapper Transpose (a,b) = (b,a)
coordSwapper ConjTranspose (a,b) = (b,a)
encodeNiceOrder :: SOrientation x -> CBLAS_ORDERT
encodeNiceOrder SRow= encodeOrder BLASRowMajor
encodeNiceOrder SColumn= encodeOrder BLASColMajor
encodeFFITranpose :: Transpose -> CBLAS_TRANSPOSET
encodeFFITranpose x= encodeTranpose $ encodeNiceTranpose x
encodeNiceTranpose :: Transpose -> BLAS_Transpose
encodeNiceTranpose x = case x of
NoTranspose -> BlasNoTranspose
Transpose -> BlasTranspose
ConjTranspose -> BlasConjTranspose
ConjNoTranspose -> BlasConjNoTranspose
type GemmFun el orient s m = Transpose ->Transpose -> el -> el -> MutDenseMatrix s orient el
-> MutDenseMatrix s orient el -> MutDenseMatrix s orient el -> m ()
gemmAbstraction:: (SM.Storable el, PrimMonad m) => String ->
GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m
gemmAbstraction gemmName gemmSafeFFI gemmUnsafeFFI constHandler = go
where
shouldCallFast :: Int -> Int -> Int -> Bool
shouldCallFast cy cx ax = flopsThreshold >= gemmComplexity cy cx ax
go tra trb alpha beta
(MutableDenseMatrix ornta ax ay astride abuff)
(MutableDenseMatrix _ bx by bstride bbuff)
(MutableDenseMatrix _ cx cy cstride cbuff)
| isBadGemm tra trb ax ay bx by cx cy = error $! "bad dimension args to GEMM: ax ay bx by cx cy: " ++ show [ax, ay, bx, by, cx ,cy]
| SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff =
error $ "the read and write inputs for: " ++ gemmName ++ " overlap. This is a programmer error. Please fix."
| otherwise =
unsafeWithPrim abuff $ \ap ->
unsafeWithPrim bbuff $ \bp ->
unsafeWithPrim cbuff $ \cp ->
constHandler alpha $ \alphaPtr ->
constHandler beta $ \betaPtr ->
do (ax,ay) <- return $ coordSwapper tra (ax,ay)
blasOrder <- return $ encodeNiceOrder ornta
rawTra <- return $ encodeFFITranpose tra
rawTrb <- return $ encodeFFITranpose trb
unsafePrimToPrim $! (if shouldCallFast cy cx ax then gemmUnsafeFFI else gemmSafeFFI )
blasOrder rawTra rawTrb (fromIntegral cy) (fromIntegral cx) (fromIntegral ax)
alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) betaPtr cp (fromIntegral cstride)
sgemm :: PrimMonad m=>
Transpose ->Transpose -> Float -> Float -> MutDenseMatrix (PrimState m) orient Float
-> MutDenseMatrix (PrimState m) orient Float -> MutDenseMatrix (PrimState m) orient Float -> m ()
sgemm = gemmAbstraction "sgemm" cblas_sgemm_unsafe cblas_sgemm_safe (\x f -> f x )
dgemm :: PrimMonad m=>
Transpose ->Transpose -> Double -> Double -> MutDenseMatrix (PrimState m) orient Double
-> MutDenseMatrix (PrimState m) orient Double -> MutDenseMatrix (PrimState m) orient Double -> m ()
dgemm = gemmAbstraction "dgemm" cblas_dgemm_unsafe cblas_dgemm_safe (\x f -> f x )
cgemm :: PrimMonad m=> Transpose ->Transpose -> (Complex Float) -> (Complex Float) ->
MutDenseMatrix (PrimState m) orient (Complex Float) ->
MutDenseMatrix (PrimState m) orient (Complex Float) ->
MutDenseMatrix (PrimState m) orient (Complex Float) -> m ()
cgemm = gemmAbstraction "cgemm" cblas_cgemm_unsafe cblas_cgemm_safe withRStorable_
zgemm :: PrimMonad m=> Transpose ->Transpose -> (Complex Double) -> (Complex Double ) ->
MutDenseMatrix (PrimState m) orient (Complex Double ) ->
MutDenseMatrix (PrimState m) orient (Complex Double) ->
MutDenseMatrix (PrimState m) orient (Complex Double) -> m ()
zgemm = gemmAbstraction "zgemm" cblas_zgemm_unsafe cblas_zgemm_safe withRStorable_