{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports #-}

module Network.Xmpp.Tls where

import                 Control.Applicative ((<$>))
import qualified       Control.Exception.Lifted as Ex
import                 Control.Monad
import                 Control.Monad.Error
import                 Control.Monad.State.Strict
import "crypto-random" Crypto.Random
import qualified       Data.ByteString as BS
import qualified       Data.ByteString.Char8 as BSC8
import qualified       Data.ByteString.Lazy as BL
import                 Data.Conduit
import                 Data.IORef
import                 Data.XML.Types
import                 Network.DNS.Resolver (ResolvConf)
import                 Network.TLS
import                 Network.Xmpp.Stream
import                 Network.Xmpp.Types
import                 System.Log.Logger (debugM, errorM, infoM)

mkBackend :: StreamHandle -> Backend
mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs)
                        , backendRecv = bufferReceive (streamReceive con)
                        , backendFlush = streamFlush con
                        , backendClose = streamClose con
                        }
  where
    bufferReceive _ 0 = return BS.empty
    bufferReceive recv n = BS.concat `liftM` (go n)
      where
        go m = do
            mbBs <- recv m
            bs <- case mbBs of
                Left e -> Ex.throwIO e
                Right r -> return r
            case BS.length bs of
                0 -> return []
                l -> if l < m
                     then (bs :) `liftM` go (m - l)
                     else return [bs]

starttlsE :: Element
starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []

-- | Checks for TLS support and run starttls procedure if applicable
tls :: Stream -> IO (Either XmppFailure ())
tls con = fmap join -- We can have Left values both from exceptions and the
                    -- error monad. Join unifies them into one error layer
          . wrapExceptions
          . flip withStream con
          . runErrorT $ do
    conf <- gets $ streamConfiguration
    sState <- gets streamConnectionState
    case sState of
        Plain -> return ()
        Closed -> do
            liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is closed."
            throwError XmppNoStream
        Finished -> do
            liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is finished."
            throwError XmppNoStream
        Secured -> do
            liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is already secured."
            throwError TlsStreamSecured
    features <- lift $ gets streamFeatures
    case (tlsBehaviour conf, streamTls features) of
        (RequireTls  , Just _   ) -> startTls
        (RequireTls  , Nothing  ) -> throwError TlsNoServerSupport
        (PreferTls   , Just _   ) -> startTls
        (PreferTls   , Nothing  ) -> skipTls
        (PreferPlain , Just True) -> startTls
        (PreferPlain , _        ) -> skipTls
        (RefuseTls   , Just True) -> throwError XmppOtherFailure
        (RefuseTls   , _        ) -> skipTls
  where
    skipTls = liftIO $ infoM "Pontarius.Xmpp.Tls" "Skipping TLS negotiation"
    startTls = do
        liftIO $ infoM "Pontarius.Xmpp.Tls" "Running StartTLS"
        params <- gets $ tlsParams . streamConfiguration
        ErrorT $ pushElement starttlsE
        answer <- lift $ pullElement
        case answer of
            Left e -> throwError e
            Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) ->
                return ()
            Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}failure" _ _) -> do
                liftIO $ errorM "Pontarius.Xmpp" "startTls: TLS initiation failed."
                throwError XmppOtherFailure
            Right r ->
                liftIO $ errorM "Pontarius.Xmpp.Tls" $
                            "Unexpected element: " ++ show r
        hand <- gets streamHandle
        (_raw, _snk, psh, recv, ctx) <- lift $ tlsinit params (mkBackend hand)
        let newHand = StreamHandle { streamSend = catchPush . psh
                                   , streamReceive = wrapExceptions . recv
                                   , streamFlush = contextFlush ctx
                                   , streamClose = bye ctx >> streamClose hand
                                   }
        lift $ modify ( \x -> x {streamHandle = newHand})
        liftIO $ infoM "Pontarius.Xmpp.Tls" "Stream Secured."
        either (lift . Ex.throwIO) return =<< lift restartStream
        modify (\s -> s{streamConnectionState = Secured})
        return ()

client :: (MonadIO m, CPRG rng) => ClientParams -> rng -> Backend -> m Context
client params gen backend  = do
    contextNew backend params gen

tlsinit :: (MonadIO m, MonadIO m1) =>
        ClientParams
     -> Backend
     -> m ( Source m1 BS.ByteString
          , Sink BS.ByteString m1 ()
          , BS.ByteString -> IO ()
          , Int -> m1 BS.ByteString
          , Context
          )
tlsinit params backend = do
    liftIO $ debugM "Pontarius.Xmpp.Tls" "TLS with debug mode enabled."
    gen <- liftIO (cprgCreate <$> createEntropyPool :: IO SystemRNG)
    con <- client params gen backend
    handshake con
    let src = forever $ do
            dt <- liftIO $ recvData con
            liftIO $ debugM "Pontarius.Xmpp.Tls" ("In :" ++ BSC8.unpack dt)
            yield dt
    let snk = do
            d <- await
            case d of
                Nothing -> return ()
                Just x -> do
                       sendData con (BL.fromChunks [x])
                       snk
    readWithBuffer <- liftIO $ mkReadBuffer (recvData con)
    return ( src
           , snk
             -- Note: sendData already sends the data to the debug output
           , \s -> sendData con $ BL.fromChunks [s]
           , liftIO . readWithBuffer
           , con
           )

mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
mkReadBuffer recv = do
    buffer <- newIORef BS.empty
    let read' n = do
            nc <- readIORef buffer
            bs <- if BS.null nc then recv
                                else return nc
            let (result, rest) = BS.splitAt n bs
            writeIORef buffer rest
            return result
    return read'

-- | Connect to an XMPP server and secure the connection with TLS before
-- starting the XMPP streams
--
-- /NB/ RFC 6120 does not specify this method, but some servers, notably GCS,
-- seem to use it.
connectTls :: ResolvConf -- ^ Resolv conf to use (try 'defaultResolvConf' as a
                         -- default)
           -> ClientParams  -- ^ TLS parameters to use when securing the connection
           -> String     -- ^ Host to use when connecting (will be resolved
                         -- using SRV records)
           -> ErrorT XmppFailure IO StreamHandle
connectTls config params host = do
    h <- connectSrv config host >>= \h' -> case h' of
        Nothing -> throwError TcpConnectionFailure
        Just h'' -> return h''
    let hand = handleToStreamHandle h
    (_raw, _snk, psh, recv, ctx) <- tlsinit params $ mkBackend hand
    return $ StreamHandle { streamSend = catchPush . psh
                          , streamReceive = wrapExceptions . recv
                          , streamFlush = contextFlush ctx
                          , streamClose = bye ctx >> streamClose hand
                          }

wrapExceptions :: IO a -> IO (Either XmppFailure a)
wrapExceptions f = Ex.catches (liftM Right $ f)
                 [ Ex.Handler $ return . Left . XmppIOException
                 , Ex.Handler $ wrap . XmppTlsError
                 , Ex.Handler $ wrap . XmppTlsConnectionNotEstablished
                 , Ex.Handler $ wrap . XmppTlsTerminated
                 , Ex.Handler $ wrap . XmppTlsHandshakeFailed
                 , Ex.Handler $ return . Left
                 ]
  where
    wrap = return . Left . TlsError