{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Record.Layer (
    RecordLayer(..)
  , newTransparentRecordLayer
  ) where

import Network.TLS.Imports
import Network.TLS.Record
import Network.TLS.Struct

import qualified Data.ByteString as B

data RecordLayer bytes = RecordLayer {
    -- Writing.hs
    RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
recordEncode    :: Record Plaintext -> IO (Either TLSError bytes)
  , RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
recordEncode13  :: Record Plaintext -> IO (Either TLSError bytes)
  , RecordLayer bytes -> bytes -> IO ()
recordSendBytes :: bytes -> IO ()

    -- Reading.hs
  , RecordLayer bytes
-> Bool -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv      :: Bool -> Int -> IO (Either TLSError (Record Plaintext))
  , RecordLayer bytes -> IO (Either TLSError (Record Plaintext))
recordRecv13    :: IO (Either TLSError (Record Plaintext))
  }

newTransparentRecordLayer :: Eq ann
                          => IO ann -> ([(ann, ByteString)] -> IO ())
                          -> IO (Either TLSError ByteString)
                          -> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer :: IO ann
-> ([(ann, ByteString)] -> IO ())
-> IO (Either TLSError ByteString)
-> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer IO ann
get [(ann, ByteString)] -> IO ()
send IO (Either TLSError ByteString)
recv = RecordLayer :: forall bytes.
(Record Plaintext -> IO (Either TLSError bytes))
-> (Record Plaintext -> IO (Either TLSError bytes))
-> (bytes -> IO ())
-> (Bool -> Int -> IO (Either TLSError (Record Plaintext)))
-> IO (Either TLSError (Record Plaintext))
-> RecordLayer bytes
RecordLayer {
    recordEncode :: Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
recordEncode    = IO ann
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
forall ann.
IO ann
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord IO ann
get
  , recordEncode13 :: Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
recordEncode13  = IO ann
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
forall ann.
IO ann
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord IO ann
get
  , recordSendBytes :: [(ann, ByteString)] -> IO ()
recordSendBytes = ([(ann, ByteString)] -> IO ()) -> [(ann, ByteString)] -> IO ()
forall ann.
Eq ann =>
([(ann, ByteString)] -> IO ()) -> [(ann, ByteString)] -> IO ()
transparentSendBytes [(ann, ByteString)] -> IO ()
send
  , recordRecv :: Bool -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv      = \Bool
_ Int
_ -> IO (Either TLSError ByteString)
-> IO (Either TLSError (Record Plaintext))
transparentRecvRecord IO (Either TLSError ByteString)
recv
  , recordRecv13 :: IO (Either TLSError (Record Plaintext))
recordRecv13    = IO (Either TLSError ByteString)
-> IO (Either TLSError (Record Plaintext))
transparentRecvRecord IO (Either TLSError ByteString)
recv
  }

transparentEncodeRecord :: IO ann -> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord :: IO ann
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord IO ann
_ (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
_) =
    Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError [(ann, ByteString)]
 -> IO (Either TLSError [(ann, ByteString)]))
-> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a b. (a -> b) -> a -> b
$ [(ann, ByteString)] -> Either TLSError [(ann, ByteString)]
forall a b. b -> Either a b
Right []
transparentEncodeRecord IO ann
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
_) =
    -- all alerts are silent and must be transported externally based on
    -- TLS exceptions raised by the library
    Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError [(ann, ByteString)]
 -> IO (Either TLSError [(ann, ByteString)]))
-> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a b. (a -> b) -> a -> b
$ [(ann, ByteString)] -> Either TLSError [(ann, ByteString)]
forall a b. b -> Either a b
Right []
transparentEncodeRecord IO ann
get (Record ProtocolType
_ Version
_ Fragment Plaintext
frag) =
    IO ann
get IO ann
-> (ann -> IO (Either TLSError [(ann, ByteString)]))
-> IO (Either TLSError [(ann, ByteString)])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ann
a -> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError [(ann, ByteString)]
 -> IO (Either TLSError [(ann, ByteString)]))
-> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a b. (a -> b) -> a -> b
$ [(ann, ByteString)] -> Either TLSError [(ann, ByteString)]
forall a b. b -> Either a b
Right [(ann
a, Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
frag)]

transparentSendBytes :: Eq ann => ([(ann, ByteString)] -> IO ()) -> [(ann, ByteString)] -> IO ()
transparentSendBytes :: ([(ann, ByteString)] -> IO ()) -> [(ann, ByteString)] -> IO ()
transparentSendBytes [(ann, ByteString)] -> IO ()
send [(ann, ByteString)]
input = [(ann, ByteString)] -> IO ()
send
    [ (ann
a, ByteString
bs) | (ann
a, [ByteString]
frgs) <- [(ann, ByteString)] -> [(ann, [ByteString])]
forall ann val. Eq ann => [(ann, val)] -> [(ann, [val])]
compress [(ann, ByteString)]
input
              , let bs :: ByteString
bs = [ByteString] -> ByteString
B.concat [ByteString]
frgs
              , Bool -> Bool
not (ByteString -> Bool
B.null ByteString
bs)
    ]

transparentRecvRecord :: IO (Either TLSError ByteString)
                      -> IO (Either TLSError (Record Plaintext))
transparentRecvRecord :: IO (Either TLSError ByteString)
-> IO (Either TLSError (Record Plaintext))
transparentRecvRecord IO (Either TLSError ByteString)
recv =
    (ByteString -> Record Plaintext)
-> Either TLSError ByteString -> Either TLSError (Record Plaintext)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
ProtocolType_Handshake Version
TLS12 (Fragment Plaintext -> Record Plaintext)
-> (ByteString -> Fragment Plaintext)
-> ByteString
-> Record Plaintext
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Fragment Plaintext
fragmentPlaintext) (Either TLSError ByteString -> Either TLSError (Record Plaintext))
-> IO (Either TLSError ByteString)
-> IO (Either TLSError (Record Plaintext))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Either TLSError ByteString)
recv

compress :: Eq ann => [(ann, val)] -> [(ann, [val])]
compress :: [(ann, val)] -> [(ann, [val])]
compress []         = []
compress ((ann
a,val
v):[(ann, val)]
xs) =
    let ([(ann, val)]
ys, [(ann, val)]
zs) = ((ann, val) -> Bool)
-> [(ann, val)] -> ([(ann, val)], [(ann, val)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((ann -> ann -> Bool
forall a. Eq a => a -> a -> Bool
== ann
a) (ann -> Bool) -> ((ann, val) -> ann) -> (ann, val) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ann, val) -> ann
forall a b. (a, b) -> a
fst) [(ann, val)]
xs
     in (ann
a, val
v val -> [val] -> [val]
forall a. a -> [a] -> [a]
: ((ann, val) -> val) -> [(ann, val)] -> [val]
forall a b. (a -> b) -> [a] -> [b]
map (ann, val) -> val
forall a b. (a, b) -> b
snd [(ann, val)]
ys) (ann, [val]) -> [(ann, [val])] -> [(ann, [val])]
forall a. a -> [a] -> [a]
: [(ann, val)] -> [(ann, [val])]
forall ann val. Eq ann => [(ann, val)] -> [(ann, [val])]
compress [(ann, val)]
zs