{-# LANGUAGE NamedFieldPuns #-}

module Network.HTTP2.Arch.Context where

import Data.IORef
import Network.HTTP.Types (Method)
import Network.Socket (SockAddr)
import UnliftIO.STM

import Imports hiding (insert)
import Network.HPACK
import Network.HTTP2.Arch.Cache (Cache, emptyCache)
import qualified Network.HTTP2.Arch.Cache as Cache
import Network.HTTP2.Arch.Rate
import Network.HTTP2.Arch.Stream
import Network.HTTP2.Arch.Types
import Network.HTTP2.Frame

data Role = Client | Server deriving (Role -> Role -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Role -> Role -> Bool
$c/= :: Role -> Role -> Bool
== :: Role -> Role -> Bool
$c== :: Role -> Role -> Bool
Eq,Int -> Role -> ShowS
[Role] -> ShowS
Role -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Role] -> ShowS
$cshowList :: [Role] -> ShowS
show :: Role -> String
$cshow :: Role -> String
showsPrec :: Int -> Role -> ShowS
$cshowsPrec :: Int -> Role -> ShowS
Show)

----------------------------------------------------------------

data RoleInfo = RIS ServerInfo | RIC ClientInfo

data ServerInfo = ServerInfo {
    ServerInfo -> TQueue (Input Stream)
inputQ :: TQueue (Input Stream)
  }

data ClientInfo = ClientInfo {
    ClientInfo -> Method
scheme    :: ByteString
  , ClientInfo -> Method
authority :: ByteString
  , ClientInfo -> IORef (Cache (Method, Method) Stream)
cache     :: IORef (Cache (Method,ByteString) Stream)
  }

toServerInfo :: RoleInfo -> ServerInfo
toServerInfo :: RoleInfo -> ServerInfo
toServerInfo (RIS ServerInfo
x) = ServerInfo
x
toServerInfo RoleInfo
_       = forall a. HasCallStack => String -> a
error String
"toServerInfo"

toClientInfo :: RoleInfo -> ClientInfo
toClientInfo :: RoleInfo -> ClientInfo
toClientInfo (RIC ClientInfo
x) = ClientInfo
x
toClientInfo RoleInfo
_       = forall a. HasCallStack => String -> a
error String
"toClientInfo"

newServerInfo :: IO RoleInfo
newServerInfo :: IO RoleInfo
newServerInfo = ServerInfo -> RoleInfo
RIS forall b c a. (b -> c) -> (a -> b) -> a -> c
. TQueue (Input Stream) -> ServerInfo
ServerInfo forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadIO m => m (TQueue a)
newTQueueIO

newClientInfo :: ByteString -> ByteString -> Int -> IO RoleInfo
newClientInfo :: Method -> Method -> Int -> IO RoleInfo
newClientInfo Method
scm Method
auth Int
lim =  ClientInfo -> RoleInfo
RIC forall b c a. (b -> c) -> (a -> b) -> a -> c
. Method
-> Method -> IORef (Cache (Method, Method) Stream) -> ClientInfo
ClientInfo Method
scm Method
auth forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef (forall k v. Int -> Cache k v
emptyCache Int
lim)

insertCache :: Method -> ByteString -> Stream -> RoleInfo -> IO ()
insertCache :: Method -> Method -> Stream -> RoleInfo -> IO ()
insertCache Method
m Method
path Stream
v (RIC (ClientInfo Method
_ Method
_ IORef (Cache (Method, Method) Stream)
ref)) = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (Cache (Method, Method) Stream)
ref forall a b. (a -> b) -> a -> b
$ \Cache (Method, Method) Stream
c ->
  (forall k v. Ord k => k -> v -> Cache k v -> Cache k v
Cache.insert (Method
m,Method
path) Stream
v Cache (Method, Method) Stream
c, ())
insertCache Method
_ Method
_ Stream
_ RoleInfo
_ = forall a. HasCallStack => String -> a
error String
"insertCache"

lookupCache :: Method -> ByteString -> RoleInfo -> IO (Maybe Stream)
lookupCache :: Method -> Method -> RoleInfo -> IO (Maybe Stream)
lookupCache Method
m Method
path (RIC (ClientInfo Method
_ Method
_ IORef (Cache (Method, Method) Stream)
ref)) = forall k v. Ord k => k -> Cache k v -> Maybe v
Cache.lookup (Method
m,Method
path) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef (Cache (Method, Method) Stream)
ref
lookupCache Method
_ Method
_ RoleInfo
_ = forall a. HasCallStack => String -> a
error String
"lookupCache"

----------------------------------------------------------------

-- | The context for HTTP/2 connection.
data Context = Context {
    Context -> Role
role               :: Role
  , Context -> RoleInfo
roleInfo           :: RoleInfo
  -- Settings
  , Context -> IORef Bool
myFirstSettings    :: IORef Bool
  , Context -> IORef (Maybe SettingsList)
myPendingAlist     :: IORef (Maybe SettingsList)
  , Context -> IORef Settings
mySettings         :: IORef Settings
  , Context -> IORef Settings
peerSettings       :: IORef Settings
  , Context -> StreamTable
streamTable        :: StreamTable
  , Context -> IORef Int
concurrency        :: IORef Int
  -- | RFC 9113 says "Other frames (from any stream) MUST NOT
  --   occur between the HEADERS frame and any CONTINUATION
  --   frames that might follow". This field is used to implement
  --   this requirement.
  , Context -> IORef (Maybe Int)
continued          :: IORef (Maybe StreamId)
  , Context -> IORef Int
myStreamId         :: IORef StreamId
  , Context -> IORef Int
peerStreamId       :: IORef StreamId
  , Context -> IORef Int
outputBufferLimit  :: IORef Int
  , Context -> TQueue (Output Stream)
outputQ            :: TQueue (Output Stream)
  , Context -> TVar Int
outputQStreamID    :: TVar StreamId
  , Context -> TQueue Control
controlQ           :: TQueue Control
  , Context -> DynamicTable
encodeDynamicTable :: DynamicTable
  , Context -> DynamicTable
decodeDynamicTable :: DynamicTable
  -- the connection window for sending data
  , Context -> TVar Int
txConnectionWindow :: TVar WindowSize
  -- window update for receiving data
  , Context -> IORef Int
rxConnectionInc    :: IORef WindowSize -- this is diff
  , Context -> Rate
pingRate           :: Rate
  , Context -> Rate
settingsRate       :: Rate
  , Context -> Rate
emptyFrameRate     :: Rate
  , Context -> Rate
rstRate            :: Rate
  , Context -> SockAddr
mySockAddr         :: SockAddr
  , Context -> SockAddr
peerSockAddr       :: SockAddr
  }

----------------------------------------------------------------

newContext :: RoleInfo -> BufferSize -> SockAddr -> SockAddr -> IO Context
newContext :: RoleInfo -> Int -> SockAddr -> SockAddr -> IO Context
newContext RoleInfo
rinfo Int
siz SockAddr
mysa SockAddr
peersa =
    Role
-> RoleInfo
-> IORef Bool
-> IORef (Maybe SettingsList)
-> IORef Settings
-> IORef Settings
-> StreamTable
-> IORef Int
-> IORef (Maybe Int)
-> IORef Int
-> IORef Int
-> IORef Int
-> TQueue (Output Stream)
-> TVar Int
-> TQueue Control
-> DynamicTable
-> DynamicTable
-> TVar Int
-> IORef Int
-> Rate
-> Rate
-> Rate
-> Rate
-> SockAddr
-> SockAddr
-> Context
Context Role
rl RoleInfo
rinfo
               forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef Bool
False
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Settings
defaultSettings
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Settings
defaultSettings
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO StreamTable
newStreamTable
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
sid0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
buflim
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => m (TQueue a)
newTQueueIO
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
sid0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => m (TQueue a)
newTQueueIO
               -- My SETTINGS_HEADER_TABLE_SIZE
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO DynamicTable
newDynamicTableForEncoding Int
defaultDynamicTableSize
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int -> IO DynamicTable
newDynamicTableForDecoding Int
defaultDynamicTableSize Int
4096
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
defaultWindowSize
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. Monad m => a -> m a
return SockAddr
mysa
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. Monad m => a -> m a
return SockAddr
peersa
   where
     rl :: Role
rl = case RoleInfo
rinfo of
       RIC{} -> Role
Client
       RoleInfo
_     -> Role
Server
     sid0 :: Int
sid0 | Role
rl forall a. Eq a => a -> a -> Bool
== Role
Client = Int
1
          | Bool
otherwise    = Int
2
     dlim :: Int
dlim = Int
defaultPayloadLength forall a. Num a => a -> a -> a
+ Int
frameHeaderLength
     buflim :: Int
buflim | Int
siz forall a. Ord a => a -> a -> Bool
>= Int
dlim = Int
dlim
            | Bool
otherwise   = Int
siz

----------------------------------------------------------------

isClient :: Context -> Bool
isClient :: Context -> Bool
isClient Context
ctx = Context -> Role
role Context
ctx forall a. Eq a => a -> a -> Bool
== Role
Client

isServer :: Context -> Bool
isServer :: Context -> Bool
isServer Context
ctx = Context -> Role
role Context
ctx forall a. Eq a => a -> a -> Bool
== Role
Server

----------------------------------------------------------------

getMyNewStreamId :: Context -> IO StreamId
getMyNewStreamId :: Context -> IO Int
getMyNewStreamId Context
ctx = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' (Context -> IORef Int
myStreamId Context
ctx) forall {b}. Num b => b -> (b, b)
inc2
  where
    inc2 :: b -> (b, b)
inc2 b
n = let n' :: b
n' = b
n forall a. Num a => a -> a -> a
+ b
2 in (b
n', b
n)

getPeerStreamID :: Context -> IO StreamId
getPeerStreamID :: Context -> IO Int
getPeerStreamID Context
ctx = forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ Context -> IORef Int
peerStreamId Context
ctx

setPeerStreamID :: Context -> StreamId -> IO ()
setPeerStreamID :: Context -> Int -> IO ()
setPeerStreamID Context
ctx Int
sid =  forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Int
peerStreamId Context
ctx) Int
sid

----------------------------------------------------------------

{-# INLINE setStreamState #-}
setStreamState :: Context -> Stream -> StreamState -> IO ()
setStreamState :: Context -> Stream -> StreamState -> IO ()
setStreamState Context
_ Stream{IORef StreamState
streamState :: Stream -> IORef StreamState
streamState :: IORef StreamState
streamState} StreamState
val = forall a. IORef a -> a -> IO ()
writeIORef IORef StreamState
streamState StreamState
val

opened :: Context -> Stream -> IO ()
opened :: Context -> Stream -> IO ()
opened ctx :: Context
ctx@Context{IORef Int
concurrency :: IORef Int
concurrency :: Context -> IORef Int
concurrency} Stream
strm = do
    forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
concurrency (\Int
x -> (Int
xforall a. Num a => a -> a -> a
+Int
1,()))
    Context -> Stream -> StreamState -> IO ()
setStreamState Context
ctx Stream
strm (Maybe ClosedCode -> OpenState -> StreamState
Open forall a. Maybe a
Nothing OpenState
JustOpened)

halfClosedRemote :: Context -> Stream -> IO ()
halfClosedRemote :: Context -> Stream -> IO ()
halfClosedRemote Context
ctx stream :: Stream
stream@Stream{IORef StreamState
streamState :: IORef StreamState
streamState :: Stream -> IORef StreamState
streamState} = do
    Maybe ClosedCode
closingCode <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef StreamState
streamState StreamState -> (StreamState, Maybe ClosedCode)
closeHalf
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Context -> Stream -> ClosedCode -> IO ()
closed Context
ctx Stream
stream) Maybe ClosedCode
closingCode
  where
    closeHalf :: StreamState -> (StreamState, Maybe ClosedCode)
    closeHalf :: StreamState -> (StreamState, Maybe ClosedCode)
closeHalf x :: StreamState
x@(Closed ClosedCode
_)       = (StreamState
x, forall a. Maybe a
Nothing)
    closeHalf (Open (Just ClosedCode
cc) OpenState
_) = (ClosedCode -> StreamState
Closed ClosedCode
cc, forall a. a -> Maybe a
Just ClosedCode
cc)
    closeHalf StreamState
_                  = (StreamState
HalfClosedRemote, forall a. Maybe a
Nothing)

halfClosedLocal :: Context -> Stream -> ClosedCode -> IO ()
halfClosedLocal :: Context -> Stream -> ClosedCode -> IO ()
halfClosedLocal Context
ctx stream :: Stream
stream@Stream{IORef StreamState
streamState :: IORef StreamState
streamState :: Stream -> IORef StreamState
streamState} ClosedCode
cc = do
    Bool
shouldFinalize <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef StreamState
streamState StreamState -> (StreamState, Bool)
closeHalf
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
shouldFinalize forall a b. (a -> b) -> a -> b
$
        Context -> Stream -> ClosedCode -> IO ()
closed Context
ctx Stream
stream ClosedCode
cc
  where
    closeHalf :: StreamState -> (StreamState, Bool)
    closeHalf :: StreamState -> (StreamState, Bool)
closeHalf x :: StreamState
x@(Closed ClosedCode
_)     = (StreamState
x, Bool
False)
    closeHalf StreamState
HalfClosedRemote = (ClosedCode -> StreamState
Closed ClosedCode
cc, Bool
True)
    closeHalf (Open Maybe ClosedCode
Nothing OpenState
o) = (Maybe ClosedCode -> OpenState -> StreamState
Open (forall a. a -> Maybe a
Just ClosedCode
cc) OpenState
o, Bool
False)
    closeHalf StreamState
_                = (Maybe ClosedCode -> OpenState -> StreamState
Open (forall a. a -> Maybe a
Just ClosedCode
cc) OpenState
JustOpened, Bool
False)

closed :: Context -> Stream -> ClosedCode -> IO ()
closed :: Context -> Stream -> ClosedCode -> IO ()
closed ctx :: Context
ctx@Context{IORef Int
concurrency :: IORef Int
concurrency :: Context -> IORef Int
concurrency,StreamTable
streamTable :: StreamTable
streamTable :: Context -> StreamTable
streamTable} strm :: Stream
strm@Stream{Int
streamNumber :: Stream -> Int
streamNumber :: Int
streamNumber} ClosedCode
cc = do
    StreamTable -> Int -> IO ()
remove StreamTable
streamTable Int
streamNumber
    -- TODO: prevent double-counting
    forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
concurrency (\Int
x -> (Int
xforall a. Num a => a -> a -> a
-Int
1,()))
    Context -> Stream -> StreamState -> IO ()
setStreamState Context
ctx Stream
strm (ClosedCode -> StreamState
Closed ClosedCode
cc) -- anyway

openStream :: Context -> StreamId -> FrameType -> IO Stream
openStream :: Context -> Int -> FrameType -> IO Stream
openStream ctx :: Context
ctx@Context{StreamTable
streamTable :: StreamTable
streamTable :: Context -> StreamTable
streamTable, IORef Settings
peerSettings :: IORef Settings
peerSettings :: Context -> IORef Settings
peerSettings} Int
sid FrameType
ftyp = do
    Int
ws <- Settings -> Int
initialWindowSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef Settings
peerSettings
    Stream
newstrm <- Int -> Int -> IO Stream
newStream Int
sid Int
ws
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (FrameType
ftyp forall a. Eq a => a -> a -> Bool
== FrameType
FrameHeaders Bool -> Bool -> Bool
|| FrameType
ftyp forall a. Eq a => a -> a -> Bool
== FrameType
FramePushPromise) forall a b. (a -> b) -> a -> b
$ Context -> Stream -> IO ()
opened Context
ctx Stream
newstrm
    StreamTable -> Int -> Stream -> IO ()
insert StreamTable
streamTable Int
sid Stream
newstrm
    forall (m :: * -> *) a. Monad m => a -> m a
return Stream
newstrm