{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE MonoLocalBinds      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Control.Distributed.Process.Lifted
    ( module Control.Distributed.Process.Lifted
    , module Control.Distributed.Process
    , module Control.Exception.Lifted
    )
where

import           Control.Distributed.Process.Lifted.Class


import           Control.Distributed.Process                  (Closure,
                                                               DidSpawn (..),
                                                               DiedReason (..),
                                                               DiedReason (..),
                                                               Match, Message,
                                                               MonitorRef,
                                                               NodeId (..),
                                                               NodeLinkException (..),
                                                               NodeMonitorNotification (..),
                                                               NodeStats (..),
                                                               PortLinkException (..),
                                                               PortMonitorNotification (..),
                                                               Process,
                                                               ProcessId,
                                                               ProcessInfo (..),
                                                               ProcessLinkException (..),
                                                               ProcessMonitorNotification (..),
                                                               ProcessRegistrationException (..),
                                                               ProcessTerminationException (..),
                                                               ReceivePort,
                                                               RegisterReply (..),
                                                               RemoteTable,
                                                               SendPort,
                                                               SendPortId,
                                                               SpawnRef, Static,
                                                               WhereIsReply (..),
                                                               closure,
                                                               infoLinks,
                                                               infoMessageQueueLength,
                                                               infoMonitors,
                                                               infoNode,
                                                               infoRegisteredNames,
                                                               isEncoded,
                                                               liftIO, match,
                                                               matchAny,
                                                               matchAnyIf,
                                                               matchChan,
                                                               matchIf,
                                                               matchMessage,
                                                               matchMessageIf,
                                                               matchSTM,
                                                               nodeAddress,
                                                               nodeStatsLinks,
                                                               nodeStatsMonitors,
                                                               nodeStatsNode,
                                                               nodeStatsProcesses,
                                                               nodeStatsRegisteredNames,
                                                               processNodeId,
                                                               sendPortId,
                                                               sendPortProcessId,
                                                               unsafeWrapMessage,
                                                               wrapMessage)
import qualified Control.Distributed.Process                  as Base
import           Control.Distributed.Process.MonadBaseControl ()

import           Control.Distributed.Process.Closure          (SerializableDict)
import           Control.Distributed.Process.Internal.Types   (ProcessExitException (..))
import           Control.Distributed.Process.Serializable     (Serializable)

import           Control.Exception.Lifted                     (Exception,
                                                               Handler (..),
                                                               bracket,
                                                               bracket_, catch,
                                                               catches, finally,
                                                               mask, mask_,
                                                               onException, try)
import qualified Control.Exception.Lifted                     as EX
import           Data.Typeable                                (Typeable)

-- compose arity 2 functions
(.:) :: (c->d) -> (a->b->c) -> a->b->d
f .: i = \l r -> f $ i l r

-- | Generalized version of 'Base.spawnLocal'
spawnLocal :: (MonadProcessBase m) => m () -> m ProcessId
spawnLocal = liftBaseDiscardP Base.spawnLocal

-- | Generalized version of 'Base.getSelfPid'
getSelfPid :: (MonadProcess m) => m ProcessId
getSelfPid = liftP Base.getSelfPid

-- | Generalized version of 'Base.expect'
expect :: (MonadProcess m) => forall a. Serializable a => m a
expect = liftP Base.expect

-- | Generalized version of 'Base.expectTimeout'
expectTimeout :: (MonadProcess m) => forall a. Serializable a => Int -> m (Maybe a)
expectTimeout = liftP . Base.expectTimeout

-- | Generalized version of 'Base.register'
register :: (MonadProcess m) => String -> ProcessId -> m ()
register name = liftP . Base.register name

-- | Generalized version of 'Base.whereis'
whereis :: (MonadProcess m) => String -> m (Maybe ProcessId)
whereis = liftP . Base.whereis

-- | Generalized version of 'Base.catchesExit'
catchesExit :: forall m a. (MonadProcessBase m) => m a -> [ProcessId -> Message -> m (Maybe a)] -> m a
-- Had to re-implement due to requirement to preserve structure of Maybe a in
-- each possible handler.
catchesExit act handlers = catch act (`handleExit` handlers)
  where
    handleExit :: ProcessExitException
               -> [ProcessId -> Message -> m (Maybe a)]
               -> m a
    handleExit ex [] = EX.throwIO ex
    handleExit ex@(ProcessExitException from msg) (h:hs) = do
      r <- h from msg
      case r of
        Nothing -> handleExit ex hs
        Just p  -> return p

-- | Generalized version of 'Base.delegate'
delegate :: MonadProcess m => ProcessId -> (Message -> Bool) -> m ()
delegate = liftP .: Base.delegate

-- | Generalized version of 'Base.forward'
forward :: MonadProcess m => Message -> ProcessId -> m ()
forward = liftP .: Base.forward

-- | Generalized version of 'Base.getLocalNodeStats'
getLocalNodeStats :: MonadProcess m => m NodeStats
getLocalNodeStats = liftP Base.getLocalNodeStats

-- | Generalized version of 'Base.getNodeStats'
getNodeStats :: MonadProcess m => NodeId -> m (Either DiedReason NodeStats)
getNodeStats = liftP . Base.getNodeStats

-- | Generalized version of 'Base.getProcessInfo'
getProcessInfo :: MonadProcess m => ProcessId -> m (Maybe ProcessInfo)
getProcessInfo = liftP . Base.getProcessInfo

-- | Generalized version of 'Base.getSelfNode'
getSelfNode :: MonadProcess m => m NodeId
getSelfNode = liftP Base.getSelfNode

-- | Generalized version of 'Base.kill'
kill :: MonadProcess m => ProcessId -> String -> m ()
kill = liftP .: Base.kill

-- | Generalized version of 'Base.link'
link :: MonadProcess m => ProcessId -> m ()
link = liftP . Base.link

-- | Generalized version of 'Base.linkNode'
linkNode :: MonadProcess m => NodeId -> m ()
linkNode = liftP . Base.linkNode

-- | Generalized version of 'Base.linkPort'
linkPort :: MonadProcess m => SendPort a -> m ()
linkPort = liftP . Base.linkPort

-- | Generalized version of 'Base.monitor'
monitor :: MonadProcess m => ProcessId -> m MonitorRef
monitor = liftP . Base.monitor

-- | Generalized version of 'Base.monitorNode'
monitorNode :: MonadProcess m => NodeId -> m MonitorRef
monitorNode = liftP . Base.monitorNode

-- | Generalized version of 'Base.receiveTimeout'
receiveTimeout :: MonadProcess m => Int -> [Match b] -> m (Maybe b)
receiveTimeout = liftP .: Base.receiveTimeout

-- | Generalized version of 'Base.receiveWait'
receiveWait :: MonadProcess m => [Match b] -> m b
receiveWait = liftP . Base.receiveWait

-- | Generalized version of 'Base.reconnect'
reconnect :: MonadProcess m => ProcessId -> m ()
reconnect = liftP . Base.reconnect

-- | Generalized version of 'Base.reconnectPort'
reconnectPort :: MonadProcess m => SendPort a -> m ()
reconnectPort = liftP . Base.reconnectPort

-- | Generalized version of 'Base.registerRemoteAsync'
registerRemoteAsync :: MonadProcess m => NodeId -> String -> ProcessId -> m ()
registerRemoteAsync n = liftP .: Base.registerRemoteAsync n

-- | Generalized version of 'Base.relay'
relay :: MonadProcess m => ProcessId -> m ()
relay = liftP . Base.relay

-- | Generalized version of 'Base.reregister'
reregister :: MonadProcess m => String -> ProcessId -> m ()
reregister = liftP .: Base.reregister

-- | Generalized version of 'Base.reregisterRemoteAsync'
reregisterRemoteAsync :: MonadProcess m => NodeId -> String -> ProcessId -> m ()
reregisterRemoteAsync n = liftP .: Base.reregisterRemoteAsync n

-- | Generalized version of 'Base.say'
say :: MonadProcess m => String -> m ()
say = liftP . Base.say

-- | Generalized version of 'Base.spawn'
spawn :: MonadProcess m => NodeId -> Closure (Process ()) -> m ProcessId
spawn = liftP .: Base.spawn

-- | Generalized version of 'Base.spawnAsync'
spawnAsync :: MonadProcess m => NodeId -> Closure (Process ()) -> m SpawnRef
spawnAsync = liftP .: Base.spawnAsync

-- | Generalized version of 'Base.spawnLink'
spawnLink :: MonadProcess m => NodeId -> Closure (Process ()) -> m ProcessId
spawnLink = liftP .: Base.spawnLink

-- | Generalized version of 'Base.spawnMonitor'
spawnMonitor :: MonadProcess m => NodeId -> Closure (Process ()) -> m (ProcessId, MonitorRef)
spawnMonitor = liftP .: Base.spawnMonitor

-- | Generalized version of 'Base.spawnSupervised'
spawnSupervised :: MonadProcess m => NodeId -> Closure (Process ()) -> m (ProcessId, MonitorRef)
spawnSupervised = liftP .: Base.spawnSupervised

-- | Generalized version of 'Base.terminate'
terminate :: MonadProcess m => m a
terminate = liftP Base.terminate

-- | Generalized version of 'Base.unlink'
unlink :: MonadProcess m => ProcessId -> m ()
unlink = liftP . Base.unlink

-- | Generalized version of 'Base.unlinkNode'
unlinkNode :: MonadProcess m => NodeId -> m ()
unlinkNode = liftP . Base.unlinkNode

-- | Generalized version of 'Base.unlinkPort'
unlinkPort :: MonadProcess m => SendPort a -> m ()
unlinkPort = liftP . Base.unlinkPort

-- | Generalized version of 'Base.unmonitor'
unmonitor :: MonadProcess m => MonitorRef -> m ()
unmonitor = liftP . Base.unmonitor

-- | Generalized version of 'Base.unregister'
unregister :: MonadProcess m => String -> m ()
unregister = liftP . Base.unregister

-- | Generalized version of 'Base.unregisterRemoteAsync'
unregisterRemoteAsync :: MonadProcess m => NodeId -> String -> m ()
unregisterRemoteAsync = liftP .: Base.unregisterRemoteAsync

-- | Generalized version of 'Base.whereisRemoteAsync'
whereisRemoteAsync :: MonadProcess m => NodeId -> String -> m ()
whereisRemoteAsync = liftP .: Base.whereisRemoteAsync

-- | Generalized version of 'Base.withMonitor'
withMonitor :: ProcessId -> (MonitorRef -> Process a) -> Process a
withMonitor = liftP .: Base.withMonitor

-- | Generalized version of 'Base.withMonitor_'
withMonitor_ :: MonadProcessBase m => ProcessId -> m a -> m a
withMonitor_ pid ma = controlP $ \runInP ->
                        Base.withMonitor_ pid (runInP ma)

-- | Generalized version of 'Base.call'
call :: (MonadProcess m,Serializable a)
     => Static (SerializableDict a) -> NodeId -> Closure (Process a) -> m a
call s = liftP .: Base.call s

-- | Generalized version of 'Base.catchExit'
catchExit :: (MonadProcessBase m,Show a,Serializable a)
          => m b -> (ProcessId -> a -> m b) -> m b
catchExit ma handler = controlP $ \runInP ->
                           Base.catchExit (runInP ma)
                                          (\pid msg -> runInP $ handler pid msg)

-- | Generalized version of 'Base.die'
die :: (MonadProcess m, Serializable a) => a -> m b
die = liftP . Base.die

-- | Generalized version of 'Base.exit'
exit :: (MonadProcess m, Serializable a) => ProcessId -> a -> m ()
exit = liftP .: Base.exit

-- | Generalized version of 'Base.handleMessage'
handleMessage :: (MonadProcess m,Serializable a)
              => Message -> (a -> Process b) -> m (Maybe b)
handleMessage msg f = liftP $ Base.handleMessage msg f

-- | Generalized version of 'Base.handleMessageIf'
handleMessageIf :: (MonadProcess m,Serializable a)
                => Message -> (a -> Bool) -> (a -> Process b) -> m (Maybe b)
handleMessageIf msg p f = liftP $ Base.handleMessageIf msg p f

-- | Generalized version of 'Base.handleMessageIf_'
handleMessageIf_ :: (MonadProcess m,Serializable a)
                 => Message -> (a -> Bool) -> (a -> Process ()) -> m ()
handleMessageIf_ msg p f = liftP $ Base.handleMessageIf_ msg p f

-- | Generalized version of 'Base.handleMessage_'
handleMessage_ :: (MonadProcess m, Serializable a) => Message -> (a -> Process ()) -> m ()
handleMessage_ msg f = liftP $ Base.handleMessage_ msg f

-- | Generalized version of 'Base.mergePortsBiased'
mergePortsBiased :: (MonadProcess m,Serializable a)
                 => [ReceivePort a] -> m (ReceivePort a)
mergePortsBiased = liftP . Base.mergePortsBiased

-- | Generalized version of 'Base.mergePortsRR'
mergePortsRR :: (MonadProcess m, Serializable a) => [ReceivePort a] -> m (ReceivePort a)
mergePortsRR = liftP . Base.mergePortsRR

-- | Generalized version of 'Base.monitorPort'
monitorPort :: (MonadProcess m, Serializable a) => SendPort a -> m MonitorRef
monitorPort = liftP . Base.monitorPort

-- | Generalized version of 'Base.newChan'
newChan :: (MonadProcess m, Serializable a) => m (SendPort a, ReceivePort a)
newChan = liftP Base.newChan

-- | Generalized version of 'Base.nsend'
nsend :: (MonadProcess m, Serializable a) => String -> a -> m ()
nsend = liftP .: Base.nsend

-- | Generalized version of 'Base.nsendRemote'
nsendRemote :: (MonadProcess m, Serializable a) => NodeId -> String -> a -> m ()
nsendRemote n = liftP .: Base.nsendRemote n

-- | Generalized version of 'Base.proxy'
proxy :: (MonadProcess m, Serializable a) => ProcessId -> (a -> Process Bool) -> m ()
proxy = liftP .: Base.proxy

-- | Generalized version of 'Base.receiveChan'
receiveChan :: (MonadProcess m, Serializable a) => ReceivePort a -> m a
receiveChan = liftP . Base.receiveChan

-- | Generalized version of 'Base.receiveChanTimeout'
receiveChanTimeout :: (MonadProcess m, Serializable a) => Int -> ReceivePort a -> m (Maybe a)
receiveChanTimeout = liftP .: Base.receiveChanTimeout

-- | Generalized version of 'Base.send'
send :: (MonadProcess m, Serializable a) => ProcessId -> a -> m ()
send = liftP .: Base.send

-- | Generalized version of 'Base.sendChan'
sendChan :: (MonadProcess m, Serializable a) => SendPort a -> a -> m ()
sendChan = liftP .: Base.sendChan

-- | Generalized version of 'Base.spawnChannel'
spawnChannel :: (MonadProcess m,Serializable a)
             => Static (SerializableDict a)
             -> NodeId
             -> Closure (ReceivePort a -> Process ())
             -> m (SendPort a)
spawnChannel s = liftP .: Base.spawnChannel s

-- | Generalized version of 'Base.spawnChannelLocal'
spawnChannelLocal :: (MonadProcess m,Serializable a)
                  => (ReceivePort a -> Process ()) -> m (SendPort a)
spawnChannelLocal = liftP . Base.spawnChannelLocal

-- | Generalized version of 'Base.unClosure'
unClosure :: (MonadProcess m, Typeable a) => Closure a -> m a
unClosure = liftP . Base.unClosure

-- | Generalized version of 'Base.unStatic'
unStatic :: (MonadProcess m, Typeable a) => Static a -> m a
unStatic = liftP . Base.unStatic

-- | Generalized version of 'Base.unsafeNSend'
unsafeNSend :: (MonadProcess m, Serializable a) => String -> a -> m ()
unsafeNSend = liftP .: Base.unsafeNSend

-- | Generalized version of 'Base.unsafeSend'
unsafeSend :: (MonadProcess m, Serializable a) => ProcessId -> a -> m ()
unsafeSend = liftP .: Base.unsafeSend

-- | Generalized version of 'Base.unsafeSendChan'
unsafeSendChan :: (MonadProcess m, Serializable a) => SendPort a -> a -> m ()
unsafeSendChan = liftP .: Base.unsafeSendChan

-- | Generalized version of 'Base.unwrapMessage'
unwrapMessage :: (MonadProcess m, Serializable a) => Message -> m (Maybe a)
unwrapMessage = liftP . Base.unwrapMessage