{-# LANGUAGE TypeFamilies #-}
module Metro.TP.UDPSocket
  ( UDPSocket
  , udpSocket
  , udpSocket_
  ) where

import           Control.Monad             (forever)
import           Data.ByteString           (empty)
import           Metro.Class               (Transport (..))
import           Metro.Socket              (bindTo, getDatagramAddr)
import           Metro.TP.BS               (BSTransport, bsTransportConfig,
                                            feed, newBSHandle)
import           Network.Socket            (addrAddress)
import           Network.Socket.ByteString (recvFrom, sendAllTo)
import           System.Log.Logger         (errorM)
import           UnliftIO                  (Async, async, cancel)

data UDPSocket = UDPSocket (Maybe (Async ())) BSTransport

instance Transport UDPSocket where
  data TransportConfig UDPSocket =
    RawSocket (TransportConfig BSTransport)
    | SocketUri String
  newTransport :: TransportConfig UDPSocket -> IO UDPSocket
newTransport (RawSocket h)   = Maybe (Async ()) -> BSTransport -> UDPSocket
UDPSocket Maybe (Async ())
forall a. Maybe a
Nothing (BSTransport -> UDPSocket) -> IO BSTransport -> IO UDPSocket
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransportConfig BSTransport -> IO BSTransport
forall transport.
Transport transport =>
TransportConfig transport -> IO transport
newTransport TransportConfig BSTransport
h
  newTransport (SocketUri h)   = do
    Maybe AddrInfo
addrInfo <- String -> IO (Maybe AddrInfo)
getDatagramAddr String
h
    case Maybe AddrInfo
addrInfo of
      Nothing -> String -> IO UDPSocket
forall a. HasCallStack => String -> a
error (String -> IO UDPSocket) -> String -> IO UDPSocket
forall a b. (a -> b) -> a -> b
$ "Connect UDP Server " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
h String -> String -> String
forall a. [a] -> [a] -> [a]
++ " failed"
      Just addrInfo0 :: AddrInfo
addrInfo0 -> do
        let addr0 :: SockAddr
addr0 = AddrInfo -> SockAddr
addrAddress AddrInfo
addrInfo0
        BSHandle
bsHandle <- ByteString -> IO BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
empty
        Socket
sock <- String -> IO Socket
bindTo "udp://0.0.0.0:0"

        Async ()
io <- IO () -> IO (Async ())
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          (bs :: ByteString
bs, addr1 :: SockAddr
addr1) <- Socket -> Int -> IO (ByteString, SockAddr)
recvFrom Socket
sock 4194304
          if SockAddr
addr0 SockAddr -> SockAddr -> Bool
forall a. Eq a => a -> a -> Bool
== SockAddr
addr1 then BSHandle -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => BSHandle -> ByteString -> m ()
feed BSHandle
bsHandle ByteString
bs
          else String -> String -> IO ()
errorM "Metro.UDP" (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ "Receive unkonw address " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SockAddr -> String
forall a. Show a => a -> String
show SockAddr
addr1

        BSTransport
tp <- TransportConfig BSTransport -> IO BSTransport
forall transport.
Transport transport =>
TransportConfig transport -> IO transport
newTransport (TransportConfig BSTransport -> IO BSTransport)
-> TransportConfig BSTransport -> IO BSTransport
forall a b. (a -> b) -> a -> b
$ BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig BSHandle
bsHandle ((ByteString -> IO ()) -> TransportConfig BSTransport)
-> (ByteString -> IO ()) -> TransportConfig BSTransport
forall a b. (a -> b) -> a -> b
$ (ByteString -> SockAddr -> IO ())
-> SockAddr -> ByteString -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Socket -> ByteString -> SockAddr -> IO ()
sendAllTo Socket
sock) SockAddr
addr0
        UDPSocket -> IO UDPSocket
forall (m :: * -> *) a. Monad m => a -> m a
return (UDPSocket -> IO UDPSocket) -> UDPSocket -> IO UDPSocket
forall a b. (a -> b) -> a -> b
$ Maybe (Async ()) -> BSTransport -> UDPSocket
UDPSocket (Async () -> Maybe (Async ())
forall a. a -> Maybe a
Just Async ()
io) BSTransport
tp

  recvData :: UDPSocket -> Int -> IO ByteString
recvData (UDPSocket _ soc :: BSTransport
soc) = BSTransport -> Int -> IO ByteString
forall transport.
Transport transport =>
transport -> Int -> IO ByteString
recvData BSTransport
soc
  sendData :: UDPSocket -> ByteString -> IO ()
sendData (UDPSocket _ soc :: BSTransport
soc) = BSTransport -> ByteString -> IO ()
forall transport.
Transport transport =>
transport -> ByteString -> IO ()
sendData BSTransport
soc
  closeTransport :: UDPSocket -> IO ()
closeTransport (UDPSocket io :: Maybe (Async ())
io soc :: BSTransport
soc) = (Async () -> IO ()) -> Maybe (Async ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async () -> IO ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
cancel Maybe (Async ())
io IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> BSTransport -> IO ()
forall transport. Transport transport => transport -> IO ()
closeTransport BSTransport
soc

udpSocket :: String -> TransportConfig UDPSocket
udpSocket :: String -> TransportConfig UDPSocket
udpSocket = String -> TransportConfig UDPSocket
SocketUri

udpSocket_ :: TransportConfig BSTransport -> TransportConfig UDPSocket
udpSocket_ :: TransportConfig BSTransport -> TransportConfig UDPSocket
udpSocket_ = TransportConfig BSTransport -> TransportConfig UDPSocket
RawSocket