module Hookup
(
Connection,
connect,
connectWithSocket,
close,
recv,
recvLine,
send,
putBuf,
ConnectionParams(..),
SocksParams(..),
TlsParams(..),
PEM.PemPasswordSupply(..),
defaultTlsParams,
ConnectionFailure(..),
CommandReply(..)
, getClientCertificate
, getPeerCertificate
, getPeerCertFingerprintSha1
, getPeerCertFingerprintSha256
, getPeerCertFingerprintSha512
, getPeerPubkeyFingerprintSha1
, getPeerPubkeyFingerprintSha256
, getPeerPubkeyFingerprintSha512
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import System.IO.Error (isDoesNotExistError, ioeGetErrorString)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import Data.Foldable
import Data.List (intercalate, partition)
import Network.Socket (AddrInfo, HostName, PortNumber, SockAddr, Socket, Family)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as SocketB
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL as SSL
import qualified OpenSSL.Session as SSL
import OpenSSL.X509.SystemStore
import OpenSSL.X509 (X509)
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.PEM as PEM
import qualified OpenSSL.EVP.Digest as Digest
import Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as Parser
import Hookup.OpenSSL (installVerification, getPubKeyDer)
import Hookup.Socks5
data ConnectionParams = ConnectionParams
{ cpHost :: HostName
, cpPort :: PortNumber
, cpSocks :: Maybe SocksParams
, cpTls :: Maybe TlsParams
, cpBind :: Maybe HostName
}
data SocksParams = SocksParams
{ spHost :: HostName
, spPort :: PortNumber
}
data TlsParams = TlsParams
{ tpClientCertificate :: Maybe FilePath
, tpClientPrivateKey :: Maybe FilePath
, tpClientPrivateKeyPassword :: PEM.PemPasswordSupply
, tpServerCertificate :: Maybe FilePath
, tpCipherSuite :: String
, tpInsecure :: Bool
}
data ConnectionFailure
= HostnameResolutionFailure HostName String
| ConnectionFailure [IOError]
| LineTooLong
| LineTruncated
| SocksError CommandReply
| SocksAuthenticationError
| SocksProtocolError
| SocksBadDomainName
deriving Show
instance Exception ConnectionFailure where
displayException LineTruncated = "connection closed while reading line"
displayException LineTooLong = "line length exceeded maximum"
displayException (ConnectionFailure xs) =
"connection attempt failed due to: " ++
intercalate ", " (map displayException xs)
displayException (HostnameResolutionFailure h s) =
"hostname resolution failed (" ++ h ++ "): " ++ s
displayException SocksAuthenticationError =
"SOCKS authentication method rejected"
displayException SocksProtocolError =
"SOCKS server protocol error"
displayException SocksBadDomainName =
"SOCKS domain name length limit exceeded"
displayException (SocksError reply) =
"SOCKS command rejected: " ++
case reply of
Succeeded -> "succeeded"
GeneralFailure -> "general SOCKS server failure"
NotAllowed -> "connection not allowed by ruleset"
NetUnreachable -> "network unreachable"
HostUnreachable -> "host unreachable"
ConnectionRefused -> "connection refused"
TTLExpired -> "TTL expired"
CmdNotSupported -> "command not supported"
AddrNotSupported -> "address type not supported"
CommandReply n -> "unknown reply " ++ show n
defaultTlsParams :: TlsParams
defaultTlsParams = TlsParams
{ tpClientCertificate = Nothing
, tpClientPrivateKey = Nothing
, tpClientPrivateKeyPassword = PEM.PwNone
, tpServerCertificate = Nothing
, tpCipherSuite = "HIGH"
, tpInsecure = False
}
openSocket :: ConnectionParams -> IO Socket
openSocket params =
case cpSocks params of
Nothing -> openSocket' (cpHost params) (cpPort params) (cpBind params)
Just sp ->
do sock <- openSocket' (spHost sp) (spPort sp) (cpBind params)
(sock <$ socksConnect sock (cpHost params) (cpPort params))
`onException` Socket.close sock
netParse :: Show a => Socket -> Parser a -> IO a
netParse sock parser =
do
result <- Parser.parseWith
(SocketB.recv sock 1)
parser
B.empty
case result of
Parser.Done i x | B.null i -> return x
_ -> throwIO SocksProtocolError
socksConnect :: Socket -> HostName -> PortNumber -> IO ()
socksConnect sock host port =
do SocketB.sendAll sock $
buildClientHello ClientHello
{ cHelloMethods = [AuthNoAuthenticationRequired] }
validateHello =<< netParse sock parseServerHello
let dnBytes = B8.pack host
unless (B.length dnBytes < 256)
(throwIO SocksBadDomainName)
SocketB.sendAll sock $
buildRequest Request
{ reqCommand = Connect
, reqAddress = Address (DomainName dnBytes) port
}
validateResponse =<< netParse sock parseResponse
validateHello :: ServerHello -> IO ()
validateHello hello =
unless (sHelloMethod hello == AuthNoAuthenticationRequired)
(throwIO SocksAuthenticationError)
validateResponse :: Response -> IO ()
validateResponse response =
unless (rspReply response == Succeeded )
(throwIO (SocksError (rspReply response)))
openSocket' ::
HostName ->
PortNumber ->
Maybe HostName ->
IO Socket
openSocket' h p mbBind =
do mbSrc <- traverse (resolve Nothing) mbBind
dst <- resolve (Just p) h
let pairs = interleaveAddressFamilies (matchBindAddrs mbSrc dst)
when (null pairs)
(throwIO (HostnameResolutionFailure h "No source/destination address family match"))
attempt pairs
hints :: AddrInfo
hints = Socket.defaultHints
{ Socket.addrSocketType = Socket.Stream
, Socket.addrFlags = [Socket.AI_ADDRCONFIG, Socket.AI_NUMERICSERV]
}
resolve :: Maybe PortNumber -> HostName -> IO [AddrInfo]
resolve mbPort host =
do res <- try (Socket.getAddrInfo (Just hints) (Just host) (show<$>mbPort))
case res of
Right ais -> return ais
Left ioe
| isDoesNotExistError ioe ->
throwIO (HostnameResolutionFailure host (ioeGetErrorString ioe))
| otherwise -> throwIO ioe
matchBindAddrs :: Maybe [AddrInfo] -> [AddrInfo] -> [(Maybe SockAddr, AddrInfo)]
matchBindAddrs Nothing dst = [ (Nothing, x) | x <- dst ]
matchBindAddrs (Just src) dst =
[ (Just (Socket.addrAddress s), d)
| d <- dst
, let ss = [s | s <- src, Socket.addrFamily d == Socket.addrFamily s]
, s <- take 1 ss ]
connAttemptDelay :: Int
connAttemptDelay = 150 * 1000
attempt ::
[(Maybe SockAddr, AddrInfo)] ->
IO Socket
attempt xs =
do comm <- newEmptyMVar
let mkThread i (mbSrc, ai) =
forkIOWithUnmask $ \unmask ->
unmask $
do threadDelay (connAttemptDelay * i)
putMVar comm =<< try (connectToAddrInfo mbSrc ai)
bracket (zipWithM mkThread [0..] xs)
(traverse_ killThread)
(\_ -> gather (length xs) [] comm)
gather ::
Int ->
[IOError] ->
MVar (Either IOError Socket) ->
IO Socket
gather 0 exs _ = throwIO (ConnectionFailure exs)
gather n exs comm =
do res <- takeMVar comm
case res of
Right s -> pure s
Left ex -> gather (n-1) (ex:exs) comm
interleaveAddressFamilies :: [(Maybe SockAddr, AddrInfo)] -> [(Maybe SockAddr, AddrInfo)]
interleaveAddressFamilies xs = interleave sixes others
where
(sixes, others) = partition is6 xs
is6 x = Socket.AF_INET6 == Socket.addrFamily (snd x)
interleave (x:xs) (y:ys) = x : y : interleave xs ys
interleave [] ys = ys
interleave xs [] = xs
connectToAddrInfo :: Maybe SockAddr -> AddrInfo -> IO Socket
connectToAddrInfo mbSrc info
= bracketOnError (socket' info) Socket.close $ \s ->
do traverse_ (bind' s) mbSrc
Socket.connect s (Socket.addrAddress info)
pure s
bind' :: Socket -> SockAddr -> IO ()
bind' _ (Socket.SockAddrInet _ 0) = pure ()
bind' _ (Socket.SockAddrInet6 _ _ (0,0,0,0) _) = pure ()
bind' s a = Socket.bind s a
socket' :: AddrInfo -> IO Socket
socket' ai =
Socket.socket
(Socket.addrFamily ai)
(Socket.addrSocketType ai)
(Socket.addrProtocol ai)
data NetworkHandle = SSL (Maybe X509) SSL | Socket Socket
openNetworkHandle ::
ConnectionParams ->
IO Socket ->
IO NetworkHandle
openNetworkHandle params mkSocket =
case cpTls params of
Nothing -> Socket <$> mkSocket
Just tls ->
do (clientCert, ssl) <- startTls tls (cpHost params) mkSocket
pure (SSL clientCert ssl)
closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle (Socket s) = Socket.close s
closeNetworkHandle (SSL _ s) =
do SSL.shutdown s SSL.Unidirectional
traverse_ Socket.close (SSL.sslSocket s)
networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend (Socket s) = SocketB.sendAll s
networkSend (SSL _ s) = SSL.write s
networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv (Socket s) = SocketB.recv s
networkRecv (SSL _ s) = SSL.read s
data Connection = Connection (MVar ByteString) NetworkHandle
connect ::
ConnectionParams ->
IO Connection
connect params =
do h <- openNetworkHandle params (openSocket params)
b <- newMVar B.empty
return (Connection b h)
connectWithSocket ::
ConnectionParams ->
Socket ->
IO Connection
connectWithSocket params sock =
do h <- openNetworkHandle params (return sock)
b <- newMVar B.empty
return (Connection b h)
close ::
Connection ->
IO ()
close (Connection _ h) = closeNetworkHandle h
recv ::
Connection ->
Int ->
IO ByteString
recv (Connection buf h) n =
do bufChunk <- swapMVar buf B.empty
if B.null bufChunk
then networkRecv h n
else return bufChunk
recvLine ::
Connection ->
Int ->
IO (Maybe ByteString)
recvLine (Connection buf h) n =
modifyMVar buf $ \bs ->
go (B.length bs) bs []
where
go bsn bs bss =
case B8.elemIndex '\n' bs of
Just i -> return (B.tail b,
Just (cleanEnd (B.concat (reverse (a:bss)))))
where
(a,b) = B.splitAt i bs
Nothing ->
do when (bsn >= n) (throwIO LineTooLong)
more <- networkRecv h n
if B.null more
then if bsn == 0 then return (B.empty, Nothing)
else throwIO LineTruncated
else go (bsn + B.length more) more (bs:bss)
putBuf ::
Connection ->
ByteString ->
IO ()
putBuf (Connection buf h) bs =
modifyMVar_ buf (\old -> return $! B.append bs old)
cleanEnd :: ByteString -> ByteString
cleanEnd bs
| B.null bs || B8.last bs /= '\r' = bs
| otherwise = B.init bs
send ::
Connection ->
ByteString ->
IO ()
send (Connection _ h) = networkSend h
startTls ::
TlsParams ->
String ->
IO Socket ->
IO (Maybe X509, SSL)
startTls tp hostname mkSocket = SSL.withOpenSSL $
do ctx <- SSL.context
SSL.contextSetCiphers ctx (tpCipherSuite tp)
installVerification ctx hostname
SSL.contextSetVerificationMode ctx (verificationMode (tpInsecure tp))
SSL.contextAddOption ctx SSL.SSL_OP_ALL
SSL.contextRemoveOption ctx SSL.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS
setupCaCertificates ctx (tpServerCertificate tp)
clientCert <- traverse (setupCertificate ctx) (tpClientCertificate tp)
traverse_ (setupPrivateKey ctx (tpClientPrivateKeyPassword tp)) (tpClientPrivateKey tp)
ssl <- SSL.connection ctx =<< mkSocket
SSL.setTlsextHostName ssl hostname
SSL.connect ssl
return (clientCert, ssl)
setupCaCertificates :: SSLContext -> Maybe FilePath -> IO ()
setupCaCertificates ctx mbPath =
case mbPath of
Nothing -> contextLoadSystemCerts ctx
Just path -> SSL.contextSetCAFile ctx path
setupCertificate :: SSLContext -> FilePath -> IO X509
setupCertificate ctx path =
do x509 <- PEM.readX509 =<< readFile path
SSL.contextSetCertificate ctx x509
pure x509
setupPrivateKey :: SSLContext -> PEM.PemPasswordSupply -> FilePath -> IO ()
setupPrivateKey ctx password path =
do str <- readFile path
key <- PEM.readPrivateKey str password
SSL.contextSetPrivateKey ctx key
verificationMode :: Bool -> SSL.VerificationMode
verificationMode insecure
| insecure = SSL.VerifyNone
| otherwise = SSL.VerifyPeer
{ SSL.vpFailIfNoPeerCert = True
, SSL.vpClientOnce = True
, SSL.vpCallback = Nothing
}
getPeerCertificate :: Connection -> IO (Maybe X509.X509)
getPeerCertificate (Connection _ h) =
case h of
Socket{} -> return Nothing
SSL _ ssl -> SSL.getPeerCertificate ssl
getClientCertificate :: Connection -> Maybe X509.X509
getClientCertificate (Connection _ h) =
case h of
Socket{} -> Nothing
SSL c _ -> c
getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 = getPeerCertFingerprint "sha1"
getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 = getPeerCertFingerprint "sha256"
getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 = getPeerCertFingerprint "sha512"
getPeerCertFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint name h =
do mb <- getPeerCertificate h
case mb of
Nothing -> return Nothing
Just x509 ->
do der <- X509.writeDerX509 x509
mbdigest <- Digest.getDigestByName name
case mbdigest of
Nothing -> return Nothing
Just digest -> return $! Just $! Digest.digestLBS digest der
getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 = getPeerPubkeyFingerprint "sha1"
getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 = getPeerPubkeyFingerprint "sha256"
getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 = getPeerPubkeyFingerprint "sha512"
getPeerPubkeyFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint name h =
do mb <- getPeerCertificate h
case mb of
Nothing -> return Nothing
Just x509 ->
do der <- getPubKeyDer x509
mbdigest <- Digest.getDigestByName name
case mbdigest of
Nothing -> return Nothing
Just digest -> return $! Just $! Digest.digestBS digest der