{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Network.TLS.IO
( sendPacket
, sendPacket13
, recvPacket
, recvPacket13
, isRecvComplete
, checkValid
, PacketFlightM
, runPacketFlight
, loadPacket13
) where
import Control.Exception (finally, throwIO)
import Control.Monad.Reader
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef
import System.IO.Error (mkIOError, eofErrorType)
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Hooks
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Receiving
import Network.TLS.Receiving13
import Network.TLS.Record
import Network.TLS.Sending
import Network.TLS.Sending13
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
sendPacket :: MonadIO m => Context -> Packet -> m ()
sendPacket ctx pkt = do
when (isNonNullAppData pkt) $ do
withEmptyPacket <- liftIO $ readIORef $ ctxNeedEmptyPacket ctx
when withEmptyPacket $
writePacketBytes ctx (AppData B.empty) >>= sendBytes ctx
writePacketBytes ctx pkt >>= sendBytes ctx
where isNonNullAppData (AppData b) = not $ B.null b
isNonNullAppData _ = False
writePacketBytes :: MonadIO m => Context -> Packet -> m ByteString
writePacketBytes ctx pkt = do
edataToSend <- liftIO $ do
withLog ctx $ \logging -> loggingPacketSent logging (show pkt)
encodePacket ctx pkt
either throwCore return edataToSend
sendPacket13 :: MonadIO m => Context -> Packet13 -> m ()
sendPacket13 ctx pkt = writePacketBytes13 ctx pkt >>= sendBytes ctx
writePacketBytes13 :: MonadIO m => Context -> Packet13 -> m ByteString
writePacketBytes13 ctx pkt = do
edataToSend <- liftIO $ do
withLog ctx $ \logging -> loggingPacketSent logging (show pkt)
encodePacket13 ctx pkt
either throwCore return edataToSend
sendBytes :: MonadIO m => Context -> ByteString -> m ()
sendBytes ctx dataToSend = liftIO $ do
withLog ctx $ \logging -> loggingIOSent logging dataToSend
contextSend ctx dataToSend
getRecord :: Context -> Int -> Header -> ByteString -> IO (Either TLSError (Record Plaintext))
getRecord ctx appDataOverhead header@(Header pt _ _) content = do
withLog ctx $ \logging -> loggingIORecv logging header content
runRxState ctx $ do
r <- decodeRecordM header content
let Record _ _ fragment = r
when (B.length (fragmentGetBytes fragment) > 16384 + overhead) $
throwError contentSizeExceeded
return r
where overhead = if pt == ProtocolType_AppData then appDataOverhead else 0
contentSizeExceeded :: TLSError
contentSizeExceeded = Error_Protocol ("record content exceeding maximum size", True, RecordOverflow)
recvPacket :: MonadIO m => Context -> m (Either TLSError Packet)
recvPacket ctx = liftIO $ do
compatSSLv2 <- ctxHasSSLv2ClientHello ctx
hrr <- usingState_ ctx getTLS13HRR
let appDataOverhead = if hrr then 256 else 0
erecord <- recvRecord compatSSLv2 appDataOverhead ctx
case erecord of
Left err -> return $ Left err
Right record ->
if hrr && isCCS record then
recvPacket ctx
else do
pktRecv <- processPacket ctx record
if isEmptyHandshake pktRecv then
recvPacket ctx
else do
pkt <- case pktRecv of
Right (Handshake hss) ->
ctxWithHooks ctx $ \hooks ->
Right . Handshake <$> mapM (hookRecvHandshake hooks) hss
_ -> return pktRecv
case pkt of
Right p -> withLog ctx $ \logging -> loggingPacketRecv logging $ show p
_ -> return ()
when compatSSLv2 $ ctxDisableSSLv2ClientHello ctx
return pkt
recvRecord :: Bool
-> Int
-> Context
-> IO (Either TLSError (Record Plaintext))
recvRecord compatSSLv2 appDataOverhead ctx
#ifdef SSLV2_COMPATIBLE
| compatSSLv2 = readExactBytes ctx 2 >>= either (return . Left) sslv2Header
#endif
| otherwise = readExactBytes ctx 5 >>= either (return . Left) (recvLengthE . decodeHeader)
where recvLengthE = either (return . Left) recvLength
recvLength header@(Header _ _ readlen)
| readlen > 16384 + 2048 = return $ Left maximumSizeExceeded
| otherwise =
readExactBytes ctx (fromIntegral readlen) >>=
either (return . Left) (getRecord ctx appDataOverhead header)
#ifdef SSLV2_COMPATIBLE
sslv2Header header =
if B.head header >= 0x80
then either (return . Left) recvDeprecatedLength $ decodeDeprecatedHeaderLength header
else readExactBytes ctx 3 >>=
either (return . Left) (recvLengthE . decodeHeader . B.append header)
recvDeprecatedLength readlen
| readlen > 1024 * 4 = return $ Left maximumSizeExceeded
| otherwise = do
res <- readExactBytes ctx (fromIntegral readlen)
case res of
Left e -> return $ Left e
Right content ->
let hdr = decodeDeprecatedHeader readlen (B.take 3 content)
in either (return . Left) (\h -> getRecord ctx appDataOverhead h content) hdr
#endif
isCCS :: Record a -> Bool
isCCS (Record ProtocolType_ChangeCipherSpec _ _) = True
isCCS _ = False
isEmptyHandshake :: Either TLSError Packet -> Bool
isEmptyHandshake (Right (Handshake [])) = True
isEmptyHandshake _ = False
recvPacket13 :: MonadIO m => Context -> m (Either TLSError Packet13)
recvPacket13 ctx = liftIO $ do
erecord <- recvRecord13 ctx
case erecord of
Left err@(Error_Protocol (_, True, BadRecordMac)) -> do
established <- ctxEstablished ctx
case established of
EarlyDataNotAllowed n
| n > 0 -> do setEstablished ctx $ EarlyDataNotAllowed (n - 1)
recvPacket13 ctx
_ -> return $ Left err
Left err -> return $ Left err
Right record -> do
pktRecv <- processPacket13 ctx record
if isEmptyHandshake13 pktRecv then
recvPacket13 ctx
else do
pkt <- case pktRecv of
Right (Handshake13 hss) ->
ctxWithHooks ctx $ \hooks ->
Right . Handshake13 <$> mapM (hookRecvHandshake13 hooks) hss
_ -> return pktRecv
case pkt of
Right p -> withLog ctx $ \logging -> loggingPacketRecv logging $ show p
_ -> return ()
return pkt
recvRecord13 :: Context
-> IO (Either TLSError (Record Plaintext))
recvRecord13 ctx = readExactBytes ctx 5 >>= either (return . Left) (recvLengthE . decodeHeader)
where recvLengthE = either (return . Left) recvLength
recvLength header@(Header _ _ readlen)
| readlen > 16384 + 256 = return $ Left maximumSizeExceeded
| otherwise =
readExactBytes ctx (fromIntegral readlen) >>=
either (return . Left) (getRecord ctx 0 header)
isEmptyHandshake13 :: Either TLSError Packet13 -> Bool
isEmptyHandshake13 (Right (Handshake13 [])) = True
isEmptyHandshake13 _ = False
maximumSizeExceeded :: TLSError
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
readExactBytes :: Context -> Int -> IO (Either TLSError ByteString)
readExactBytes ctx sz = do
hdrbs <- contextRecv ctx sz
if B.length hdrbs == sz
then return $ Right hdrbs
else do
setEOF ctx
return . Left $
if B.null hdrbs
then Error_EOF
else Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ show (B.length hdrbs))
isRecvComplete :: Context -> IO Bool
isRecvComplete ctx = usingState_ ctx $ do
cont <- gets stHandshakeRecordCont
cont13 <- gets stHandshakeRecordCont13
return $! isNothing cont && isNothing cont13
checkValid :: Context -> IO ()
checkValid ctx = do
established <- ctxEstablished ctx
when (established == NotEstablished) $ throwIO ConnectionNotEstablished
eofed <- ctxEOF ctx
when eofed $ throwIO $ mkIOError eofErrorType "data" Nothing Nothing
newtype PacketFlightM a = PacketFlightM (ReaderT (IORef [ByteString]) IO a)
deriving (Functor, Applicative, Monad, MonadFail, MonadIO)
runPacketFlight :: Context -> PacketFlightM a -> IO a
runPacketFlight ctx (PacketFlightM f) = do
ref <- newIORef []
finally (runReaderT f ref) $ do
st <- readIORef ref
unless (null st) $ sendBytes ctx $ B.concat $ reverse st
loadPacket13 :: Context -> Packet13 -> PacketFlightM ()
loadPacket13 ctx pkt = PacketFlightM $ do
bs <- writePacketBytes13 ctx pkt
ref <- ask
liftIO $ modifyIORef ref (bs :)