{-# LANGUAGE TypeFamilies #-} -- Copyright (C) 2010 John Millikin -- -- This program is free software: you can redistribute it and/or modify -- it under the terms of the GNU General Public License as published by -- the Free Software Foundation, either version 3 of the License, or -- any later version. -- -- This program is distributed in the hope that it will be useful, -- but WITHOUT ANY WARRANTY; without even the implied warranty of -- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -- GNU General Public License for more details. -- -- You should have received a copy of the GNU General Public License -- along with this program. If not, see . module Network.Protocol.TLS.GNU ( TLS , TLST , Session , Error (..) , runTLS , runTLS' , runClient , getSession , handshake , rehandshake , putBytes , getBytes , checkPending -- * Settings , Transport (..) , handleTransport , Credentials , setCredentials , certificateCredentials , F.DigestAlgorithm(..) , hash ) where import qualified Control.Concurrent.MVar as M import Control.Monad (when, foldM, foldM_) import Control.Monad.Trans.Class (lift) import qualified Control.Monad.Trans.Except as E import qualified Control.Monad.Trans.Reader as R import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Unsafe as B import Data.IORef import qualified Foreign as F import qualified Foreign.C as F import Foreign.Concurrent as FC import qualified System.IO as IO import System.IO.Unsafe (unsafePerformIO) import UnexceptionalIO.Trans (Unexceptional) import qualified UnexceptionalIO.Trans as UIO import qualified Network.Protocol.TLS.GNU.Foreign as F data Error = Error Integer deriving (Show) globalInitMVar :: M.MVar () {-# NOINLINE globalInitMVar #-} globalInitMVar = unsafePerformIO $ M.newMVar () globalInit :: (Unexceptional m) => E.ExceptT Error m () globalInit = do let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init F.ReturnCode rc <- UIO.unsafeFromIO init_ when (rc < 0) $ E.throwE $ mapError rc globalDeinit :: IO () globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit data Session = Session { sessionPtr :: F.ForeignPtr F.Session -- TLS credentials are not copied into the gnutls session struct, -- so pointers to them must be kept alive until the credentials -- are no longer needed. -- -- TODO: Have some way to mark credentials as no longer needed. -- The current code just keeps them alive for the duration -- of the session, which may be excessive. , sessionCredentials :: IORef [F.ForeignPtr F.Credentials] } type TLS a = TLST IO a type TLST m a = E.ExceptT Error (R.ReaderT Session m) a runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a) runTLS s = E.runExceptT . runTLS' s runTLS' :: Session -> TLST m a -> E.ExceptT Error m a runTLS' s = E.mapExceptT (flip R.runReaderT s) runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a) runClient transport tls = do eitherSession <- newSession transport (F.ConnectionEnd 2) case eitherSession of Left err -> return (Left err) Right session -> runTLS session tls newSession :: (Unexceptional m) => Transport -> F.ConnectionEnd -> m (Either Error Session) newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do globalInit F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end when (rc < 0) $ E.throwE $ mapError rc UIO.unsafeFromIO $ do ptr <- F.peek sPtr let session = F.Session ptr push <- F.wrapTransportFunc (pushImpl transport) pull <- F.wrapTransportFunc (pullImpl transport) F.gnutls_transport_set_push_function session push F.gnutls_transport_set_pull_function session pull _ <- F.gnutls_set_default_priority session creds <- newIORef [] fp <- FC.newForeignPtr ptr $ do F.gnutls_deinit session globalDeinit F.freeHaskellFunPtr push F.freeHaskellFunPtr pull return (Session fp creds) getSession :: (Monad m) => TLST m Session getSession = lift R.ask handshake :: (Unexceptional m) => TLST m () handshake = unsafeWithSession F.gnutls_handshake >>= checkRC rehandshake :: (Unexceptional m) => TLST m () rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC putBytes :: (Unexceptional m) => BL.ByteString -> TLST m () putBytes = putChunks . BL.toChunks where putChunks chunks = do maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks case maybeErr of Nothing -> return () Just err -> E.throwE $ mapError $ fromIntegral err putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where loop ptr len = do let len' = fromIntegral len sent <- F.gnutls_record_send s ptr len' let sent' = fromIntegral sent case len - sent' of 0 -> return Nothing x | x > 0 -> loop (F.plusPtr ptr sent') x | otherwise -> return $ Just x putChunk _ err _ = return err getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString getBytes count = do (mbytes, len) <- unsafeWithSession $ \s -> F.allocaBytes (fromInteger count) $ \ptr -> do len <- F.gnutls_record_recv s ptr (fromInteger count) bytes <- if len >= 0 then do chunk <- B.packCStringLen (ptr, fromIntegral len) return $ Just $ BL.fromChunks [chunk] else return Nothing return (bytes, len) case mbytes of Just bytes -> return bytes Nothing -> E.throwE $ mapError $ fromIntegral len checkPending :: (Unexceptional m) => TLST m Integer checkPending = unsafeWithSession $ \s -> do pending <- F.gnutls_record_check_pending s return $ toInteger pending data Transport = Transport { transportPush :: BL.ByteString -> IO () , transportPull :: Integer -> IO BL.ByteString } pullImpl :: Transport -> F.TransportFunc pullImpl t _ buf bufSize = do bytes <- transportPull t $ toInteger bufSize let loop ptr chunk = B.unsafeUseAsCStringLen chunk $ \(cstr, len) -> do F.copyArray (F.castPtr ptr) cstr len return $ F.plusPtr ptr len foldM_ loop buf $ BL.toChunks bytes return $ fromIntegral $ BL.length bytes pushImpl :: Transport -> F.TransportFunc pushImpl t _ buf bufSize = do let buf' = F.castPtr buf bytes <- B.unsafePackCStringLen (buf', fromIntegral bufSize) transportPush t $ BL.fromChunks [bytes] return bufSize handleTransport :: IO.Handle -> Transport handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger) data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials) setCredentials :: (Unexceptional m) => Credentials -> TLST m () setCredentials (Credentials ctype fp) = do rc <- unsafeWithSession $ \s -> F.withForeignPtr fp $ \ptr -> do F.gnutls_credentials_set s ctype ptr s <- getSession if F.unRC rc == 0 then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) else checkRC rc certificateCredentials :: (Unexceptional m) => TLST m Credentials certificateCredentials = do (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do rc <- F.gnutls_certificate_allocate_credentials ptr ptr' <- if F.unRC rc < 0 then return F.nullPtr else F.peek ptr return (rc, ptr') checkRC rc fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr return $ Credentials (F.CredentialsType 1) fp -- | This must only be called with IO actions that do not throw NonPseudoException unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a unsafeWithSession io = do s <- getSession UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session checkRC :: (Monad m) => F.ReturnCode -> E.ExceptT Error m () checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x mapError :: F.CInt -> Error mapError = Error . toInteger hash :: (Unexceptional m) => F.DigestAlgorithm -> B.ByteString -> E.ExceptT Error m B.ByteString hash algo input = E.ExceptT $ UIO.unsafeFromIO $ F.alloca $ \hashp -> F.alloca $ \output -> E.runExceptT $ do checkRC =<< UIO.unsafeFromIO (F.gnutls_hash_init hashp (fromIntegral $ fromEnum algo)) hsh <- UIO.unsafeFromIO $ F.peek hashp (checkRC =<<) $ UIO.unsafeFromIO $ B.unsafeUseAsCStringLen input $ \(cstr, len) -> F.gnutls_hash hsh cstr (fromIntegral len) UIO.unsafeFromIO $ F.gnutls_hash_deinit hsh output UIO.unsafeFromIO $ B.unsafePackCString output