{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UndecidableInstances  #-}

module Metro.UDPServer
  ( UDPServer
  , udpServer
  , newClient
  ) where

import           Control.Monad             (void)
import           Data.ByteString           (empty)
import           Data.Hashable
import           Data.IOHashMap            (IOHashMap)
import qualified Data.IOHashMap            as HM (delete, empty, insert, lookup)
import           Metro.Class               (GetPacketId, RecvPacket,
                                            Servable (..), Transport,
                                            TransportConfig)
import           Metro.Conn
import           Metro.Node                (NodeEnv1)
import           Metro.Server              (ServerT, getServ, handleConn,
                                            serverEnv)
import           Metro.Session             (SessionT)
import           Metro.Socket              (bindTo, getDatagramAddr)
import           Metro.TP.BS               (BSHandle, bsTransportConfig,
                                            closeBSHandle, feed, newBSHandle)
import           Metro.TP.UDPSocket        (UDPSocket, udpSocket_)
import           Network.Socket            (SockAddr, Socket, addrAddress)
import qualified Network.Socket            as Socket (close)
import           Network.Socket.ByteString (recvFrom, sendAllTo)
import           System.Log.Logger         (errorM)
import           UnliftIO

data UDPServer = UDPServer Socket (IOHashMap String BSHandle)

instance Servable UDPServer where
  data ServerConfig UDPServer = UDPConfig String
  type SID UDPServer = SockAddr
  type STP UDPServer = UDPSocket
  newServer :: ServerConfig UDPServer -> m UDPServer
newServer (UDPConfig hostPort) = do
    Socket
sock <- IO Socket -> m Socket
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Socket -> m Socket) -> IO Socket -> m Socket
forall a b. (a -> b) -> a -> b
$ String -> IO Socket
bindTo String
hostPort
    Socket -> IOHashMap String BSHandle -> UDPServer
UDPServer Socket
sock (IOHashMap String BSHandle -> UDPServer)
-> m (IOHashMap String BSHandle) -> m UDPServer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (IOHashMap String BSHandle)
forall (m :: * -> *) k v. MonadIO m => m (IOHashMap k v)
HM.empty
  servOnce :: UDPServer
-> (Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ())
-> m ()
servOnce us :: UDPServer
us@(UDPServer Socket
serv IOHashMap String BSHandle
handleList) Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ()
done = do
    (ByteString
bs, SockAddr
addr) <- IO (ByteString, SockAddr) -> m (ByteString, SockAddr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ByteString, SockAddr) -> m (ByteString, SockAddr))
-> IO (ByteString, SockAddr) -> m (ByteString, SockAddr)
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO (ByteString, SockAddr)
recvFrom Socket
serv Int
4194304

    Maybe BSHandle
bsHandle <- String -> IOHashMap String BSHandle -> m (Maybe BSHandle)
forall (m :: * -> *) k v.
(MonadIO m, Eq k, Hashable k) =>
k -> IOHashMap k v -> m (Maybe v)
HM.lookup (SockAddr -> String
forall a. Show a => a -> String
show SockAddr
addr) IOHashMap String BSHandle
handleList
    case Maybe BSHandle
bsHandle of
      Just BSHandle
h  -> BSHandle -> ByteString -> m ()
forall (m :: * -> *). MonadIO m => BSHandle -> ByteString -> m ()
feed BSHandle
h ByteString
bs
      Maybe BSHandle
Nothing ->
        m (Async ()) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Async ()) -> m ()) -> (m () -> m (Async ())) -> m () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> m (Async ())
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          BSHandle
h <- ByteString -> m BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
bs
          TransportConfig UDPSocket
config <- UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
forall (m :: * -> *).
MonadIO m =>
UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
newTransportConfig UDPServer
us SockAddr
addr BSHandle
h
          Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ()
done (Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ())
-> Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ()
forall a b. (a -> b) -> a -> b
$ (SockAddr, TransportConfig UDPSocket)
-> Maybe (SockAddr, TransportConfig UDPSocket)
forall a. a -> Maybe a
Just (SockAddr
addr, TransportConfig UDPSocket
config)
          BSHandle -> m ()
forall (m :: * -> *). MonadIO m => BSHandle -> m ()
closeBSHandle BSHandle
h

  onConnEnter :: UDPServer -> SID UDPServer -> m ()
onConnEnter UDPServer
_ SID UDPServer
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  onConnLeave :: UDPServer -> SID UDPServer -> m ()
onConnLeave (UDPServer Socket
_ IOHashMap String BSHandle
handleList) SID UDPServer
addr = String -> IOHashMap String BSHandle -> m ()
forall (m :: * -> *) k v.
(MonadIO m, Eq k, Hashable k) =>
k -> IOHashMap k v -> m ()
HM.delete (SockAddr -> String
forall a. Show a => a -> String
show SID UDPServer
SockAddr
addr) IOHashMap String BSHandle
handleList
  servClose :: UDPServer -> m ()
servClose (UDPServer Socket
serv IOHashMap String BSHandle
_) = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
Socket.close Socket
serv

udpServer :: String -> ServerConfig UDPServer
udpServer :: String -> ServerConfig UDPServer
udpServer = String -> ServerConfig UDPServer
UDPConfig

newTransportConfig
  :: (MonadIO m)
  => UDPServer
  -> SockAddr
  -> BSHandle
  -> m (TransportConfig UDPSocket)
newTransportConfig :: UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
newTransportConfig (UDPServer Socket
sock IOHashMap String BSHandle
handleList) SockAddr
addr BSHandle
h = do
  String -> BSHandle -> IOHashMap String BSHandle -> m ()
forall (m :: * -> *) k v.
(MonadIO m, Eq k, Hashable k) =>
k -> v -> IOHashMap k v -> m ()
HM.insert (SockAddr -> String
forall a. Show a => a -> String
show SockAddr
addr) BSHandle
h IOHashMap String BSHandle
handleList
  TransportConfig UDPSocket -> m (TransportConfig UDPSocket)
forall (m :: * -> *) a. Monad m => a -> m a
return (TransportConfig UDPSocket -> m (TransportConfig UDPSocket))
-> TransportConfig UDPSocket -> m (TransportConfig UDPSocket)
forall a b. (a -> b) -> a -> b
$ TransportConfig BSTransport -> TransportConfig UDPSocket
udpSocket_ (TransportConfig BSTransport -> TransportConfig UDPSocket)
-> TransportConfig BSTransport -> TransportConfig UDPSocket
forall a b. (a -> b) -> a -> b
$ BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig BSHandle
h ((ByteString -> IO ()) -> TransportConfig BSTransport)
-> (ByteString -> IO ()) -> TransportConfig BSTransport
forall a b. (a -> b) -> a -> b
$ (ByteString -> SockAddr -> IO ())
-> SockAddr -> ByteString -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Socket -> ByteString -> SockAddr -> IO ()
sendAllTo Socket
sock) SockAddr
addr

newClient
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt)
  => (TransportConfig UDPSocket -> TransportConfig tp)
  -> String
  -> nid
  -> u
  -> (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> ServerT UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
newClient :: (TransportConfig UDPSocket -> TransportConfig tp)
-> String
-> nid
-> u
-> (rpkt -> m Bool)
-> SessionT u nid k rpkt tp m ()
-> ServerT
     UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
newClient TransportConfig UDPSocket -> TransportConfig tp
mk String
hostPort nid
nid u
uEnv rpkt -> m Bool
preprocess SessionT u nid k rpkt tp m ()
sess = do
  Maybe AddrInfo
addr <- IO (Maybe AddrInfo)
-> ServerT UDPServer u nid k rpkt tp m (Maybe AddrInfo)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe AddrInfo)
 -> ServerT UDPServer u nid k rpkt tp m (Maybe AddrInfo))
-> IO (Maybe AddrInfo)
-> ServerT UDPServer u nid k rpkt tp m (Maybe AddrInfo)
forall a b. (a -> b) -> a -> b
$ String -> IO (Maybe AddrInfo)
getDatagramAddr String
hostPort
  case Maybe AddrInfo
addr of
    Maybe AddrInfo
Nothing -> do
      IO () -> ServerT UDPServer u nid k rpkt tp m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ServerT UDPServer u nid k rpkt tp m ())
-> IO () -> ServerT UDPServer u nid k rpkt tp m ()
forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Metro.UDP" (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Connect UDP Server " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
hostPort String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" failed"
      Maybe (NodeEnv1 u nid k rpkt tp)
-> ServerT
     UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (NodeEnv1 u nid k rpkt tp)
forall a. Maybe a
Nothing
    Just AddrInfo
addr0 -> do
      UDPServer
us <- ServerEnv UDPServer u nid k rpkt tp -> UDPServer
forall serv u nid k rpkt tp. ServerEnv serv u nid k rpkt tp -> serv
getServ (ServerEnv UDPServer u nid k rpkt tp -> UDPServer)
-> ServerT
     UDPServer u nid k rpkt tp m (ServerEnv UDPServer u nid k rpkt tp)
-> ServerT UDPServer u nid k rpkt tp m UDPServer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ServerT
  UDPServer u nid k rpkt tp m (ServerEnv UDPServer u nid k rpkt tp)
forall (m :: * -> *) serv u nid k rpkt tp.
Monad m =>
ServerT serv u nid k rpkt tp m (ServerEnv serv u nid k rpkt tp)
serverEnv
      BSHandle
h <- ByteString -> ServerT UDPServer u nid k rpkt tp m BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
empty
      TransportConfig tp
config <- TransportConfig UDPSocket -> TransportConfig tp
mk (TransportConfig UDPSocket -> TransportConfig tp)
-> ServerT UDPServer u nid k rpkt tp m (TransportConfig UDPSocket)
-> ServerT UDPServer u nid k rpkt tp m (TransportConfig tp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UDPServer
-> SockAddr
-> BSHandle
-> ServerT UDPServer u nid k rpkt tp m (TransportConfig UDPSocket)
forall (m :: * -> *).
MonadIO m =>
UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
newTransportConfig UDPServer
us (AddrInfo -> SockAddr
addrAddress AddrInfo
addr0) BSHandle
h
      ConnEnv tp
connEnv <- TransportConfig tp
-> ServerT UDPServer u nid k rpkt tp m (ConnEnv tp)
forall (m :: * -> *) tp.
(MonadIO m, Transport tp) =>
TransportConfig tp -> m (ConnEnv tp)
initConnEnv TransportConfig tp
config
      NodeEnv1 u nid k rpkt tp -> Maybe (NodeEnv1 u nid k rpkt tp)
forall a. a -> Maybe a
Just (NodeEnv1 u nid k rpkt tp -> Maybe (NodeEnv1 u nid k rpkt tp))
-> ((NodeEnv1 u nid k rpkt tp, Async ())
    -> NodeEnv1 u nid k rpkt tp)
-> (NodeEnv1 u nid k rpkt tp, Async ())
-> Maybe (NodeEnv1 u nid k rpkt tp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeEnv1 u nid k rpkt tp, Async ()) -> NodeEnv1 u nid k rpkt tp
forall a b. (a, b) -> a
fst ((NodeEnv1 u nid k rpkt tp, Async ())
 -> Maybe (NodeEnv1 u nid k rpkt tp))
-> ServerT
     UDPServer u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
-> ServerT
     UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> SID UDPServer
-> ConnEnv tp
-> nid
-> u
-> (rpkt -> m Bool)
-> SessionT u nid k rpkt tp m ()
-> ServerT
     UDPServer u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
forall (m :: * -> *) tp nid k rpkt serv u.
(MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid,
 Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt,
 Servable serv) =>
String
-> SID serv
-> ConnEnv tp
-> nid
-> u
-> (rpkt -> m Bool)
-> SessionT u nid k rpkt tp m ()
-> ServerT
     serv u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
handleConn String
"Server" (AddrInfo -> SockAddr
addrAddress AddrInfo
addr0) ConnEnv tp
connEnv nid
nid u
uEnv rpkt -> m Bool
preprocess SessionT u nid k rpkt tp m ()
sess