module Network.QUIC.Connector where

import Data.IORef
import Network.QUIC.Types
import UnliftIO.STM

class Connector a where
    getRole :: a -> Role
    getEncryptionLevel :: a -> IO EncryptionLevel
    getMaxPacketSize :: a -> IO Int
    getConnectionState :: a -> IO ConnectionState
    getPacketNumber :: a -> IO PacketNumber
    getAlive :: a -> IO Bool

----------------------------------------------------------------

data ConnState = ConnState
    { ConnState -> Role
role :: Role
    , ConnState -> TVar ConnectionState
connectionState :: TVar ConnectionState
    , ConnState -> IORef PacketNumber
packetNumber :: IORef PacketNumber -- squeezing three to one
    , ConnState -> TVar EncryptionLevel
encryptionLevel :: TVar EncryptionLevel -- to synchronize
    , ConnState -> IORef PacketNumber
maxPacketSize :: IORef Int
    , -- Explicitly separated from 'ConnectionState'
      -- It seems that STM triggers a dead-lock if
      -- it is used in the close function of bracket.
      ConnState -> IORef Bool
connectionAlive :: IORef Bool
    }

newConnState :: Role -> IO ConnState
newConnState :: Role -> IO ConnState
newConnState Role
rl =
    Role
-> TVar ConnectionState
-> IORef PacketNumber
-> TVar EncryptionLevel
-> IORef PacketNumber
-> IORef Bool
-> ConnState
ConnState Role
rl
        (TVar ConnectionState
 -> IORef PacketNumber
 -> TVar EncryptionLevel
 -> IORef PacketNumber
 -> IORef Bool
 -> ConnState)
-> IO (TVar ConnectionState)
-> IO
     (IORef PacketNumber
      -> TVar EncryptionLevel
      -> IORef PacketNumber
      -> IORef Bool
      -> ConnState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnectionState -> IO (TVar ConnectionState)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO ConnectionState
Handshaking
        IO
  (IORef PacketNumber
   -> TVar EncryptionLevel
   -> IORef PacketNumber
   -> IORef Bool
   -> ConnState)
-> IO (IORef PacketNumber)
-> IO
     (TVar EncryptionLevel
      -> IORef PacketNumber -> IORef Bool -> ConnState)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PacketNumber -> IO (IORef PacketNumber)
forall a. a -> IO (IORef a)
newIORef PacketNumber
0
        IO
  (TVar EncryptionLevel
   -> IORef PacketNumber -> IORef Bool -> ConnState)
-> IO (TVar EncryptionLevel)
-> IO (IORef PacketNumber -> IORef Bool -> ConnState)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EncryptionLevel -> IO (TVar EncryptionLevel)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO EncryptionLevel
InitialLevel
        IO (IORef PacketNumber -> IORef Bool -> ConnState)
-> IO (IORef PacketNumber) -> IO (IORef Bool -> ConnState)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PacketNumber -> IO (IORef PacketNumber)
forall a. a -> IO (IORef a)
newIORef PacketNumber
defaultQUICPacketSize
        IO (IORef Bool -> ConnState) -> IO (IORef Bool) -> IO ConnState
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True

----------------------------------------------------------------

data Role = Client | Server deriving (Role -> Role -> Bool
(Role -> Role -> Bool) -> (Role -> Role -> Bool) -> Eq Role
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Role -> Role -> Bool
== :: Role -> Role -> Bool
$c/= :: Role -> Role -> Bool
/= :: Role -> Role -> Bool
Eq, PacketNumber -> Role -> ShowS
[Role] -> ShowS
Role -> String
(PacketNumber -> Role -> ShowS)
-> (Role -> String) -> ([Role] -> ShowS) -> Show Role
forall a.
(PacketNumber -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: PacketNumber -> Role -> ShowS
showsPrec :: PacketNumber -> Role -> ShowS
$cshow :: Role -> String
show :: Role -> String
$cshowList :: [Role] -> ShowS
showList :: [Role] -> ShowS
Show)

isClient :: Connector a => a -> Bool
isClient :: forall a. Connector a => a -> Bool
isClient a
conn = a -> Role
forall a. Connector a => a -> Role
getRole a
conn Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
Client

isServer :: Connector a => a -> Bool
isServer :: forall a. Connector a => a -> Bool
isServer a
conn = a -> Role
forall a. Connector a => a -> Role
getRole a
conn Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
Server

----------------------------------------------------------------

data ConnectionState
    = Handshaking
    | ReadyFor0RTT
    | ReadyFor1RTT
    | Established
    deriving (ConnectionState -> ConnectionState -> Bool
(ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> Eq ConnectionState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionState -> ConnectionState -> Bool
== :: ConnectionState -> ConnectionState -> Bool
$c/= :: ConnectionState -> ConnectionState -> Bool
/= :: ConnectionState -> ConnectionState -> Bool
Eq, Eq ConnectionState
Eq ConnectionState =>
(ConnectionState -> ConnectionState -> Ordering)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> ConnectionState)
-> (ConnectionState -> ConnectionState -> ConnectionState)
-> Ord ConnectionState
ConnectionState -> ConnectionState -> Bool
ConnectionState -> ConnectionState -> Ordering
ConnectionState -> ConnectionState -> ConnectionState
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ConnectionState -> ConnectionState -> Ordering
compare :: ConnectionState -> ConnectionState -> Ordering
$c< :: ConnectionState -> ConnectionState -> Bool
< :: ConnectionState -> ConnectionState -> Bool
$c<= :: ConnectionState -> ConnectionState -> Bool
<= :: ConnectionState -> ConnectionState -> Bool
$c> :: ConnectionState -> ConnectionState -> Bool
> :: ConnectionState -> ConnectionState -> Bool
$c>= :: ConnectionState -> ConnectionState -> Bool
>= :: ConnectionState -> ConnectionState -> Bool
$cmax :: ConnectionState -> ConnectionState -> ConnectionState
max :: ConnectionState -> ConnectionState -> ConnectionState
$cmin :: ConnectionState -> ConnectionState -> ConnectionState
min :: ConnectionState -> ConnectionState -> ConnectionState
Ord, PacketNumber -> ConnectionState -> ShowS
[ConnectionState] -> ShowS
ConnectionState -> String
(PacketNumber -> ConnectionState -> ShowS)
-> (ConnectionState -> String)
-> ([ConnectionState] -> ShowS)
-> Show ConnectionState
forall a.
(PacketNumber -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: PacketNumber -> ConnectionState -> ShowS
showsPrec :: PacketNumber -> ConnectionState -> ShowS
$cshow :: ConnectionState -> String
show :: ConnectionState -> String
$cshowList :: [ConnectionState] -> ShowS
showList :: [ConnectionState] -> ShowS
Show)

isConnectionEstablished :: Connector a => a -> IO Bool
isConnectionEstablished :: forall a. Connector a => a -> IO Bool
isConnectionEstablished a
conn = do
    ConnectionState
st <- a -> IO ConnectionState
forall a. Connector a => a -> IO ConnectionState
getConnectionState a
conn
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ case ConnectionState
st of
        ConnectionState
Established -> Bool
True
        ConnectionState
_ -> Bool
False