{-# OPTIONS_GHC -funbox-strict-fields #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
module Database.TDS.Proto where
import Database.TDS.Proto.Errors
import Control.Exception (Exception)
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Tardis
import Control.Monad.Trans.Maybe
import Data.Attoparsec.Binary
import qualified Data.Attoparsec.ByteString as A hiding (Done)
import qualified Data.Attoparsec.ByteString.Streaming as S
import Data.Bits ( bit, (.|.), (.&.), xor
, shiftL, shiftR )
import qualified Data.ByteString as BS
import Data.ByteString.Builder
import Data.ByteString.Builder.Extra hiding (Done)
import qualified Data.ByteString.Internal as IBS
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Streaming as SBS
import Data.Int
import Data.Maybe
import Data.Monoid
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Foreign as T
import Data.Word
import qualified Data.Vector as V
import Debug.Trace
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Marshal.Utils (copyBytes)
import Foreign.Ptr
import Foreign.Storable
import Streaming (Compose(..))
import qualified Streaming as S
import qualified Streaming.Prelude as S
import qualified Streaming.Internal as S (inspect)
cancelPacketSize, maximumPayloadPacketSize :: CSize
cancelPacketSize = 1024
maximumPayloadPacketSize = 65536
pktHdrSz :: CSize
pktHdrSz = 8
data Sender
= Client
| Server
deriving Show
data ResponseInfo a
= ExpectsResponse (ResponseType a)
| NoResponse
deriving Show
data ResponseType a
= ResponseType Bool a
deriving Show
data StreamType
= TokenlessStream
| TokenStream
deriving Show
data Unimplemented
instance Show Unimplemented where
show = error "Unimplemented"
newtype SPID = SPID Word16
deriving (Show, Eq)
data PacketType (sender :: Sender) (resp :: ResponseInfo *) (d :: *) where
PreLogin :: PacketType 'Client
('ExpectsResponse ('ResponseType 'False PreLogin))
PreLogin
Login7 :: PacketType 'Client ('ExpectsResponse ('ResponseType 'False Login7Ack)) Login7
SQLBatch :: PacketType 'Client ('ExpectsResponse ('ResponseType 'True RowResults)) T.Text
BulkLoad :: PacketType 'Client 'NoResponse Unimplemented
RPC :: PacketType 'Client 'NoResponse Unimplemented
Attention :: PacketType 'Client ('ExpectsResponse ('ResponseType 'False ())) ()
TrMgrRequest :: PacketType 'Client 'NoResponse Unimplemented
TabularResult :: PacketType 'Server 'NoResponse a
deriving instance Show (PacketType sender resp d)
packetType :: PacketType sender resp d -> Word8
packetType PreLogin = 0x12
packetType Login7 = 0x10
packetType SQLBatch = 0x01
packetType BulkLoad = 0x07
packetType RPC = 0x03
packetType Attention = 0x06
packetType TrMgrRequest = 0x0E
packetType TabularResult = 0x04
newtype PacketStatus (s :: Sender) = PacketStatus Word8
deriving Show
instance Semigroup (PacketStatus s) where
(<>) = mappend
instance Monoid (PacketStatus s) where
mempty = PacketStatus 0
mappend (PacketStatus a) (PacketStatus b) =
PacketStatus (a .|. b)
hasStatus :: PacketStatus s -> PacketStatus s -> Bool
hasStatus (PacketStatus sts) (PacketStatus lookup) =
sts .&. lookup > 0
pktStatusEndOfMessage :: PacketStatus s
pktStatusEndOfMessage = PacketStatus (bit 0)
pktStatusIgnore, pktStatusResetConn, pktStatusResetConnSkipTran
:: PacketStatus 'Client
pktStatusIgnore = PacketStatus (bit 1)
pktStatusResetConn = PacketStatus (bit 2)
pktStatusResetConnSkipTran = PacketStatus (bit 3)
newtype PacketSequenceID = PacketSequenceID Word8
deriving (Show, Eq)
type family Packed (packed :: Bool) (a :: *) (b :: *) where
Packed 'True a b = a
Packed 'False a b = b
data PacketHeader (sender :: Sender) (resp :: ResponseInfo *) (d :: *) where
PacketHeader ::
{ pktHdrType :: !(PacketType sender resp d)
, pktHdrStatus :: !(PacketStatus sender)
, pktHdrLength :: !Word16
, pktHdrSPID :: !SPID
, pktHdrSeqID :: !PacketSequenceID
, pktHdrWindow :: !Word8
} -> PacketHeader sender resp d
deriving instance Show (PacketHeader sender resp d)
mkPacketHeader :: PacketType sender resp d -> PacketStatus sender
-> PacketHeader sender resp d
mkPacketHeader ty flags = PacketHeader ty flags 0 (SPID 0)
(PacketSequenceID 0) 0
writeHdr :: Ptr a -> PacketHeader 'Client resp d -> IO ()
writeHdr ptr (PacketHeader pktType (PacketStatus pktSts) len
(SPID spid16) (PacketSequenceID seqId)
window) =
do poke (castPtr ptr) (packetType pktType)
poke (castPtr ptr `plusPtr` 1) pktSts
be16 (castPtr ptr `plusPtr` 2) len
be16 (castPtr ptr `plusPtr` 4) spid16
poke (castPtr ptr `plusPtr` 6) seqId
poke (castPtr ptr `plusPtr` 7) window
where
be16 :: Ptr Word8 -> Word16 -> IO ()
be16 ptr x =
let lo, hi :: Word8
lo = fromIntegral (x .&. 0xFF)
hi = fromIntegral ((x `shiftR` 8) .&. 0xFF)
in poke (castPtr ptr) hi >>
poke (castPtr ptr `plusPtr` 1) lo
readHdr :: PacketType sender resp d
-> Ptr a -> IO (Maybe (PacketHeader sender resp d))
readHdr pktTy ptr =
do tag <- peek (castPtr ptr)
if tag /= packetType pktTy
then pure Nothing
else Just <$>
(PacketHeader pktTy
<$> (PacketStatus <$> peek (castPtr ptr `plusPtr` 1))
<*> be16 (ptr `plusPtr` 2)
<*> (SPID <$> be16 (ptr `plusPtr` 4))
<*> (PacketSequenceID <$> peek (castPtr ptr `plusPtr` 6))
<*> peek (castPtr ptr `plusPtr` 7))
where
be16 ptr = do
hi <- peek (castPtr ptr :: Ptr Word8)
lo <- peek (castPtr ptr `plusPtr` 1 :: Ptr Word8)
pure ((fromIntegral hi `shiftL` 8) .|. fromIntegral lo)
data Packet (sender :: Sender) (resp :: ResponseInfo *) (d :: *) (f :: * -> *) where
Packet ::
{ pktHdr :: !(PacketHeader sender resp d)
, pktData :: !(f d)
} -> Packet sender resp d f
mkPacket :: PacketHeader sender resp d -> d
-> Packet sender resp d Identity
mkPacket hdr d = Packet hdr (Identity d)
packetEncoding :: Payload d
=> Packet sender resp d Identity
-> Packet sender resp d PacketEncoding
packetEncoding (Packet hdr (Identity x)) =
Packet hdr (PacketEncoding (encodePayload x))
deriving instance Show d => Show (Packet sender resp d Identity)
displayRequest :: Show d => Packet sender resp d Identity -> String
displayRequest = show
data PayloadEncoder (streaming :: StreamType) where
PayloadStreamEncoder :: Builder -> PayloadEncoder 'TokenStream
PayloadBatchEncoder :: Word -> Builder -> PayloadEncoder 'TokenlessStream
newtype PacketEncoding pld
= PacketEncoding (PayloadEncoder (PayloadStreaming pld))
class Show pld => Payload pld where
type PayloadStreaming pld :: StreamType
encodePayload :: pld -> PayloadEncoder (PayloadStreaming pld)
instance Payload () where
type PayloadStreaming () = 'TokenlessStream
encodePayload _ = PayloadBatchEncoder 0 mempty
instance Payload Text where
type PayloadStreaming Text = 'TokenlessStream
encodePayload t = PayloadBatchEncoder (fromIntegral $ T.length t * 2)
(byteString (TE.encodeUtf16LE t))
runBatchEncoder :: Monoid fixup
=> Tardis fixup (Builder, Word) ()
-> PayloadEncoder 'TokenlessStream
runBatchEncoder action =
let (_, (_, (builder, len))) = runTardis action (mempty, (mempty, 0))
in PayloadBatchEncoder len builder
getFixups :: Tardis fixup (Builder, Word) fixup
getFixups = getFuture
getPosition :: Tardis fixup (Builder, Word) Word
getPosition = snd <$> getPast
fixup :: Monoid fixup
=> fixup -> Tardis fixup (Builder, Word) ()
fixup f = modifyBackwards (mappend f)
emit :: Word -> Builder -> Tardis fixup (Builder, Word) ()
emit len b = modifyForwards (\(before, p) -> (before <> b, len + p))
newtype MajorVersion = MajorVersion Word8 deriving (Show, Num)
newtype MinorVersion = MinorVersion Word8 deriving (Show, Num)
newtype BuildNumber = BuildNumber Word16 deriving (Show, Num)
newtype SubBuildNumber = SubBuildNumber Word16 deriving (Show, Num)
data Encryption
= EncryptionOff | EncryptionOn
| EncryptionNotSupported
| EncryptionRequired
deriving Show
data Nonce
= Nonce !Word64 !Word64 !Word64 !Word64
deriving Show
data PreLoginOption
= VersionOption !MajorVersion !MinorVersion !BuildNumber !SubBuildNumber
| EncryptionOption !Encryption
| InstanceNameOption !Text
| ThreadIdOption !Word32
| MarsOption !Bool
| FedAuthRequiredOption !Word8
| NonceOption !Nonce
deriving Show
type PreLoginOptions = [ PreLoginOption ]
versionOption :: MajorVersion -> MinorVersion -> BuildNumber -> SubBuildNumber
-> PreLoginOptions
versionOption maj min b sb =
[ VersionOption maj min b sb ]
encryptionOff :: PreLoginOptions
encryptionOff = [ EncryptionOption EncryptionNotSupported ]
newtype PreLogin
= PreLoginP
{ preLoginOptions :: PreLoginOptions
} deriving Show
instance Payload PreLogin where
type PayloadStreaming PreLogin = 'TokenlessStream
encodePayload (PreLoginP options) =
runBatchEncoder $
do fixups <- getFixups
emitFixups options fixups
emit 1 (word8 0xFF)
forM_ options $ \option ->
do o <- getPosition
buildOption option
o' <- getPosition
fixup ([( fromIntegral o
, fromIntegral $ o' - o)])
where
emitFixups [] _ = pure ()
emitFixups (option:options) ~(~(offset, len):fixups) =
do emit 1 (word8 (optionToken option))
emit 2 (word16BE offset)
emit 2 (word16BE len)
emitFixups options fixups
optionToken :: PreLoginOption -> Word8
optionToken VersionOption {} = 0x00
optionToken EncryptionOption {} = 0x01
optionToken InstanceNameOption {} = 0x02
optionToken ThreadIdOption {} = 0x03
optionToken MarsOption {} = 0x04
optionToken FedAuthRequiredOption {} = 0x06
optionToken NonceOption {} = 0x07
buildOption (VersionOption (MajorVersion major)
(MinorVersion minor)
(BuildNumber build)
(SubBuildNumber subBuild)) =
emit 1 (word8 major) >> emit 1 (word8 minor) >>
emit 2 (word16BE build) >> emit 2 (word16BE subBuild)
buildOption (EncryptionOption enc) =
buildEncryption enc
buildOption (InstanceNameOption nm) =
let nmBs = TE.encodeUtf8 nm
in emit (fromIntegral $ BS.length nmBs)
(byteString nmBs) >>
emit 0 (word8 0)
buildOption (ThreadIdOption w) =
emit 4 (word32BE w)
buildOption (MarsOption False) = emit 1 (word8 0x00)
buildOption (MarsOption True) = emit 1 (word8 0x01)
buildOption (FedAuthRequiredOption a) =
emit 1 (word8 a)
buildOption (NonceOption (Nonce a3 a2 a1 a0)) =
emit 32 (word64LE a0 <> word64LE a1 <>
word64LE a2 <> word64LE a3)
buildEncryption e =
emit 1 . word8 $
case e of
EncryptionOff -> 0x00
EncryptionOn -> 0x01
EncryptionNotSupported -> 0x02
EncryptionRequired -> 0x03
newtype TDSVersion
= TDSVersion { fromTDSVersion :: Word32 }
deriving Show
newtype ClientProgVersion
= ClientProgVersion { fromClientProgVersion :: Word32 }
deriving Show
newtype ClientPID
= ClientPID { fromClientPID :: Word32 }
deriving Show
newtype ConnectionID
= ConnectionID { fromConnectionID :: Word32 }
deriving Show
newtype Login7Options
= Login7Options { fromLogin7Options :: Word32 }
deriving Show
instance Monoid Login7Options where
mempty = Login7Options 0
mappend (Login7Options a) (Login7Options b) =
Login7Options (a .|. b)
instance Semigroup Login7Options where
(<>) = mappend
defaultLoginOptions :: Login7Options
defaultLoginOptions = Login7Options 0x000003e0
data Login7Feature = Login7Feature deriving Show
newtype LCID = LCID Word16 deriving Show
newtype CollationFlags = CollationFlags Word8 deriving Show
newtype CollationVersion = CollationVersion Word8 deriving Show
data Collation
= Collation
{ collationLCID :: !LCID
, collationFlags :: !CollationFlags
, collationVersion :: !CollationVersion
} deriving Show
data ClientID
= ClientID !Word16 !Word32
deriving (Show, Eq)
tdsVersion71 :: TDSVersion
tdsVersion71 = TDSVersion 0x00000071
data Login7
= Login7P
{ login7_tdsVersion :: !TDSVersion
, login7_packetSize :: !Word32
, login7_clientProgVer :: !ClientProgVersion
, login7_clientPID :: !ClientPID
, login7_connectionID :: !ConnectionID
, login7_flags :: !Login7Options
, login7_clientTmZone :: !Word32
, login7_collation :: !Collation
, login7_hostName :: !Text
, login7_userName :: !Text
, login7_password :: !Text
, login7_appName :: !Text
, login7_serverName :: !Text
, login7_extension :: !Text
, login7_cltIntName :: !Text
, login7_language :: !Text
, login7_database :: !Text
, login7_clientID :: !ClientID
, login7_SSPI :: !Text
, login7_atchDbFile :: !Text
, login7_changePasswd :: !Text
, login7_SSPILong :: !Word32
, login7_extraFeatures :: [Login7Feature]
} deriving Show
data Login7Fixups
= Login7Fixups
{ login7_totalLen :: !(Sum Word32)
, login7_hostNameOfs :: !Word16
, login7_userNameOfs :: !Word16
, login7_passwordOfs :: !Word16
, login7_appNameOfs :: !Word16
, login7_serverNameOfs :: !Word16
, login7_extensionOfs :: !Word16
, login7_cltIntNameOfs :: !Word16
, login7_languageOfs :: !Word16
, login7_databaseOfs :: !Word16
, login7_SSPIOfs :: !Word16
, login7_atchDbFileOfs :: !Word16
, login7_changePasswdOfs :: !Word16
} deriving Show
instance Monoid Login7Fixups where
mempty = Login7Fixups mempty 0 0 0 0 0 0 0 0 0 0 0 0
mappend a b =
Login7Fixups (login7_totalLen a <> login7_totalLen b)
(go login7_hostNameOfs) (go login7_userNameOfs)
(go login7_passwordOfs) (go login7_appNameOfs)
(go login7_serverNameOfs) (go login7_extensionOfs)
(go login7_cltIntNameOfs) (go login7_languageOfs)
(go login7_databaseOfs) (go login7_SSPIOfs)
(go login7_atchDbFileOfs) (go login7_changePasswdOfs)
where
go f = if f b == 0 then f a else f b
instance Semigroup Login7Fixups where
(<>) = mappend
instance Payload Login7 where
type PayloadStreaming Login7 = 'TokenlessStream
encodePayload d =
runBatchEncoder $
do ~(Sum totalLen) <- login7_totalLen <$> getFixups
emit 4 (word32LE (fromIntegral totalLen))
emit 4 (word32BE (fromTDSVersion (login7_tdsVersion d)))
emit 4 (word32LE (login7_packetSize d))
emit 4 (word32LE (fromClientProgVersion (login7_clientProgVer d)))
emit 4 (word32LE (fromClientPID (login7_clientPID d)))
emit 4 (word32LE (fromConnectionID (login7_connectionID d)))
emit 4 (word32LE (fromLogin7Options (login7_flags d)))
emit 4 (word32LE (login7_clientTmZone d))
emitCollation (login7_collation d)
emitOffset login7_hostName login7_hostNameOfs
emitOffset login7_userName login7_userNameOfs
emitOffset login7_password login7_passwordOfs
emitOffset login7_appName login7_appNameOfs
emitOffset login7_serverName login7_serverNameOfs
emitOffset login7_extension login7_extensionOfs
emitOffset login7_cltIntName login7_cltIntNameOfs
emitOffset login7_language login7_languageOfs
emitOffset login7_database login7_databaseOfs
let ClientID hi lo = login7_clientID d
emit 4 (word32LE lo)
emit 2 (word16LE hi)
emitOffset login7_SSPI login7_SSPIOfs
emitOffset login7_atchDbFile login7_atchDbFileOfs
curPos <- fromIntegral <$> getPosition
fixup (mempty { login7_hostNameOfs = curPos })
emitData id login7_hostName (\x -> mempty { login7_hostNameOfs = x })
emitData id login7_userName (\x -> mempty { login7_userNameOfs = x })
emitData enc login7_password (\x -> mempty { login7_passwordOfs = x })
emitData id login7_appName (\x -> mempty { login7_appNameOfs = x })
emitData id login7_serverName (\x -> mempty { login7_serverNameOfs = x })
emitData id login7_extension (\x -> mempty { login7_extensionOfs = x })
emitData id login7_cltIntName (\x -> mempty { login7_cltIntNameOfs = x })
emitData id login7_language (\x -> mempty { login7_languageOfs = x })
emitData id login7_database (\x -> mempty { login7_databaseOfs = x })
pos <- fromIntegral <$> getPosition
fixup (mempty { login7_totalLen = Sum pos })
where
emitOffset getBs getOfs = do
let bs = getBs d
ofs <- getOfs <$> getFixups
emit 2 (word16LE ofs)
emit 2 (word16LE (fromIntegral (T.length bs)))
enc b = let swapped = ((b `shiftR` 4) .&. 0xF) .|.
((b .&. 0xF) `shiftL` 4)
in swapped `xor` 0xA5
emitData f getBs setOfs = do
let bs = getBs d
if T.length bs > 0
then do
ofs <- fromIntegral <$> getPosition
fixup (setOfs ofs)
emit (2 * fromIntegral (T.length bs))
(byteString (BS.map f $ TE.encodeUtf16LE bs))
else pure ()
emitCollation (Collation (LCID lcid)
(CollationFlags flags)
(CollationVersion vers)) = do
emit 4 (word32LE (((fromIntegral vers .&. 0xF) `shiftR` 28) .|.
(fromIntegral flags `shiftR` 20) .|.
fromIntegral lcid))
data ResponseDecoder (streaming :: StreamType) (res :: *) where
DecodeBatchResponse :: (Ptr Word8 -> Word16 -> IO (Maybe res))
-> ResponseDecoder 'TokenlessStream res
DecodeTokenStream :: (forall r. IO () -> S.Stream TokenStream IO r -> IO res)
-> ResponseDecoder 'TokenStream res
class Show res => Response res where
type ResponseStreaming res :: StreamType
responseDecoder :: ResponseDecoder (ResponseStreaming res) res
instance Response Unimplemented where
type ResponseStreaming Unimplemented = 'TokenlessStream
responseDecoder = error "responseDecoder{Unimplemented}"
type BatchDecoder = StateT (Ptr Word8, Word16) (ReaderT Word16 (MaybeT IO))
decodeBatch :: BatchDecoder a -> ResponseDecoder 'TokenlessStream a
decodeBatch decoder =
DecodeBatchResponse $ \ptr sz ->
runMaybeT (runReaderT (evalStateT decoder (ptr, 0)) sz)
read8 :: BatchDecoder Word8
read8 = do
sz <- ask
(res, ofs) <- get
if ofs > sz
then lift (lift (MaybeT (pure Nothing)))
else do
put (res, ofs + 1)
liftIO (peek (res `plusPtr` fromIntegral ofs))
read16BE :: BatchDecoder Word16
read16BE = do
hi <- fromIntegral <$> read8
lo <- fromIntegral <$> read8
pure ((hi `shiftL` 8) .|. lo)
read32BE :: BatchDecoder Word32
read32BE = do
hi <- fromIntegral <$> read16BE
lo <- fromIntegral <$> read16BE
pure ((hi `shiftL` 16) .|. lo)
read64BE :: BatchDecoder Word64
read64BE = do
hi <- fromIntegral <$> read32BE
lo <- fromIntegral <$> read32BE
pure ((hi `shiftL` 32) .|. lo)
read16LE :: BatchDecoder Word16
read16LE = do
lo <- fromIntegral <$> read8
hi <- fromIntegral <$> read8
pure ((hi `shiftL` 8) .|. lo)
read32LE :: BatchDecoder Word32
read32LE = do
lo <- fromIntegral <$> read16LE
hi <- fromIntegral <$> read16LE
pure ((hi `shiftL` 16) .|. lo)
read64LE :: BatchDecoder Word64
read64LE = do
lo <- fromIntegral <$> read32LE
hi <- fromIntegral <$> read32LE
pure ((hi `shiftL` 32) .|. lo)
ztText :: Word16 -> BatchDecoder Text
ztText len = tell >>= go 0
where
go ofs start
| ofs >= len = lift (lift (MaybeT (pure Nothing)))
| otherwise = do
c <- read8
if c == 0
then finish start ofs
else go (ofs + 1) start
finish start finalLen = do
(ptr, _) <- get
liftIO (T.peekCStringLen (castPtr ptr `plusPtr` fromIntegral start,
fromIntegral finalLen))
tell :: BatchDecoder Word16
tell = snd <$> get
seek :: Word16 -> BatchDecoder ()
seek ofs = do
sz <- ask
if ofs < sz
then modify (\(ptr, _) -> (ptr, ofs))
else lift (lift (MaybeT (pure Nothing)))
instance Response PreLogin where
type ResponseStreaming PreLogin = 'TokenlessStream
responseDecoder =
decodeBatch $ do
options <- readOptions id
fmap (PreLoginP . catMaybes)
(sequence options)
where
readOptions a = do
tag <- read8
if tag == (0xFF :: Word8)
then pure (a [])
else do
ofs <- read16BE
len <- read16BE
readOptions (a . ((seek ofs >> readOption tag len):))
readOption 0x00 len
| len >= 6 =
Just <$> (VersionOption
<$> (MajorVersion <$> read8)
<*> (MinorVersion <$> read8)
<*> (BuildNumber <$> read16BE)
<*> (SubBuildNumber <$> read16BE))
readOption 0x01 len
| len >= 1 =
fmap (Just . EncryptionOption) $ do
tag <- read8
case tag of
0x00 -> pure EncryptionOff
0x01 -> pure EncryptionOn
0x02 -> pure EncryptionNotSupported
0x03 -> pure EncryptionRequired
_ -> lift (lift (MaybeT (pure Nothing)))
readOption 0x02 len
| len >= 1 =
Just . InstanceNameOption <$> ztText len
readOption 0x03 len
| len >= 4 =
Just . ThreadIdOption <$> read32BE
readOption 0x04 len
| len >= 1 =
Just . MarsOption <$> ((/=0) <$> read8)
readOption 0x06 len
| len >= 1 =
Just . FedAuthRequiredOption <$> read8
readOption 0x07 len
| len == 32 =
Just <$> (NonceOption <$>
(Nonce <$> read64LE <*> read64LE
<*> read64LE <*> read64LE))
readOption _ _ = lift (lift (MaybeT (pure Nothing)))
data Login7Ack = Login7Ack !Word8 !TDSVersion !T.Text !ProgVersion
deriving Show
instance Response Login7Ack where
type ResponseStreaming Login7Ack = 'TokenStream
responseDecoder = DecodeTokenStream $ \finish s ->
S.streamFold (\_ ack -> do
S.liftIO finish
case ack of
Just ack' -> pure ack'
_ -> fail "Token stream incomplete")
(\x ack -> x >>= ($ ack))
(\x ack ->
case x of
OneToken (LoginAck i v progNm progVers) next ->
next (Just (Login7Ack i v progNm progVers))
OneToken Done {} _
| Just ack' <- ack ->
S.liftIO finish >> pure ack'
| otherwise -> fail "Premature DONE"
OneToken token next -> do liftIO (putStrLn ("Unhandled token during LOGIN : " ++ show (() <$ token)))
next ack
ContParse {} -> fail "Can't parse unhandled token in LOGIN")
s Nothing
newtype DoneSts = DoneSts Word16 deriving Show
newtype ProgVersion = ProgVersion Word32 deriving Show
data Message
= Message
{ messageCode :: !SQLError
, messageSt :: !Word8
, messageClass :: !ErrorClass
, messageText :: !Text
, messageServerName :: !Text
, messageProcName :: !Text
, messageLineNum :: Word16
} deriving Show
instance Exception Message
data EnvChange
= EnvChangeDatabase !T.Text !T.Text
| EnvChangeLanguage !T.Text !T.Text
| EnvChangeCollation !BS.ByteString !BS.ByteString
| EnvChangeUnknown
deriving Show
data Token' f
= TvpRow
| Offset
| ReturnStatus
| ColMetadata !ColumnMetadata
| AltMetadata
| TableName
| ColumnInfo
| Order (V.Vector Word8)
| Error !Message
| Info !Message
| ReturnValue
| LoginAck !Word8 !TDSVersion !T.Text !ProgVersion
| FeatureExtAck
| Row !f
| NbcRow
| AltRow
| EnvChange !EnvChange
| SessionState
| SSPI
| FedAuthInfo
| Done !DoneSts !Word16 !Word64
| DoneProc
| DoneInProc
deriving (Functor, Show)
type Token = Token' (SBS.ByteString IO ())
data TokenStream f
= OneToken !Token f
| ContParse !Token (SBS.ByteString IO () -> f)
deriving Functor
doneHasMore :: DoneSts -> Bool
doneHasMore (DoneSts sts) =
(sts .&. 1) /= 0
doneHasError :: DoneSts -> Bool
doneHasError (DoneSts sts) =
(sts .&. 2) /= 0
parseTokenStream :: SBS.ByteString IO () -> S.Stream TokenStream IO ()
parseTokenStream bs =
do (res, bs') <- liftIO (S.parse parseToken bs)
case res of
Left e ->
fail ("Stream decode error: " ++ show e)
Right (Right r) ->
S.wrap (OneToken r (parseTokenStream bs'))
Right (Left contParse) -> do
S.wrap (ContParse (contParse bs') parseTokenStream)
usVarChar, bVarChar :: A.Parser (Int, Text)
bVarChar = do
len <- fromIntegral <$> A.anyWord8
(1 + len * 2,) <$> (TE.decodeUtf16LE <$> A.take (len * 2))
usVarChar = do
len <- fromIntegral <$> anyWord16le
(2 + len * 2,) <$> (TE.decodeUtf16LE <$> A.take (len * 2))
bVarByte :: A.Parser (Int, BS.ByteString)
bVarByte = do
len <- fromIntegral <$> A.anyWord8
(1 + len,) . BS.copy <$> A.take len
parseToken :: A.Parser (Either (SBS.ByteString IO () -> Token) Token)
parseToken =
do tag <- A.anyWord8
case tag of
0x01 -> fail "TVPROW"
0x78 -> fail "OFFSET"
0x79 -> fail "RETURNSTATUS"
0x81 -> Right . ColMetadata <$> colMetadataP
0x88 -> fail "ALTMETADATA"
0xA4 -> fail "TABNAME"
0xA5 -> fail "COLINFO"
0xA9 -> Right . Order <$> orderP
0xAA -> Right . Error <$> messageP
0xAB -> Right . Info <$> messageP
0xAC -> fail "RETURNVALUE"
0xAD -> Right <$> loginAckP
0xAE -> fail "FEATUREEXTACK"
0xD1 -> pure (Left Row)
0xD2 -> fail "NBCROW"
0xD3 -> fail "ALTROW"
0xE3 -> Right <$> envChangeP
0xE4 -> fail "SESSIONSTATE"
0xED -> fail "SSPI"
0xEE -> fail "FEDAUTHINFO"
0xFD -> fmap Right (Done <$> (DoneSts <$> anyWord16le) <*> anyWord16le
<*> (fromIntegral <$> anyWord32le))
0xFE -> fail "DONEPROC"
0xFF -> fail "DONEINPROC"
_ -> do
d <- BS.pack <$> replicateM 16 A.anyWord8
fail ("Unknown token in TDS stream: " ++ show tag ++ " " ++ show d)
where
nextPacket totalLen actualLength =
if totalLen - fromIntegral actualLength < 0 then fail "len - actualLength < 0"
else replicateM_ (fromIntegral $ totalLen - fromIntegral actualLength) A.anyWord8
loginAckP = do
len <- anyWord16le
iface <- A.anyWord8
tdsVersion <- TDSVersion <$> anyWord32le
(progByteLen, progName) <- bVarChar
progVersion <- ProgVersion <$> anyWord32le
let actualLength = 1 + 4 +
progByteLen +
4
nextPacket len actualLength
pure (LoginAck iface tdsVersion progName progVersion)
envChangeP = do
len <- anyWord16le
tag <- A.anyWord8
(actualLen, chg) <-
case tag of
1 -> do
(newDbLen, newDb) <- bVarChar
(oldDbLen, oldDb) <- bVarChar
pure (oldDbLen + newDbLen, EnvChangeDatabase oldDb newDb)
2 -> do
(newLangLen, newLang) <- bVarChar
(oldLangLen, oldLang) <- bVarChar
pure (oldLangLen + newLangLen, EnvChangeLanguage oldLang newLang)
7 -> do
(newCollLen, newColl) <- bVarByte
(oldCollLen, oldColl) <- bVarByte
pure (oldCollLen + newCollLen, EnvChangeCollation oldColl newColl)
_ -> pure (0, EnvChangeUnknown)
nextPacket len (1 + actualLen)
pure (EnvChange chg)
orderP = do
len <- anyWord16le
V.fromList <$> replicateM (fromIntegral len) A.anyWord8
messageP = do
len <- anyWord16le
msgCode <- SQLError <$> anyWord32le
st <- A.anyWord8
cls <- ErrorClass <$> A.anyWord8
(msgByteLen, msgText) <- usVarChar
(serverNameByteLen, serverName) <- bVarChar
(procNameByteLen, procName) <- bVarChar
lnNum <- anyWord16le
let actualLength = 4 + 1 + 1 +
msgByteLen +
serverNameByteLen +
procNameByteLen +
2
nextPacket len actualLength
pure (Message msgCode st cls msgText serverName procName lnNum)
data EncryptionAlgorithm = EncryptionAlgorithm deriving Show
data EncryptionAlgorithmType = EncryptionAlgorithmType deriving Show
data CharKind = NormalChar | NationalChar deriving Show
data PrecScale = PrecScale !Word8 !Word8 deriving Show
data TypeLen = ByteLen | ShortLen deriving Show
data TypeInfo
= NullType
| IntNType !Bool !Word8
| GuidType !Word8
| DecimalType !Bool !Word8 !PrecScale
| NumericType !Bool !Word8 !PrecScale
| BitNType !Bool !Word8
| DecimalNType !Bool !Word8 !PrecScale
| NumericNType !Bool !Word8 !PrecScale
| FloatNType !Bool !Word8
| MoneyNType !Bool !Word8
| DtTmNType !Bool !Word8
| DateNType !Bool !Word8
| TimeNType !Bool !Word8
| DtTm2NType !Word8
| DtTmOfsType !Word8
| CharType !TypeLen !CharKind !Word16 !(Maybe Collation)
| VarcharType !TypeLen !CharKind !Word16 !(Maybe Collation)
| BinaryType !Word16
| VarBinType !Word16
| ImageType !Word32
| NTextType !Word32 !Collation
| SSVarType !Word32
| TextType !Word32 !Collation
| XMLType !Word32
deriving Show
data CryptoMetadata
= CryptoMetadata
{ cryptoOrdinal :: !Word16
, cryptoUserType :: !Word32
, cryptoBaseType :: !TypeInfo
, cryptoEncAlgo :: !EncryptionAlgorithm
, cryptoAlgName :: !T.Text
, cryptoAlgType :: !EncryptionAlgorithmType
, cryptoNormVers :: !Word8
} deriving Show
data ColumnData
= ColumnData
{ cdUserType :: !Word32
, cdFlags :: !Word16
, cdBaseTypeInfo :: !TypeInfo
, cdTableName :: !(Maybe T.Text)
, cdCrypto :: !(Maybe CryptoMetadata)
, cdColName :: !T.Text
} deriving Show
data ColumnMetadata
= ColumnMetadata
{ cmCekTbl :: BS.ByteString
, cmColData :: [ ColumnData ]
} deriving Show
data RawColumn f where
RawColumn :: ColumnData -> SBS.ByteString IO () -> (SBS.ByteString IO () -> f) -> RawColumn f
deriving instance Functor RawColumn
typeInfoParser :: A.Parser TypeInfo
typeInfoParser = do
tag <- A.anyWord8
case tag of
0x1F -> pure $ NullType
0x22 -> lLen $ ImageType
0x23 -> collation $ lLen $ TextType
0x24 -> bLen $ GuidType
0x25 -> bLen $ VarBinType
0x26 -> bLen $ IntNType True
0x27 -> noColl $ bLen $ VarcharType ByteLen NormalChar
0x28 -> bLen $ DateNType True
0x29 -> bLen $ TimeNType True
0x2A -> bLen $ DtTm2NType
0x2B -> bLen $ DtTmOfsType
0x2D -> bLen $ BinaryType
0x2F -> noColl $ bLen $ CharType ByteLen NormalChar
0x37 -> precScale $ bLen $ DecimalNType True
0x3F -> precScale $ bLen $ NumericType True
0x30 -> pure $ IntNType False 1
0x32 -> pure $ BitNType False 1
0x34 -> pure $ IntNType False 2
0x38 -> pure $ IntNType False 4
0x3A -> pure $ DtTmNType False 4
0x3B -> pure $ FloatNType False 4
0x3C -> pure $ MoneyNType False 8
0x3D -> pure $ DtTmNType True 8
0x3E -> pure $ FloatNType False 8
0x62 -> lLen $ SSVarType
0x63 -> collation $ lLen $ NTextType
0x68 -> bLen $ BitNType True
0x6A -> precScale $ bLen $ DecimalNType True
0x6C -> precScale $ bLen $ NumericNType True
0x6D -> bLen $ FloatNType True
0x6E -> bLen $ MoneyNType True
0x6F -> bLen $ DtTmNType True
0x7A -> pure $ MoneyNType False 4
0x7F -> pure $ IntNType False 8
0xA5 -> usLen $ VarBinType
0xA7 -> optColl $ usLen $ VarcharType ShortLen NormalChar
0xAD -> usLen $ BinaryType
0xAF -> optColl $ usLen $ CharType ShortLen NormalChar
0xE7 -> optColl $ usLen $ VarcharType ShortLen NationalChar
0xEF -> optColl $ usLen $ CharType ShortLen NationalChar
0xF1 -> usLen $ XMLType
0xF0 -> fail "CLR UDT not supported"
_ -> fail ("Unknown TypeInfo tag " ++ show tag)
where
bLen, usLen, lLen :: Num a => (a -> b) -> A.Parser b
bLen mk = mk . fromIntegral <$> A.anyWord8
usLen mk = mk . fromIntegral <$> anyWord16le
lLen mk = mk . fromIntegral <$> anyWord32le
precScale mk = do
prec <- A.anyWord8
scale <- A.anyWord8
mk <*> pure (PrecScale prec scale)
collation mk =
mk <*> collationP
optColl mk =
mk <*> (fmap Just collationP)
noColl mk = mk <*> pure Nothing
collationP = do
lcidData <- anyWord32le
let lcid = LCID (fromIntegral (lcidData .&. 0xFFFFF))
colFlags = CollationFlags (fromIntegral ((lcidData `shiftR` 20) .&. 0xFF))
version = CollationVersion (fromIntegral ((lcidData `shiftR` 28) .&. 0xF))
sortId <- A.anyWord8
pure (Collation lcid colFlags version)
colDataParser :: A.Parser ColumnData
colDataParser = do
userType <- fromIntegral <$> anyWord16le
flags <- anyWord16le
typeInfo <- typeInfoParser
let hasTableName =
case typeInfo of
NTextType {} -> True
TextType {} -> True
ImageType {} -> True
_ -> False
tblName <- if hasTableName
then fail "TODO table name parsing"
else pure Nothing
(_, colNm) <- bVarChar
pure (ColumnData userType flags typeInfo tblName Nothing colNm)
colMetadataP :: A.Parser ColumnMetadata
colMetadataP = do
colCnt <- fromIntegral <$> anyWord16le
let columnsPresent = if colCnt == 0xFFFF then 0 else colCnt
colsData <- replicateM columnsPresent colDataParser
if columnsPresent == 0
then do
metadata <- anyWord16le
when (metadata /= 0xFFFF) (fail "TDS COLMETADATA decode error: Columns present")
pure (ColumnMetadata mempty [])
else
pure (ColumnMetadata mempty colsData)
newtype RowResults
= RowResults
{ getRowResults :: S.Stream (Compose (S.Of ColumnMetadata)
(S.Stream (S.Stream RawColumn IO) IO))
IO ()
}
instance Show RowResults where
show _ = "RowResults <stream>"
instance Response RowResults where
type ResponseStreaming RowResults = 'TokenStream
responseDecoder = DecodeTokenStream (\finish s ->
pure (RowResults (resultsDecoder s >>
liftIO finish)))
where
resultsDecoder :: S.Stream TokenStream IO a
-> S.Stream (Compose (S.Of ColumnMetadata)
(S.Stream (S.Stream RawColumn IO) IO))
IO ()
resultsDecoder tokens =
do res <- liftIO $ S.inspect tokens
case res of
Left a -> pure ()
Right (OneToken (ColMetadata mt) next) ->
S.wrap (Compose (mt S.:> parseRowData mt next))
Right (OneToken Done {} next) ->
pure ()
Right (OneToken _ next) -> resultsDecoder next
Right (ContParse tok _) -> fail ("Can't parse " ++ show (() <$ tok))
parseRowData :: ColumnMetadata
-> S.Stream TokenStream IO a
-> S.Stream (S.Stream RawColumn IO)
IO
(S.Stream
(Compose
(S.Of ColumnMetadata) (S.Stream (S.Stream RawColumn IO) IO))
IO ())
parseRowData cols tokens =
do res <- liftIO $ S.inspect tokens
case res of
Left a -> fail "Tokens ended while parsing row results"
Right (OneToken (Done sts _ _) next)
| doneHasMore sts -> pure (resultsDecoder next)
| otherwise -> pure (pure ())
Right (OneToken _ next) -> parseRowData cols next
Right (ContParse (Row rowData) next) ->
S.wrap (parseColumns cols (cmColData cols) rowData next)
Right (ContParse tok _) -> fail ("Unknown token while parsing row result: " ++ show (() <$ tok))
parseColumns :: ColumnMetadata -> [ColumnData]
-> SBS.ByteString IO ()
-> (SBS.ByteString IO () -> S.Stream TokenStream IO a)
-> S.Stream RawColumn IO
(S.Stream (S.Stream RawColumn IO) IO
(S.Stream (Compose
(S.Of ColumnMetadata)
(S.Stream (S.Stream RawColumn IO) IO))
IO ()))
parseColumns mt [] rowData next =
pure (parseRowData mt (next rowData))
parseColumns mt (col:cols) rowData next =
S.wrap (RawColumn col rowData (\row -> parseColumns mt cols row next))
data SplitEncoding sender resp pld
= LastPacket (Ptr () -> IO CSize)
| OnePacket (Ptr () -> IO (Packet sender resp pld (SplitEncoding sender resp)))
type SplitPacket sender resp pld =
Packet sender resp pld (SplitEncoding sender resp)
splitPacket :: CSize -> Packet 'Client resp d PacketEncoding
-> (Maybe CSize, SplitPacket 'Client resp d)
splitPacket bufSz (Packet pktHdr (PacketEncoding encoder)) =
case encoder of
PayloadBatchEncoder len builder ->
let chunkSz = min bufSz (fromIntegral len)
chunks = toLazyByteStringWith (untrimmedStrategy (fromIntegral chunkSz)
(fromIntegral chunkSz))
mempty
builder
unfoldSplit _ _ [] = error "No data in packet"
unfoldSplit i lenLeft (a:as) =
let (ptr, ofs, len) = IBS.toForeignPtr a
in Packet (pktHdr { pktHdrSeqID = PacketSequenceID (i + 1) }) $
if null as || len == lenLeft
then LastPacket $ \dst ->
withForeignPtr ptr $ \ptrRaw ->
copyBytes dst (castPtr ptrRaw `plusPtr` ofs) lenLeft >>
pure (fromIntegral lenLeft)
else OnePacket $ \dst ->
withForeignPtr ptr $ \ptrRaw ->
copyBytes dst (castPtr ptrRaw `plusPtr` ofs) (fromIntegral chunkSz) >>
pure (unfoldSplit (i + 1) (lenLeft - BS.length a)
as)
in ( Just chunkSz
, unfoldSplit 0 (fromIntegral len) (BL.toChunks chunks) )
PayloadStreamEncoder builder ->
error "PayloadStreamEncoder TODO"
class KnownBool (b :: Bool) where
boolVal :: p b -> Bool
instance KnownBool 'True where
boolVal _ = True
instance KnownBool 'False where
boolVal _ = False