module Internal.Matrix where
import Internal.Vector
import Internal.Devel
import Internal.Vectorized hiding ((#), (#!))
import Foreign.Marshal.Alloc ( free )
import Foreign.Marshal.Array(newArray)
import Foreign.Ptr ( Ptr )
import Foreign.Storable ( Storable )
import Data.Complex ( Complex )
import Foreign.C.Types ( CInt(..) )
import Foreign.C.String ( CString, newCString )
import System.IO.Unsafe ( unsafePerformIO )
import Control.DeepSeq ( NFData(..) )
import Text.Printf
data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
data Matrix t = Matrix
{ irows :: !Int
, icols :: !Int
, xRow :: !Int
, xCol :: !Int
, xdat :: !(Vector t)
}
rows :: Matrix t -> Int
rows = irows
cols :: Matrix t -> Int
cols = icols
size m = (irows m, icols m)
rowOrder m = xCol m == 1 || cols m == 1
colOrder m = xRow m == 1 || rows m == 1
is1d (size->(r,c)) = r==1 || c==1
isSlice m@(size->(r,c)) = r*c < dim (xdat m)
orderOf :: Matrix t -> MatrixOrder
orderOf m = if rowOrder m then RowMajor else ColumnMajor
showInternal :: Storable t => Matrix t -> IO ()
showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv
where
r = rows m
c = cols m
xr = xRow m
xc = xCol m
slc = if isSlice m then "slice" else "full"
ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols"
dv = dim (xdat m)
trans :: Matrix t -> Matrix t
trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
m { irows = c, icols = r, xRow = xc, xCol = xr }
cmat :: (Element t) => Matrix t -> Matrix t
cmat m
| rowOrder m = m
| otherwise = extractAll RowMajor m
fmat :: (Element t) => Matrix t -> Matrix t
fmat m
| colOrder m = m
| otherwise = extractAll ColumnMajor m
amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
amatr x f g = unsafeWith (xdat x) (f . g r c)
where
r = fi (rows x)
c = fi (cols x)
amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
where
r = fi (rows x)
c = fi (cols x)
sr = fi (xRow x)
sc = fi (xCol x)
instance Storable t => TransArray (Matrix t)
where
type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
apply = amat
applyRaw = amatr
infixr 1 #
a # b = apply a b
a #! b = a # b # id
copy ord m = extractR ord m 0 (idxs[0,rows m1]) 0 (idxs[0,cols m1])
extractAll ord m = unsafePerformIO (copy ord m)
flatten :: Element t => Matrix t -> Vector t
flatten m
| isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m)
| otherwise = xdat m
toLists :: (Element t) => Matrix t -> [[t]]
toLists = map toList . toRows
compatdim :: [Int] -> Maybe Int
compatdim [] = Nothing
compatdim [a] = Just a
compatdim (a:b:xs)
| a==b = compatdim (b:xs)
| a==1 = compatdim (b:xs)
| b==1 = compatdim (a:xs)
| otherwise = Nothing
fromRows :: Element t => [Vector t] -> Matrix t
fromRows [] = emptyM 0 0
fromRows vs = case compatdim (map dim vs) of
Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
Just 0 -> emptyM r 0
Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
where
r = length vs
adapt c v
| c == 0 = fromList[]
| dim v == c = v
| otherwise = constantD (v@>0) c
toRows :: Element t => Matrix t -> [Vector t]
toRows m
| rowOrder m = map sub rowRange
| otherwise = map ext rowRange
where
rowRange = [0..rows m1]
sub k = subVector (k*xRow m) (cols m) (xdat m)
ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m1])
fromColumns :: Element t => [Vector t] -> Matrix t
fromColumns m = trans . fromRows $ m
toColumns :: Element t => Matrix t -> [Vector t]
toColumns m = toRows . trans $ m
(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
infixl 9 @@>
m@Matrix {irows = r, icols = c} @@> (i,j)
| i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
| otherwise = atM' m i j
atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
matrixFromVector o r c v
| r * c == dim v = m
| otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
where
m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
| otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix ord r c = do
p <- createVector (r*c)
return (matrixFromVector ord r c p)
reshape :: Storable t => Int -> Vector t -> Matrix t
reshape 0 v = matrixFromVector RowMajor 0 0 v
reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
liftMatrix f m@Matrix { irows = r, icols = c, xdat = d}
| isSlice m = matrixFromVector RowMajor r c (f (flatten m))
| otherwise = matrixFromVector (orderOf m) r c (f d)
liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
liftMatrix2 f m1@(size->(r,c)) m2
| (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2"
| rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2))
| otherwise = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2)))
class (Storable a) => Element a where
constantD :: a -> Int -> Vector a
extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
sortI :: Ord a => Vector a -> Vector CInt
sortV :: Ord a => Vector a -> Vector a
compareV :: Ord a => Vector a -> Vector a -> Vector CInt
selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
instance Element Float where
constantD = constantAux cconstantF
extractR = extractAux c_extractF
setRect = setRectAux c_setRectF
sortI = sortIdxF
sortV = sortValF
compareV = compareF
selectV = selectF
remapM = remapF
rowOp = rowOpAux c_rowOpF
gemm = gemmg c_gemmF
instance Element Double where
constantD = constantAux cconstantR
extractR = extractAux c_extractD
setRect = setRectAux c_setRectD
sortI = sortIdxD
sortV = sortValD
compareV = compareD
selectV = selectD
remapM = remapD
rowOp = rowOpAux c_rowOpD
gemm = gemmg c_gemmD
instance Element (Complex Float) where
constantD = constantAux cconstantQ
extractR = extractAux c_extractQ
setRect = setRectAux c_setRectQ
sortI = undefined
sortV = undefined
compareV = undefined
selectV = selectQ
remapM = remapQ
rowOp = rowOpAux c_rowOpQ
gemm = gemmg c_gemmQ
instance Element (Complex Double) where
constantD = constantAux cconstantC
extractR = extractAux c_extractC
setRect = setRectAux c_setRectC
sortI = undefined
sortV = undefined
compareV = undefined
selectV = selectC
remapM = remapC
rowOp = rowOpAux c_rowOpC
gemm = gemmg c_gemmC
instance Element (CInt) where
constantD = constantAux cconstantI
extractR = extractAux c_extractI
setRect = setRectAux c_setRectI
sortI = sortIdxI
sortV = sortValI
compareV = compareI
selectV = selectI
remapM = remapI
rowOp = rowOpAux c_rowOpI
gemm = gemmg c_gemmI
instance Element Z where
constantD = constantAux cconstantL
extractR = extractAux c_extractL
setRect = setRectAux c_setRectL
sortI = sortIdxL
sortV = sortValL
compareV = compareL
selectV = selectL
remapM = remapL
rowOp = rowOpAux c_rowOpL
gemm = gemmg c_gemmL
subMatrix :: Element a
=> (Int,Int)
-> (Int,Int)
-> Matrix a
-> Matrix a
subMatrix (r0,c0) (rt,ct) m
| rt <= 0 || ct <= 0 = matrixFromVector RowMajor (max 0 rt) (max 0 ct) (fromList [])
| 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
0 <= c0 && 0 <= ct && c0+ct <= cols m = res
| otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m
where
p = r0 * xRow m + c0 * xCol m
tot | rowOrder m = ct + (rt1) * xRow m
| otherwise = rt + (ct1) * xCol m
res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) }
maxZ xs = if minimum xs == 0 then 0 else maximum xs
conformMs ms = map (conformMTo (r,c)) ms
where
r = maxZ (map rows ms)
c = maxZ (map cols ms)
conformVs vs = map (conformVTo n) vs
where
n = maxZ (map dim vs)
conformMTo (r,c) m
| size m == (r,c) = m
| size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
| size m == (r,1) = repCols c m
| size m == (1,c) = repRows r m
| otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
conformVTo n v
| dim v == n = v
| dim v == 1 = constantD (v@>0) n
| otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
repRows n x = fromRows (replicate n (flatten x))
repCols n x = fromColumns (replicate n (flatten x))
shSize = shDim . size
shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
emptyM r c = matrixFromVector RowMajor r c (fromList[])
instance (Storable t, NFData t) => NFData (Matrix t)
where
rnf m | d > 0 = rnf (v @> 0)
| otherwise = ()
where
d = dim v
v = xdat m
extractAux f ord m moder vr modec vc = do
let nr = if moder == 0 then fromIntegral $ vr@>1 vr@>0 + 1 else dim vr
nc = if modec == 0 then fromIntegral $ vc@>1 vc@>0 + 1 else dim vc
r <- createMatrix ord nr nc
(vr # vc # m #! r) (f moder modec) #|"extract"
return r
type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
foreign import ccall unsafe "extractD" c_extractD :: Extr Double
foreign import ccall unsafe "extractF" c_extractF :: Extr Float
foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
foreign import ccall unsafe "extractL" c_extractL :: Extr Z
setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
type SetRect x = I -> I -> x ::> x::> Ok
foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double
foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float
foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double)
foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float)
foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I
foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
sortG f v = unsafePerformIO $ do
r <- createVector (dim v)
(v #! r) f #|"sortG"
return r
sortIdxD = sortG c_sort_indexD
sortIdxF = sortG c_sort_indexF
sortIdxI = sortG c_sort_indexI
sortIdxL = sortG c_sort_indexL
sortValD = sortG c_sort_valD
sortValF = sortG c_sort_valF
sortValI = sortG c_sort_valI
sortValL = sortG c_sort_valL
foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt))
foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt))
foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt))
foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt))
foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
compareG f u v = unsafePerformIO $ do
r <- createVector (dim v)
(u # v #! r) f #|"compareG"
return r
compareD = compareG c_compareD
compareF = compareG c_compareF
compareI = compareG c_compareI
compareL = compareG c_compareL
foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
selectG f c u v w = unsafePerformIO $ do
r <- createVector (dim v)
(c # u # v # w #! r) f #|"selectG"
return r
selectD = selectG c_selectD
selectF = selectG c_selectF
selectI = selectG c_selectI
selectL = selectG c_selectL
selectC = selectG c_selectC
selectQ = selectG c_selectQ
type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
remapG f i j m = unsafePerformIO $ do
r <- createMatrix RowMajor (rows i) (cols i)
(i # j # m #! r) f #|"remapG"
return r
remapD = remapG c_remapD
remapF = remapG c_remapF
remapI = remapG c_remapI
remapL = remapG c_remapL
remapC = remapG c_remapC
remapQ = remapG c_remapQ
type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
foreign import ccall unsafe "remapD" c_remapD :: Rem Double
foreign import ccall unsafe "remapF" c_remapF :: Rem Float
foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
foreign import ccall unsafe "remapL" c_remapL :: Rem Z
rowOpAux f c x i1 i2 j1 j2 m = do
px <- newArray [x]
(m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
free px
type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C
foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float)
foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
type Tgemm x = x :> x ::> x ::> x ::> Ok
foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
foreign import ccall unsafe "saveMatrix" c_saveMatrix
:: CString -> CString -> Double ::> Ok
saveMatrix
:: FilePath
-> String
-> Matrix Double
-> IO ()
saveMatrix name format m = do
cname <- newCString name
cformat <- newCString format
(m # id) (c_saveMatrix cname cformat) #|"saveMatrix"
free cname
free cformat
return ()