module Network.Xmpp.Tls where
import Control.Applicative ((<$>))
import qualified Control.Exception.Lifted as Ex
import Control.Monad
import Control.Monad.Error
import Control.Monad.State.Strict
import "crypto-random" Crypto.Random
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC8
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import Data.IORef
import Data.XML.Types
import Network.DNS.Resolver (ResolvConf)
import Network.TLS
import Network.Xmpp.Stream
import Network.Xmpp.Types
import System.Log.Logger (debugM, errorM, infoM)
mkBackend :: StreamHandle -> Backend
mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs)
, backendRecv = bufferReceive (streamReceive con)
, backendFlush = streamFlush con
, backendClose = streamClose con
}
where
bufferReceive _ 0 = return BS.empty
bufferReceive recv n = BS.concat `liftM` (go n)
where
go m = do
mbBs <- recv m
bs <- case mbBs of
Left e -> Ex.throwIO e
Right r -> return r
case BS.length bs of
0 -> return []
l -> if l < m
then (bs :) `liftM` go (m l)
else return [bs]
starttlsE :: Element
starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []
tls :: Stream -> IO (Either XmppFailure ())
tls con = fmap join
. wrapExceptions
. flip withStream con
. runErrorT $ do
conf <- gets $ streamConfiguration
sState <- gets streamConnectionState
case sState of
Plain -> return ()
Closed -> do
liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is closed."
throwError XmppNoStream
Finished -> do
liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is finished."
throwError XmppNoStream
Secured -> do
liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is already secured."
throwError TlsStreamSecured
features <- lift $ gets streamFeatures
case (tlsBehaviour conf, streamTls features) of
(RequireTls , Just _ ) -> startTls
(RequireTls , Nothing ) -> throwError TlsNoServerSupport
(PreferTls , Just _ ) -> startTls
(PreferTls , Nothing ) -> skipTls
(PreferPlain , Just True) -> startTls
(PreferPlain , _ ) -> skipTls
(RefuseTls , Just True) -> throwError XmppOtherFailure
(RefuseTls , _ ) -> skipTls
where
skipTls = liftIO $ infoM "Pontarius.Xmpp.Tls" "Skipping TLS negotiation"
startTls = do
liftIO $ infoM "Pontarius.Xmpp.Tls" "Running StartTLS"
params <- gets $ tlsParams . streamConfiguration
ErrorT $ pushElement starttlsE
answer <- lift $ pullElement
case answer of
Left e -> throwError e
Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) ->
return ()
Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}failure" _ _) -> do
liftIO $ errorM "Pontarius.Xmpp" "startTls: TLS initiation failed."
throwError XmppOtherFailure
Right r ->
liftIO $ errorM "Pontarius.Xmpp.Tls" $
"Unexpected element: " ++ show r
hand <- gets streamHandle
(_raw, _snk, psh, recv, ctx) <- lift $ tlsinit params (mkBackend hand)
let newHand = StreamHandle { streamSend = catchPush . psh
, streamReceive = wrapExceptions . recv
, streamFlush = contextFlush ctx
, streamClose = bye ctx >> streamClose hand
}
lift $ modify ( \x -> x {streamHandle = newHand})
liftIO $ infoM "Pontarius.Xmpp.Tls" "Stream Secured."
either (lift . Ex.throwIO) return =<< lift restartStream
modify (\s -> s{streamConnectionState = Secured})
return ()
client :: (MonadIO m, CPRG rng) => ClientParams -> rng -> Backend -> m Context
client params gen backend = do
contextNew backend params gen
tlsinit :: (MonadIO m, MonadIO m1) =>
ClientParams
-> Backend
-> m ( Source m1 BS.ByteString
, Sink BS.ByteString m1 ()
, BS.ByteString -> IO ()
, Int -> m1 BS.ByteString
, Context
)
tlsinit params backend = do
liftIO $ debugM "Pontarius.Xmpp.Tls" "TLS with debug mode enabled."
gen <- liftIO (cprgCreate <$> createEntropyPool :: IO SystemRNG)
con <- client params gen backend
handshake con
let src = forever $ do
dt <- liftIO $ recvData con
liftIO $ debugM "Pontarius.Xmpp.Tls" ("In :" ++ BSC8.unpack dt)
yield dt
let snk = do
d <- await
case d of
Nothing -> return ()
Just x -> do
sendData con (BL.fromChunks [x])
snk
readWithBuffer <- liftIO $ mkReadBuffer (recvData con)
return ( src
, snk
, \s -> sendData con $ BL.fromChunks [s]
, liftIO . readWithBuffer
, con
)
mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
mkReadBuffer recv = do
buffer <- newIORef BS.empty
let read' n = do
nc <- readIORef buffer
bs <- if BS.null nc then recv
else return nc
let (result, rest) = BS.splitAt n bs
writeIORef buffer rest
return result
return read'
connectTls :: ResolvConf
-> ClientParams
-> String
-> ErrorT XmppFailure IO StreamHandle
connectTls config params host = do
h <- connectSrv config host >>= \h' -> case h' of
Nothing -> throwError TcpConnectionFailure
Just h'' -> return h''
let hand = handleToStreamHandle h
(_raw, _snk, psh, recv, ctx) <- tlsinit params $ mkBackend hand
return $ StreamHandle { streamSend = catchPush . psh
, streamReceive = wrapExceptions . recv
, streamFlush = contextFlush ctx
, streamClose = bye ctx >> streamClose hand
}
wrapExceptions :: IO a -> IO (Either XmppFailure a)
wrapExceptions f = Ex.catches (liftM Right $ f)
[ Ex.Handler $ return . Left . XmppIOException
, Ex.Handler $ wrap . XmppTlsError
, Ex.Handler $ wrap . XmppTlsConnectionNotEstablished
, Ex.Handler $ wrap . XmppTlsTerminated
, Ex.Handler $ wrap . XmppTlsHandshakeFailed
, Ex.Handler $ return . Left
]
where
wrap = return . Left . TlsError