{-# LANGUAGE ScopedTypeVariables #-}
module System.IO.Streams.TLS
( TLSConnection
, connect
, connectTLS
, tLsToConnection
, accept
, module Data.TLSSetting
) where
import qualified Control.Exception as E
import Data.Connection
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import Data.TLSSetting
import qualified Network.Socket as N
import Network.TLS (ClientParams, Context, ServerParams)
import qualified Network.TLS as TLS
import qualified System.IO.Streams as Stream
import qualified System.IO.Streams.TCP as TCP
type TLSConnection = Connection (TLS.Context, N.SockAddr)
tLsToConnection :: (Context, N.SockAddr)
-> IO TLSConnection
tLsToConnection :: (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr) = do
InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
Stream.makeInputStream IO (Maybe ByteString)
input
TLSConnection -> IO TLSConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (InputStream ByteString
-> (ByteString -> IO ())
-> IO ()
-> (Context, SockAddr)
-> TLSConnection
forall a.
InputStream ByteString
-> (ByteString -> IO ()) -> IO () -> a -> Connection a
Connection InputStream ByteString
is ByteString -> IO ()
forall {m :: * -> *}. MonadIO m => ByteString -> m ()
write (Context -> IO ()
closeTLS Context
ctx) (Context
ctx, SockAddr
addr))
where
input :: IO (Maybe ByteString)
input = (do
ByteString
s <- Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
B.null ByteString
s then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
s
) IO (Maybe ByteString)
-> (SomeException -> IO (Maybe ByteString))
-> IO (Maybe ByteString)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing)
write :: ByteString -> m ()
write ByteString
s = Context -> ByteString -> m ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx ByteString
s
closeTLS :: Context -> IO ()
closeTLS :: Context -> IO ()
closeTLS Context
ctx = (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
TLS.contextClose Context
ctx)
IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
connectTLS :: ClientParams
-> Maybe String
-> N.HostName
-> N.PortNumber
-> IO (Context, N.SockAddr)
connectTLS :: ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port = do
let subname' :: String
subname' = String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
host String -> String
forall a. a -> a
id Maybe String
subname
prms' :: ClientParams
prms' = ClientParams
prms { TLS.clientServerIdentification = (subname', BC.pack (show port)) }
(Socket
sock, SockAddr
addr) <- String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
host PortNumber
port
IO Context
-> (Context -> IO ())
-> (Context -> IO (Context, SockAddr))
-> IO (Context, SockAddr)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
prms') Context -> IO ()
closeTLS ((Context -> IO (Context, SockAddr)) -> IO (Context, SockAddr))
-> (Context -> IO (Context, SockAddr)) -> IO (Context, SockAddr)
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
(Context, SockAddr) -> IO (Context, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Context
ctx, SockAddr
addr)
connect :: ClientParams
-> Maybe String
-> N.HostName
-> N.PortNumber
-> IO TLSConnection
connect :: ClientParams
-> Maybe String -> String -> PortNumber -> IO TLSConnection
connect ClientParams
prms Maybe String
subname String
host PortNumber
port = ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port IO (Context, SockAddr)
-> ((Context, SockAddr) -> IO TLSConnection) -> IO TLSConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Context, SockAddr) -> IO TLSConnection
tLsToConnection
accept :: ServerParams
-> N.Socket
-> IO TLSConnection
accept :: ServerParams -> Socket -> IO TLSConnection
accept ServerParams
prms Socket
sock = do
(Socket
sock', SockAddr
addr) <- Socket -> IO (Socket, SockAddr)
N.accept Socket
sock
IO Context
-> (Context -> IO ())
-> (Context -> IO TLSConnection)
-> IO TLSConnection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock' ServerParams
prms) Context -> IO ()
closeTLS ((Context -> IO TLSConnection) -> IO TLSConnection)
-> (Context -> IO TLSConnection) -> IO TLSConnection
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
TLSConnection
conn <- (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr)
TLSConnection -> IO TLSConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return TLSConnection
conn