{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}

#include "HsNetDef.h"

module Network.Socket.Shutdown (
    ShutdownCmd(..)
  , shutdown
  , gracefulClose
  ) where

import Control.Concurrent (yield)
import qualified Control.Exception as E
import Foreign.Marshal.Alloc (mallocBytes, free)
import System.Timeout

#if !defined(mingw32_HOST_OS)
import Control.Concurrent.STM
import qualified GHC.Event as Ev
#endif

import Network.Socket.Buffer
import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.STM
import Network.Socket.Types

data ShutdownCmd = ShutdownReceive
                 | ShutdownSend
                 | ShutdownBoth

sdownCmdToInt :: ShutdownCmd -> CInt
sdownCmdToInt :: ShutdownCmd -> CInt
sdownCmdToInt ShutdownCmd
ShutdownReceive = CInt
0
sdownCmdToInt ShutdownCmd
ShutdownSend    = CInt
1
sdownCmdToInt ShutdownCmd
ShutdownBoth    = CInt
2

-- | Shut down one or both halves of the connection, depending on the
-- second argument to the function.  If the second argument is
-- 'ShutdownReceive', further receives are disallowed.  If it is
-- 'ShutdownSend', further sends are disallowed.  If it is
-- 'ShutdownBoth', further sends and receives are disallowed.
shutdown :: Socket -> ShutdownCmd -> IO ()
shutdown :: Socket -> ShutdownCmd -> IO ()
shutdown Socket
s ShutdownCmd
stype = IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ()) -> IO ()) -> (CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
  String -> IO CInt -> IO ()
forall a. (Eq a, Num a) => String -> IO a -> IO ()
throwSocketErrorIfMinus1Retry_ String
"Network.Socket.shutdown" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
    CInt -> CInt -> IO CInt
c_shutdown CInt
fd (CInt -> IO CInt) -> CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ ShutdownCmd -> CInt
sdownCmdToInt ShutdownCmd
stype

foreign import CALLCONV unsafe "shutdown"
  c_shutdown :: CInt -> CInt -> IO CInt

-- | Closing a socket gracefully.
--   This sends TCP FIN and check if TCP FIN is received from the peer.
--   The second argument is time out to receive TCP FIN in millisecond.
--   In both normal cases and error cases, socket is deallocated finally.
--
--   Since: 3.1.1.0
gracefulClose :: Socket -> Int -> IO ()
gracefulClose :: Socket -> Int -> IO ()
gracefulClose Socket
s Int
tmout0 = IO ()
sendRecvFIN IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`E.finally` Socket -> IO ()
close Socket
s
  where
    sendRecvFIN :: IO ()
sendRecvFIN = do
        -- Sending TCP FIN.
        Either IOException ()
ex <- IO () -> IO (Either IOException ())
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Socket -> ShutdownCmd -> IO ()
shutdown Socket
s ShutdownCmd
ShutdownSend
        case Either IOException ()
ex of
          -- Don't catch asynchronous exceptions
          Left (IOException
_ :: E.IOException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Right () -> do
              -- Giving CPU time to other threads hoping that
              -- FIN arrives meanwhile.
              IO ()
yield
              -- Waiting TCP FIN.
              IO (Ptr Word8)
-> (Ptr Word8 -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Int -> IO (Ptr Word8)
forall a. Int -> IO (Ptr a)
mallocBytes Int
bufSize) Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free (Socket -> Int -> Ptr Word8 -> IO ()
recvEOF Socket
s Int
tmout0)

recvEOF :: Socket -> Int -> Ptr Word8 -> IO ()
#if !defined(mingw32_HOST_OS)
recvEOF :: Socket -> Int -> Ptr Word8 -> IO ()
recvEOF Socket
s Int
tmout0 Ptr Word8
buf = do
    Maybe EventManager
mevmgr <- IO (Maybe EventManager)
Ev.getSystemEventManager
    case Maybe EventManager
mevmgr of
      Maybe EventManager
Nothing -> Socket -> Int -> Ptr Word8 -> IO ()
recvEOFloop Socket
s Int
tmout0 Ptr Word8
buf
      Just EventManager
_ -> Socket -> Int -> Ptr Word8 -> IO ()
recvEOFevent Socket
s Int
tmout0 Ptr Word8
buf
#else
recvEOF = recvEOFloop
#endif

-- Don't use 4092 here. The GHC runtime takes the global lock
-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
bufSize :: Int
bufSize :: Int
bufSize = Int
1024

recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO ()
recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO ()
recvEOFloop Socket
s Int
tmout0 Ptr Word8
buf = IO (Maybe Int) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Maybe Int) -> IO ()) -> IO (Maybe Int) -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> IO Int -> IO (Maybe Int)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
tmout0 (IO Int -> IO (Maybe Int)) -> IO Int -> IO (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
recvBuf Socket
s Ptr Word8
buf Int
bufSize

#if !defined(mingw32_HOST_OS)
data Wait = MoreData | TimeoutTripped

recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO ()
recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO ()
recvEOFevent Socket
s Int
tmout0 Ptr Word8
buf = do
    TimerManager
tmmgr <- IO TimerManager
Ev.getSystemTimerManager
    TVar Bool
tvar <- Bool -> IO (TVar Bool)
forall a. a -> IO (TVar a)
newTVarIO Bool
False
    IO TimeoutKey
-> (TimeoutKey -> IO ()) -> (TimeoutKey -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (TimerManager -> TVar Bool -> IO TimeoutKey
setupTimeout TimerManager
tmmgr TVar Bool
tvar) (TimerManager -> TimeoutKey -> IO ()
cancelTimeout TimerManager
tmmgr) ((TimeoutKey -> IO ()) -> IO ()) -> (TimeoutKey -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TimeoutKey
_ -> do
        IO (STM (), IO ())
-> ((STM (), IO ()) -> IO ())
-> ((STM (), IO ()) -> IO ())
-> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Socket -> IO (STM (), IO ())
setupRead Socket
s) (STM (), IO ()) -> IO ()
forall {a} {b}. (a, b) -> b
cancelRead (((STM (), IO ()) -> IO ()) -> IO ())
-> ((STM (), IO ()) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(STM ()
rxWait,IO ()
_) -> do
            let toWait :: STM ()
toWait = TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
tvar STM Bool -> (Bool -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> STM ()
check
                wait :: IO Wait
wait = STM Wait -> IO Wait
forall a. STM a -> IO a
atomically ((STM ()
toWait STM () -> STM Wait -> STM Wait
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Wait -> STM Wait
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Wait
TimeoutTripped)
                               STM Wait -> STM Wait -> STM Wait
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (STM ()
rxWait STM () -> STM Wait -> STM Wait
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Wait -> STM Wait
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Wait
MoreData))
            Wait
waitRes <- IO Wait
wait
            case Wait
waitRes of
              Wait
TimeoutTripped -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
              -- We don't check the (positive) length.
              -- In normal case, it's 0. That is, only FIN is received.
              -- In error cases, data is available. But there is no
              -- application which can read it. So, let's stop receiving
              -- to prevent attacks.
              Wait
MoreData       -> IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
recvBufNoWait Socket
s Ptr Word8
buf Int
bufSize
  where
    -- millisecond to microsecond
    tmout :: Int
tmout = Int
tmout0 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000
    setupTimeout :: TimerManager -> TVar Bool -> IO TimeoutKey
setupTimeout TimerManager
tmmgr TVar Bool
tvar =
        TimerManager -> Int -> IO () -> IO TimeoutKey
Ev.registerTimeout TimerManager
tmmgr Int
tmout (IO () -> IO TimeoutKey) -> IO () -> IO TimeoutKey
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
tvar Bool
True
    cancelTimeout :: TimerManager -> TimeoutKey -> IO ()
cancelTimeout = TimerManager -> TimeoutKey -> IO ()
Ev.unregisterTimeout
    setupRead :: Socket -> IO (STM (), IO ())
setupRead = Socket -> IO (STM (), IO ())
waitAndCancelReadSocketSTM
    cancelRead :: (a, b) -> b
cancelRead (a
_,b
cancel) = b
cancel
#endif