{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
module Network.TLS.IO
( sendPacket
, sendPacket13
, recvPacket
, recvPacket13
, isRecvComplete
, checkValid
, PacketFlightM
, runPacketFlight
, loadPacket13
) where
import Control.Exception (finally, throwIO)
import Control.Monad.Reader
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Context.Internal
import Network.TLS.Hooks
import Network.TLS.Imports
import Network.TLS.Receiving
import Network.TLS.Record
import Network.TLS.Record.Layer
import Network.TLS.Sending
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
sendPacket :: Context -> Packet -> IO ()
sendPacket :: Context -> Packet -> IO ()
sendPacket ctx :: Context
ctx@Context{ctxRecordLayer :: ()
ctxRecordLayer = RecordLayer bytes
recordLayer} Packet
pkt = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Packet -> Bool
isNonNullAppData Packet
pkt) forall a b. (a -> b) -> a -> b
$ do
Bool
withEmptyPacket <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ Context -> IORef Bool
ctxNeedEmptyPacket Context
ctx
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
withEmptyPacket forall a b. (a -> b) -> a -> b
$
forall bytes.
Monoid bytes =>
Context -> RecordLayer bytes -> Packet -> IO bytes
writePacketBytes Context
ctx RecordLayer bytes
recordLayer (ByteString -> Packet
AppData ByteString
B.empty) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
forall bytes. RecordLayer bytes -> bytes -> IO ()
recordSendBytes RecordLayer bytes
recordLayer
forall bytes.
Monoid bytes =>
Context -> RecordLayer bytes -> Packet -> IO bytes
writePacketBytes Context
ctx RecordLayer bytes
recordLayer Packet
pkt forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall bytes. RecordLayer bytes -> bytes -> IO ()
recordSendBytes RecordLayer bytes
recordLayer
where isNonNullAppData :: Packet -> Bool
isNonNullAppData (AppData ByteString
b) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ ByteString -> Bool
B.null ByteString
b
isNonNullAppData Packet
_ = Bool
False
writePacketBytes :: Monoid bytes
=> Context -> RecordLayer bytes -> Packet -> IO bytes
writePacketBytes :: forall bytes.
Monoid bytes =>
Context -> RecordLayer bytes -> Packet -> IO bytes
writePacketBytes Context
ctx RecordLayer bytes
recordLayer Packet
pkt = do
Context -> (Logging -> IO ()) -> IO ()
withLog Context
ctx forall a b. (a -> b) -> a -> b
$ \Logging
logging -> Logging -> String -> IO ()
loggingPacketSent Logging
logging (forall a. Show a => a -> String
show Packet
pkt)
Either TLSError bytes
edataToSend <- forall bytes.
Monoid bytes =>
Context
-> RecordLayer bytes -> Packet -> IO (Either TLSError bytes)
encodePacket Context
ctx RecordLayer bytes
recordLayer Packet
pkt
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError bytes
edataToSend
sendPacket13 :: Context -> Packet13 -> IO ()
sendPacket13 :: Context -> Packet13 -> IO ()
sendPacket13 ctx :: Context
ctx@Context{ctxRecordLayer :: ()
ctxRecordLayer = RecordLayer bytes
recordLayer} Packet13
pkt =
forall bytes.
Monoid bytes =>
Context -> RecordLayer bytes -> Packet13 -> IO bytes
writePacketBytes13 Context
ctx RecordLayer bytes
recordLayer Packet13
pkt forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall bytes. RecordLayer bytes -> bytes -> IO ()
recordSendBytes RecordLayer bytes
recordLayer
writePacketBytes13 :: Monoid bytes
=> Context -> RecordLayer bytes -> Packet13 -> IO bytes
writePacketBytes13 :: forall bytes.
Monoid bytes =>
Context -> RecordLayer bytes -> Packet13 -> IO bytes
writePacketBytes13 Context
ctx RecordLayer bytes
recordLayer Packet13
pkt = do
Context -> (Logging -> IO ()) -> IO ()
withLog Context
ctx forall a b. (a -> b) -> a -> b
$ \Logging
logging -> Logging -> String -> IO ()
loggingPacketSent Logging
logging (forall a. Show a => a -> String
show Packet13
pkt)
Either TLSError bytes
edataToSend <- forall bytes.
Monoid bytes =>
Context
-> RecordLayer bytes -> Packet13 -> IO (Either TLSError bytes)
encodePacket13 Context
ctx RecordLayer bytes
recordLayer Packet13
pkt
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError bytes
edataToSend
recvPacket :: Context -> IO (Either TLSError Packet)
recvPacket :: Context -> IO (Either TLSError Packet)
recvPacket ctx :: Context
ctx@Context{ctxRecordLayer :: ()
ctxRecordLayer = RecordLayer bytes
recordLayer} = do
Bool
compatSSLv2 <- Context -> IO Bool
ctxHasSSLv2ClientHello Context
ctx
Bool
hrr <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
let appDataOverhead :: Int
appDataOverhead = if Bool
hrr then Int
256 else Int
0
Either TLSError (Record Plaintext)
erecord <- forall bytes.
RecordLayer bytes
-> Bool -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv RecordLayer bytes
recordLayer Bool
compatSSLv2 Int
appDataOverhead
case Either TLSError (Record Plaintext)
erecord of
Left TLSError
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left TLSError
err
Right Record Plaintext
record ->
if Bool
hrr Bool -> Bool -> Bool
&& forall a. Record a -> Bool
isCCS Record Plaintext
record then
Context -> IO (Either TLSError Packet)
recvPacket Context
ctx
else do
Either TLSError Packet
pktRecv <- Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket Context
ctx Record Plaintext
record
if Either TLSError Packet -> Bool
isEmptyHandshake Either TLSError Packet
pktRecv then
Context -> IO (Either TLSError Packet)
recvPacket Context
ctx
else do
Either TLSError Packet
pkt <- case Either TLSError Packet
pktRecv of
Right (Handshake [Handshake]
hss) ->
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx forall a b. (a -> b) -> a -> b
$ \Hooks
hooks ->
forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Handshake] -> Packet
Handshake forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Hooks -> Handshake -> IO Handshake
hookRecvHandshake Hooks
hooks) [Handshake]
hss
Either TLSError Packet
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError Packet
pktRecv
case Either TLSError Packet
pkt of
Right Packet
p -> Context -> (Logging -> IO ()) -> IO ()
withLog Context
ctx forall a b. (a -> b) -> a -> b
$ \Logging
logging -> Logging -> String -> IO ()
loggingPacketRecv Logging
logging forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Packet
p
Either TLSError Packet
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
compatSSLv2 forall a b. (a -> b) -> a -> b
$ Context -> IO ()
ctxDisableSSLv2ClientHello Context
ctx
forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError Packet
pkt
isCCS :: Record a -> Bool
isCCS :: forall a. Record a -> Bool
isCCS (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment a
_) = Bool
True
isCCS Record a
_ = Bool
False
isEmptyHandshake :: Either TLSError Packet -> Bool
isEmptyHandshake :: Either TLSError Packet -> Bool
isEmptyHandshake (Right (Handshake [])) = Bool
True
isEmptyHandshake Either TLSError Packet
_ = Bool
False
recvPacket13 :: Context -> IO (Either TLSError Packet13)
recvPacket13 :: Context -> IO (Either TLSError Packet13)
recvPacket13 ctx :: Context
ctx@Context{ctxRecordLayer :: ()
ctxRecordLayer = RecordLayer bytes
recordLayer} = do
Either TLSError (Record Plaintext)
erecord <- forall bytes.
RecordLayer bytes -> IO (Either TLSError (Record Plaintext))
recordRecv13 RecordLayer bytes
recordLayer
case Either TLSError (Record Plaintext)
erecord of
Left err :: TLSError
err@(Error_Protocol String
_ AlertDescription
BadRecordMac) -> do
Established
established <- Context -> IO Established
ctxEstablished Context
ctx
case Established
established of
EarlyDataNotAllowed Int
n
| Int
n forall a. Ord a => a -> a -> Bool
> Int
0 -> do Context -> Established -> IO ()
setEstablished Context
ctx forall a b. (a -> b) -> a -> b
$ Int -> Established
EarlyDataNotAllowed (Int
n forall a. Num a => a -> a -> a
- Int
1)
Context -> IO (Either TLSError Packet13)
recvPacket13 Context
ctx
Established
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left TLSError
err
Left TLSError
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left TLSError
err
Right Record Plaintext
record -> do
Either TLSError Packet13
pktRecv <- Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 Context
ctx Record Plaintext
record
if Either TLSError Packet13 -> Bool
isEmptyHandshake13 Either TLSError Packet13
pktRecv then
Context -> IO (Either TLSError Packet13)
recvPacket13 Context
ctx
else do
Either TLSError Packet13
pkt <- case Either TLSError Packet13
pktRecv of
Right (Handshake13 [Handshake13]
hss) ->
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx forall a b. (a -> b) -> a -> b
$ \Hooks
hooks ->
forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Handshake13] -> Packet13
Handshake13 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Hooks -> Handshake13 -> IO Handshake13
hookRecvHandshake13 Hooks
hooks) [Handshake13]
hss
Either TLSError Packet13
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError Packet13
pktRecv
case Either TLSError Packet13
pkt of
Right Packet13
p -> Context -> (Logging -> IO ()) -> IO ()
withLog Context
ctx forall a b. (a -> b) -> a -> b
$ \Logging
logging -> Logging -> String -> IO ()
loggingPacketRecv Logging
logging forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Packet13
p
Either TLSError Packet13
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError Packet13
pkt
isEmptyHandshake13 :: Either TLSError Packet13 -> Bool
isEmptyHandshake13 :: Either TLSError Packet13 -> Bool
isEmptyHandshake13 (Right (Handshake13 [])) = Bool
True
isEmptyHandshake13 Either TLSError Packet13
_ = Bool
False
isRecvComplete :: Context -> IO Bool
isRecvComplete :: Context -> IO Bool
isRecvComplete Context
ctx = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do
Maybe (GetContinuation (HandshakeType, ByteString))
cont <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont
Maybe (GetContinuation (HandshakeType13, ByteString))
cont13 <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType13, ByteString))
stHandshakeRecordCont13
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall a. Maybe a -> Bool
isNothing Maybe (GetContinuation (HandshakeType, ByteString))
cont Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isNothing Maybe (GetContinuation (HandshakeType13, ByteString))
cont13
checkValid :: Context -> IO ()
checkValid :: Context -> IO ()
checkValid Context
ctx = do
Established
established <- Context -> IO Established
ctxEstablished Context
ctx
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Established
established forall a. Eq a => a -> a -> Bool
== Established
NotEstablished) forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO TLSException
ConnectionNotEstablished
Bool
eofed <- Context -> IO Bool
ctxEOF Context
ctx
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
eofed forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ TLSError -> TLSException
PostHandshake TLSError
Error_EOF
type Builder b = [b] -> [b]
newtype PacketFlightM b a = PacketFlightM (ReaderT (RecordLayer b, IORef (Builder b)) IO a)
deriving (forall a b. a -> PacketFlightM b b -> PacketFlightM b a
forall a b. (a -> b) -> PacketFlightM b a -> PacketFlightM b b
forall b a b. a -> PacketFlightM b b -> PacketFlightM b a
forall b a b. (a -> b) -> PacketFlightM b a -> PacketFlightM b b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> PacketFlightM b b -> PacketFlightM b a
$c<$ :: forall b a b. a -> PacketFlightM b b -> PacketFlightM b a
fmap :: forall a b. (a -> b) -> PacketFlightM b a -> PacketFlightM b b
$cfmap :: forall b a b. (a -> b) -> PacketFlightM b a -> PacketFlightM b b
Functor, forall b. Functor (PacketFlightM b)
forall a. a -> PacketFlightM b a
forall b a. a -> PacketFlightM b a
forall a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b a
forall a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
forall a b.
PacketFlightM b (a -> b) -> PacketFlightM b a -> PacketFlightM b b
forall b a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b a
forall b a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
forall b a b.
PacketFlightM b (a -> b) -> PacketFlightM b a -> PacketFlightM b b
forall a b c.
(a -> b -> c)
-> PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b c
forall b a b c.
(a -> b -> c)
-> PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b a
$c<* :: forall b a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b a
*> :: forall a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
$c*> :: forall b a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
liftA2 :: forall a b c.
(a -> b -> c)
-> PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b c
$cliftA2 :: forall b a b c.
(a -> b -> c)
-> PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b c
<*> :: forall a b.
PacketFlightM b (a -> b) -> PacketFlightM b a -> PacketFlightM b b
$c<*> :: forall b a b.
PacketFlightM b (a -> b) -> PacketFlightM b a -> PacketFlightM b b
pure :: forall a. a -> PacketFlightM b a
$cpure :: forall b a. a -> PacketFlightM b a
Applicative, forall b. Applicative (PacketFlightM b)
forall a. a -> PacketFlightM b a
forall b a. a -> PacketFlightM b a
forall a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
forall a b.
PacketFlightM b a -> (a -> PacketFlightM b b) -> PacketFlightM b b
forall b a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
forall b a b.
PacketFlightM b a -> (a -> PacketFlightM b b) -> PacketFlightM b b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> PacketFlightM b a
$creturn :: forall b a. a -> PacketFlightM b a
>> :: forall a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
$c>> :: forall b a b.
PacketFlightM b a -> PacketFlightM b b -> PacketFlightM b b
>>= :: forall a b.
PacketFlightM b a -> (a -> PacketFlightM b b) -> PacketFlightM b b
$c>>= :: forall b a b.
PacketFlightM b a -> (a -> PacketFlightM b b) -> PacketFlightM b b
Monad, forall b. Monad (PacketFlightM b)
forall a. String -> PacketFlightM b a
forall b a. String -> PacketFlightM b a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: forall a. String -> PacketFlightM b a
$cfail :: forall b a. String -> PacketFlightM b a
MonadFail, forall b. Monad (PacketFlightM b)
forall a. IO a -> PacketFlightM b a
forall b a. IO a -> PacketFlightM b a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: forall a. IO a -> PacketFlightM b a
$cliftIO :: forall b a. IO a -> PacketFlightM b a
MonadIO)
runPacketFlight :: Context -> (forall b . Monoid b => PacketFlightM b a) -> IO a
runPacketFlight :: forall a.
Context -> (forall b. Monoid b => PacketFlightM b a) -> IO a
runPacketFlight Context{ctxRecordLayer :: ()
ctxRecordLayer = RecordLayer bytes
recordLayer} (PacketFlightM ReaderT (RecordLayer bytes, IORef (Builder bytes)) IO a
f) = do
IORef (Builder bytes)
ref <- forall a. a -> IO (IORef a)
newIORef forall a. a -> a
id
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (RecordLayer bytes, IORef (Builder bytes)) IO a
f (RecordLayer bytes
recordLayer, IORef (Builder bytes)
ref) forall a b. IO a -> IO b -> IO a
`finally` forall b. Monoid b => RecordLayer b -> IORef (Builder b) -> IO ()
sendPendingFlight RecordLayer bytes
recordLayer IORef (Builder bytes)
ref
sendPendingFlight :: Monoid b => RecordLayer b -> IORef (Builder b) -> IO ()
sendPendingFlight :: forall b. Monoid b => RecordLayer b -> IORef (Builder b) -> IO ()
sendPendingFlight RecordLayer b
recordLayer IORef (Builder b)
ref = do
Builder b
build <- forall a. IORef a -> IO a
readIORef IORef (Builder b)
ref
let bss :: [b]
bss = Builder b
build []
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [b]
bss) forall a b. (a -> b) -> a -> b
$ forall bytes. RecordLayer bytes -> bytes -> IO ()
recordSendBytes RecordLayer b
recordLayer forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [b]
bss
loadPacket13 :: Monoid b => Context -> Packet13 -> PacketFlightM b ()
loadPacket13 :: forall b. Monoid b => Context -> Packet13 -> PacketFlightM b ()
loadPacket13 Context
ctx Packet13
pkt = forall b a.
ReaderT (RecordLayer b, IORef (Builder b)) IO a
-> PacketFlightM b a
PacketFlightM forall a b. (a -> b) -> a -> b
$ do
(RecordLayer b
recordLayer, IORef (Builder b)
ref) <- forall r (m :: * -> *). MonadReader r m => m r
ask
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
b
bs <- forall bytes.
Monoid bytes =>
Context -> RecordLayer bytes -> Packet13 -> IO bytes
writePacketBytes13 Context
ctx RecordLayer b
recordLayer Packet13
pkt
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Builder b)
ref (forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
bs forall a. a -> [a] -> [a]
:))