{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Matrix.Static.IO ( fromMM , fromMM' , toMM , IOElement(..) ) where import qualified Data.ByteString.Char8 as B import Conduit import Control.Monad (when) import qualified Data.Vector.Generic as G import Data.Matrix.Dynamic (Dynamic(..)) import qualified Data.Vector.Unboxed as U import Data.ByteString.Lex.Fractional (readExponential, readSigned) import Data.ByteString.Lex.Integral (readDecimal, readDecimal_) import Data.Singletons import Data.Maybe import Data.Singletons.TypeLits import Data.Double.Conversion.ByteString (toShortest) import Text.Printf (printf) import qualified Data.Matrix.Static.Sparse as S data MMElem = MMReal | MMComplex | MMInteger | MMPattern deriving (Eq) class U.Unbox a => IOElement a where decodeElem :: B.ByteString -> a encodeElem :: a -> B.ByteString elemType :: Proxy a -> MMElem instance IOElement Int where decodeElem x = fst . fromMaybe errMsg . readSigned readDecimal $ x where errMsg = error $ "readInt: Fail to cast ByteString to Int:" ++ show x encodeElem = B.pack . show elemType _ = MMInteger instance IOElement Double where decodeElem x = fst . fromMaybe errMsg . readSigned readExponential $ x where errMsg = error $ "readDouble: Fail to cast ByteString to Double:" ++ show x encodeElem = toShortest elemType _ = MMReal fromMM' :: forall o m v a. (PrimMonad m, G.Vector v a, IOElement a) => ConduitT B.ByteString o m (Dynamic S.SparseMatrix v a) fromMM' = linesUnboundedAsciiC .| do (ty, (r,c,nnz)) <- parseHeader when (elemType (Proxy :: Proxy a) /= ty) $ error "Element types do not match" vec <- streamTriplet .| sinkVector when (U.length vec /= nnz) $ error $ "number of non-zeros do not match: " <> show nnz <> "/=" <> show (U.length vec) withSomeSing (fromIntegral (r :: Int)) $ \(SNat :: Sing r) -> withSomeSing (fromIntegral (c :: Int)) $ \(SNat :: Sing c) -> return $ Dynamic (S.fromTriplet vec :: S.SparseMatrix r c v a) fromMM :: forall o m r c v a. (PrimMonad m, SingI r, SingI c, G.Vector v a, IOElement a) => ConduitT B.ByteString o m (S.SparseMatrix r c v a) fromMM = linesUnboundedAsciiC .| do (ty, (r,c,nnz)) <- parseHeader mat@(S.SparseMatrix v _ _) <- case () of _ | elemType (Proxy :: Proxy a) /= ty -> error "Element types do not match" | (r, c) /= (nrow, ncol) -> error $ "Dimensions do not match: " <> show (r,c) <> "/=" <> show (nrow,ncol) | otherwise -> do vec <- streamTriplet .| sinkVector return $ S.fromTriplet (vec :: U.Vector (Int, Int, a)) let n = G.length v if n /= nnz then error $ "number of non-zeros do not match: " <> show nnz <> "/=" <> show n else return mat where nrow = fromIntegral $ fromSing (sing :: Sing r) :: Int ncol = fromIntegral $ fromSing (sing :: Sing c) :: Int toMM :: forall m r c v a i. (Monad m, S.Zero a, IOElement a, G.Vector v a) => S.SparseMatrix r c v a -> ConduitT i B.ByteString m () toMM mat@(S.SparseMatrix vec _ _) = ( do yield header yield "%" yield $ B.pack $ printf "%d %d %d" r c n S.toTriplet mat .| mapC f ) .| unlinesAsciiC where f (i, j, x) = B.unwords [B.pack $ show (i+1), B.pack $ show (j+1), encodeElem x] header = case elemType (Proxy :: Proxy a) of MMReal -> "%%MatrixMarket matrix coordinate real general" MMInteger -> "%%MatrixMarket matrix coordinate integer general" _ -> undefined (r, c) = S.dim mat n = G.length vec parseHeader :: Monad m => ConduitT B.ByteString o m (MMElem, (Int, Int, Int)) parseHeader = do ty <- headC >>= \case Nothing -> error "Empty file" Just header -> return $ parse header dropWhileC $ (=='%') . B.head headC >>= \case Nothing -> error "Empty file" Just x -> let [r, c, nnz] = map decodeElem $ B.words x in return (ty, (r, c, nnz)) where parse x | "%%MatrixMarket" `B.isPrefixOf` x = case B.words x of [_, _, format, ty, form] -> let ty' = case ty of "real" -> MMReal "complex" -> MMComplex "integer" -> MMInteger "pattern" -> MMPattern t -> error $ "Unknown type: " <> show t in ty' _ -> error $ "Cannot parse header: " <> show x | otherwise = error $ "Cannot parse header: " <> show x {-# INLINE parseHeader #-} streamTriplet :: (Monad m, IOElement a) => ConduitT B.ByteString (Int, Int, a) m () streamTriplet = mapC (f . B.words) where f [i,j,x] = (readDecimal_ i - 1, readDecimal_ j - 1, decodeElem x) f x = error $ "Formatting error: " <> show x {-# INLINE streamTriplet #-}