{-# LANGUAGE GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.TDS.Query where
import qualified Database.TDS.Proto as Proto
import Database.TDS.Types
import Control.Exception ( Exception, SomeException(..)
, bracket, onException
, throwIO, catch, mask )
import Data.Bifunctor
import Data.Bits
import qualified Data.ByteString.Streaming as SBS
import qualified Data.ByteString.Internal as IBS
import Data.Foldable
import Data.Maybe
import Data.Ratio
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.IO as T
import Data.Word
import Debug.Trace
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import Numeric
import System.IO
import qualified Streaming as S
import qualified Streaming.Prelude as S
newtype MsSQLRuntimeError = MsSQLRuntimeError T.Text
deriving Show
instance Exception MsSQLRuntimeError
withTransaction :: Connection -> IO a -> IO a
withTransaction conn go =
mask $ \unmask -> do
beginTransaction conn
res <- (unmask (fmap Right go <*
commitTransaction conn)
`catch` (\e@(MsSQLRuntimeError {}) -> pure (Left e)))
`catch` (\e@(SomeException {}) -> hPutStrLn stderr ("Caught exception " ++ show e) >> rollbackTransaction conn >> throwIO e)
case res of
Right x -> pure x
Left err -> throwIO err
execNoRows :: Connection -> T.Text -> IO ()
execNoRows tds q = do
let sql = Proto.mkPacket (Proto.mkPacketHeader Proto.SQLBatch mempty) q
getRes <- tdsSendPacket tds sql
ResponseResultReceived (Proto.RowResults rows) <- getRes
res <- S.inspect rows
case res of
Left () -> pure ()
Right {} -> throwIO (MsSQLRuntimeError (T.pack ("Expected no rows for statement: " ++ T.unpack q)))
beginTransaction, commitTransaction, rollbackTransaction :: Connection -> IO ()
beginTransaction tds = execNoRows tds "BEGIN TRANSACTION"
commitTransaction tds = execNoRows tds "COMMIT TRANSACTION"
rollbackTransaction tds = execNoRows tds "ROLLBACK TRANSACTION"
query :: Connection -> T.Text -> IO ()
query conn sqlTxt = do
let sqlBatch = Proto.mkPacket (Proto.mkPacketHeader Proto.SQLBatch mempty)
sqlTxt
getRes <- tdsSendPacket conn sqlBatch
ResponseResultReceived (Proto.RowResults rows) <- getRes
S.mapsM_ (\(S.Compose (cols S.:> rows')) -> do
S.mapsM_ (\row -> S.mapsM_ (\(Proto.RawColumn columnType columnData next) -> do
bs' <- S.liftIO (printColumn columnType columnData)
pure (next bs')) row) rows')
rows
pure ()
take8 :: Monad m => SBS.ByteString m () -> m (Word8, SBS.ByteString m ())
take8 bs = do
r <- SBS.uncons bs
case r of
Nothing -> fail "take8: no more bytes"
Just (r, bs') -> pure (r, bs')
take16LE :: S.MonadIO m => SBS.ByteString m () -> m (Word16, SBS.ByteString m ())
take16LE bs = do
a S.:> bs' <- SBS.toStrict (SBS.splitAt 2 bs)
let (fPtr, ofs, _) = IBS.toForeignPtr a
x <- S.liftIO . withForeignPtr fPtr $ \ptr ->
peek (ptr `plusPtr` ofs)
pure (x, bs')
takeLength :: S.MonadIO m => Proto.TypeLen -> SBS.ByteString m () -> m (Word16, SBS.ByteString m ())
takeLength Proto.ShortLen = take16LE
takeLength Proto.ByteLen = fmap (first fromIntegral) . take16LE
takeLE :: Monad m => Int -> SBS.ByteString m () -> m (Integer, SBS.ByteString m ())
takeLE n bs = foldlM (\(!a, bs') shift -> do
(x, bs'') <- take8 bs'
pure (a .|. (fromIntegral x `shiftL` shift), bs''))
(0, bs) (fmap (*8) [0..n-1])
printNumeric :: String -> Bool -> Word8 -> Proto.PrecScale -> SBS.ByteString IO ()
-> IO (SBS.ByteString IO ())
printNumeric s True sz precScale d = do
(realSz, d') <- take8 d
if realSz == 0
then do
putStrLn (s ++ "(" ++ show sz ++ ", " ++ show precScale ++ "): (NULL)")
pure d'
else printNumeric s False realSz precScale d'
printNumeric s False sz precScale@(Proto.PrecScale p scale) d = do
putStr (s ++ "(" ++ show sz ++ ", " ++ show precScale ++ "):")
(sign, d') <- take8 d
let intSz | p <= 9 = 4
| p <= 19 = 8
| p <= 28 = 12
| otherwise = 16
(num, d'') <- takeLE (fromIntegral sz - 1) d'
let res = num % (10 ^ fromIntegral scale)
res' :: Rational
res' = if sign == 0 then negate res else res
putStrLn (show res')
pure d''
printColumn :: Proto.ColumnData -> SBS.ByteString IO () -> IO (SBS.ByteString IO ())
printColumn ty d =
case Proto.cdBaseTypeInfo ty of
Proto.VarcharType typeLen Proto.NationalChar len coll -> do
(len, d') <- takeLength typeLen d
if len == 0xFFFF
then do
putStrLn "(NULL)"
pure d'
else do
let d'' = SBS.splitAt (fromIntegral len) d'
byteData S.:> d''' <- SBS.toStrict d''
T.putStrLn (TE.decodeUtf16LE byteData)
pure d'''
Proto.IntNType False 4 -> do
(n, d') <-takeLE 4 d
pure d'
Proto.IntNType False bytes -> do
(n, d') <- takeLE (fromIntegral bytes) d
pure d'
Proto.IntNType True bytes -> do
(realWidth, d') <- take8 d
if realWidth == 0
then do
putStrLn ("INT(" ++ show bytes ++ "): (NULL)")
pure d'
else do
(n, d'') <- takeLE (fromIntegral bytes) d'
putStrLn ("INT(" ++ show bytes ++ "): " ++ show n)
pure d''
Proto.DecimalNType nullable sz precScale ->
printNumeric "DECIMAL" nullable sz precScale d
Proto.NumericNType nullable sz precScale ->
printNumeric "NUMERIC" nullable sz precScale d
_ -> fail ("Can't print data of type " ++ show ty)