{-# OPTIONS_GHC -funbox-strict-fields #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
module Database.TDS.Types where
import Control.Applicative
import Control.Concurrent.STM
import Control.Exception
import Control.Monad.Identity
import Data.Monoid ((<>))
import qualified Data.Text as T
import Data.Word
import qualified Database.TDS.Proto as Proto
import qualified Database.TDS.Proto.Errors as Proto
import System.IO
data ClientState
= Connecting
| SentPreLogin
| SentTLSSSLNegotiation
| SentLogin7WithCompleteAuthenticationToken
| SentLogin7WithSPNEGO
| SentLogin7WithFAIR
| LoggedIn
| SentClientRequest
| SentAttention
| RoutingCompleted
| Final
deriving (Show, Enum, Eq, Ord, Bounded)
sendableInState :: Proto.PacketType sender resp d -> ClientState -> Bool
sendableInState Proto.PreLogin Connecting = True
sendableInState Proto.Login7 SentPreLogin = True
sendableInState Proto.TrMgrRequest LoggedIn = True
sendableInState Proto.SQLBatch LoggedIn = True
sendableInState Proto.Attention _ = True
sendableInState _ _ = False
stateTransition :: Proto.PacketType sender resp d -> ClientState -> ClientState
stateTransition Proto.PreLogin Connecting = SentPreLogin
stateTransition Proto.Login7 SentPreLogin = SentLogin7WithCompleteAuthenticationToken
stateTransition Proto.SQLBatch LoggedIn = SentClientRequest
stateTransition _ _ = Final
newtype ConnectionTimeout = ConnectionTimeout Int
deriving (Show, Eq, Ord)
connectionTimeout :: ConnectionTimeout
connectionTimeout = ConnectionTimeout 15
data TDSErrorType
= TDSNoSuchHost | TDSSocketError
| TDSServerUninitialized
| TDSServerBusy | TDSServerQuit
| TDSInvalidStateTransition
| TDSInvalidResponse
deriving Show
data TDSError where
TDSError ::
Show d =>
{ tdsErrorType :: !TDSErrorType
, tdsErrorState :: !ClientState
, tdsErrorRequest :: !(Maybe (Proto.Packet sender resp d Identity))
, tdsErrorMessage :: !String
} -> TDSError
deriving instance Show TDSError
tdsErrorNoReq :: TDSErrorType -> ClientState -> String -> TDSError
tdsErrorNoReq ty st =
TDSError ty st
(Nothing :: Maybe (Proto.Packet 'Proto.Client 'Proto.NoResponse () Identity))
instance Exception TDSError where
displayException e =
unlines ([ "TDS Error\n"
, " " <> tdsErrorMessage e <> "\n"
, " Type: " <> show (tdsErrorType e)
, " State: " <> show (tdsErrorState e) ] ++
case e of
TDSError { tdsErrorRequest = Nothing } -> []
TDSError { tdsErrorRequest = Just req } ->
[ " Request: ", Proto.displayRequest req ])
data ConnectionInfo
= ConnectionInfo
{ _tdsConnHost :: Maybe String
, _tdsConnPort :: Maybe Word16
} deriving (Show, Eq, Ord)
instance Monoid ConnectionInfo where
mempty = ConnectionInfo Nothing Nothing
mappend a b =
ConnectionInfo
{ _tdsConnHost = _tdsConnHost b <|> _tdsConnHost a
, _tdsConnPort = _tdsConnPort b <|> _tdsConnPort a
}
instance Semigroup ConnectionInfo where
(<>) = mappend
data TdsAuth = TdsAuth deriving (Show, Eq, Ord)
data Options
= Options
{ _tdsOptionsSecure :: Bool
, _tdsAuth :: Maybe TdsAuth
, _tdsConnInfo :: ConnectionInfo
, _tdsOnMessage :: Proto.Message -> IO ()
, _tdsOnEnvChange :: Proto.EnvChange -> IO ()
, _tdsUser :: T.Text
, _tdsPassword :: T.Text
, _tdsDatabase :: T.Text
, _tdsAppName :: T.Text
, _tdsClientName :: T.Text
}
instance Monoid Options where
mempty =
Options
{ _tdsOptionsSecure = False
, _tdsAuth = Nothing
, _tdsConnInfo = mempty
, _tdsOnMessage = \_ -> pure ()
, _tdsOnEnvChange = \_ -> pure ()
, _tdsUser = ""
, _tdsPassword = ""
, _tdsDatabase = ""
, _tdsAppName = ""
, _tdsClientName = ""
}
mappend a b =
Options
{ _tdsOptionsSecure =
_tdsOptionsSecure a || _tdsOptionsSecure b
, _tdsAuth = _tdsAuth b <|> _tdsAuth a
, _tdsConnInfo = _tdsConnInfo a <> _tdsConnInfo b
, _tdsOnMessage = _tdsOnMessage a >> _tdsOnMessage b
, _tdsOnEnvChange = _tdsOnEnvChange a >> _tdsOnEnvChange b
, _tdsUser = altText (_tdsUser a) (_tdsUser b)
, _tdsPassword = altText (_tdsPassword a) (_tdsPassword b)
, _tdsDatabase = altText (_tdsDatabase a) (_tdsDatabase b)
, _tdsAppName = altText (_tdsAppName a) (_tdsAppName b)
, _tdsClientName = altText (_tdsClientName a) (_tdsClientName b)
}
where
altText a b
| T.null a = b
| otherwise = a
instance Semigroup Options where
(<>) = mappend
defaultOptions :: Options
defaultOptions =
mempty
{ _tdsConnInfo = ConnectionInfo
{ _tdsConnHost = Just "localhost"
, _tdsConnPort = Just 1433
}
, _tdsClientName = "haskell-tds"
}
tdsOptionSecure :: Options
tdsOptionSecure =
mempty { _tdsOptionsSecure = True }
tdsOptionHost :: String -> Options
tdsOptionHost h = mempty { _tdsConnInfo = mempty { _tdsConnHost = Just h } }
tdsOptionPort :: Word16 -> Options
tdsOptionPort p = mempty { _tdsConnInfo = mempty { _tdsConnPort = Just p } }
tdsOptionUserAndPassword :: T.Text -> T.Text -> Options
tdsOptionUserAndPassword user pw = mempty { _tdsUser = user
, _tdsPassword = pw }
tdsOptionDatabase :: T.Text -> Options
tdsOptionDatabase db = mempty { _tdsDatabase = db }
tdsDebugLogging :: Options
tdsDebugLogging = mempty { _tdsOnMessage = debugMsg
, _tdsOnEnvChange = debugEnvChange }
where
debugEnvChange c = hPutStrLn stderr ("[ENVCHANGE] " ++ show c)
debugMsg msg =
do let level = case Proto.clsSeverity (Proto.messageClass msg) of
Proto.Information -> "[INFO ] "
Proto.Fatal -> "[FATAL ] "
Proto.SQLError msgCode = Proto.messageCode msg
Proto.ErrorClass cls = Proto.messageClass msg
hPutStr stderr . unlines . map (level <>) $
[ "Message Code: " ++ show msg
, "Message Status: " ++ show (Proto.messageSt msg)
, "Message Class: " ++ show cls
, ""
, " " ++ T.unpack (Proto.messageText msg)
, ""
, "Server name: " ++ T.unpack (Proto.messageServerName msg)
, "Process name: " ++ T.unpack (Proto.messageProcName msg)
, "Line Number: " ++ show (Proto.messageLineNum msg)
]
data ResponseResult (ty :: Proto.ResponseType *) where
ResponseResultCancelled :: ResponseResult ('Proto.ResponseType 'True a)
ResponseResultReceived :: Show a => a -> ResponseResult ('Proto.ResponseType cancelable a)
deriving instance Show (ResponseResult ty)
data CancelInfo (canCancel :: Bool) where
NonCancelable :: CancelInfo 'False
Cancelable :: TVar Bool -> TMVar () -> CancelInfo 'True
class MkCancelable (canCancel :: Bool) where
mkCancelable :: STM (CancelInfo canCancel)
instance MkCancelable 'False where
mkCancelable = pure NonCancelable
instance MkCancelable 'True where
mkCancelable = Cancelable <$> newTVar False <*> newEmptyTMVar
data Connection
= Connection
{ tdsSendPacket :: forall (sender :: Proto.Sender) (d :: *)
(cancelable :: Bool) (res :: *)
. ( Proto.Payload d, Proto.Response res
, Proto.KnownBool cancelable
, MkCancelable cancelable )
=> Proto.Packet 'Proto.Client
('Proto.ExpectsResponse ('Proto.ResponseType cancelable res))
d Identity
-> IO (IO (ResponseResult ('Proto.ResponseType cancelable res)))
, tdsCancel :: IO ()
, tdsQuit :: IO ()
, tdsConnectionState :: !(TVar ClientState)
, tdsConnectionOptions :: !Options
}
getReadyState :: Connection -> IO ClientState
getReadyState c =
atomically $
readTVar (tdsConnectionState c)
isConnected :: Connection -> IO Bool
isConnected c =
getReadyState c >>=
\case
Final -> pure False
_ -> pure True
isReady :: Connection -> IO Bool
isReady c =
getReadyState c >>=
\case
LoggedIn -> pure True
_ -> pure False
isBusy :: Connection -> IO Bool
isBusy c =
getReadyState c >>=
\case
SentClientRequest -> pure True
SentAttention -> pure True
_ -> pure False
isCancelable :: Connection -> IO Bool
isCancelable c =
getReadyState c >>=
\case
SentClientRequest -> pure True
_ -> pure False