{-# LANGUAGE ScopedTypeVariables #-}
module System.IO.Streams.SSL
( connect
, withConnection
, sslToStreams
) where
import qualified Control.Exception as E
import Control.Monad (void)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL.Session as SSL
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
bUFSIZ :: Int
bUFSIZ = 32752
sslToStreams :: SSL
-> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams ssl = do
is <- Streams.makeInputStream input
os <- Streams.makeOutputStream output
return $! (is, os)
where
input = do
s <- SSL.read ssl bUFSIZ
return $! if S.null s then Nothing else Just s
output Nothing = return $! ()
output (Just s) = SSL.write ssl s
connect :: SSLContext
-> HostName
-> PortNumber
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
connect ctx host port = do
(addrInfo:_) <- N.getAddrInfo (Just hints) (Just host) (Just $ show port)
let family = N.addrFamily addrInfo
let socketType = N.addrSocketType addrInfo
let protocol = N.addrProtocol addrInfo
let address = N.addrAddress addrInfo
E.bracketOnError (N.socket family socketType protocol)
N.close
(\sock -> do N.connect sock address
ssl <- SSL.connection ctx sock
SSL.connect ssl
(is, os) <- sslToStreams ssl
return $! (is, os, ssl)
)
where
hints = N.defaultHints {
N.addrFlags = [N.AI_ADDRCONFIG, N.AI_NUMERICSERV]
, N.addrSocketType = N.Stream
}
withConnection ::
SSLContext
-> HostName
-> PortNumber
-> (InputStream ByteString -> OutputStream ByteString -> SSL -> IO a)
-> IO a
withConnection ctx host port action = do
(addrInfo:_) <- N.getAddrInfo (Just hints) (Just host) (Just $ show port)
E.bracket (connectTo addrInfo) cleanup go
where
go (is, os, ssl, _) = action is os ssl
connectTo addrInfo = do
let family = N.addrFamily addrInfo
let socketType = N.addrSocketType addrInfo
let protocol = N.addrProtocol addrInfo
let address = N.addrAddress addrInfo
E.bracketOnError (N.socket family socketType protocol)
N.close
(\sock -> do N.connect sock address
ssl <- SSL.connection ctx sock
SSL.connect ssl
(is, os) <- sslToStreams ssl
return $! (is, os, ssl, sock))
cleanup (_, os, ssl, sock) = E.mask_ $ do
eatException $! Streams.write Nothing os
eatException $! SSL.shutdown ssl $! SSL.Unidirectional
eatException $! N.close sock
hints = N.defaultHints {
N.addrFlags = [N.AI_ADDRCONFIG, N.AI_NUMERICSERV]
, N.addrSocketType = N.Stream
}
eatException m = void m `E.catch` (\(_::E.SomeException) -> return $! ())