{-# LANGUAGE ScopedTypeVariables #-}

-- | A bit like 'Fence', but not thread safe and optimised for avoiding taking the fence
module General.Thread(
    withThreadsBoth,
    withThreadSlave,
    allocateThread,
    Thread, newThreadFinally, stopThreads
    ) where

import General.Cleanup
import Data.Hashable
import Control.Concurrent.Extra
import Control.Exception
import General.Extra
import Control.Monad.Extra


data Thread = Thread ThreadId (Barrier ())

instance Eq Thread where
    Thread ThreadId
a Barrier ()
_ == :: Thread -> Thread -> Bool
== Thread ThreadId
b Barrier ()
_ = ThreadId
a ThreadId -> ThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== ThreadId
b

instance Hashable Thread where
    hashWithSalt :: Int -> Thread -> Int
hashWithSalt Int
salt (Thread ThreadId
a Barrier ()
_) = Int -> ThreadId -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt ThreadId
a

-- | The inner thread is unmasked even if you started masked.
newThreadFinally :: IO a -> (Thread -> Either SomeException a -> IO ()) -> IO Thread
newThreadFinally :: IO a -> (Thread -> Either SomeException a -> IO ()) -> IO Thread
newThreadFinally IO a
act Thread -> Either SomeException a -> IO ()
cleanup = do
    Barrier ()
bar <- IO (Barrier ())
forall a. IO (Barrier a)
newBarrier
    ThreadId
t <- IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> (IO () -> IO () -> IO ()) -> IO () -> IO () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
finally (Barrier () -> () -> IO ()
forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier ()
bar ()) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Either SomeException a
res <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO (Either SomeException a))
-> IO a -> IO (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ IO a -> IO a
forall a. IO a -> IO a
unmask IO a
act
        ThreadId
me <- IO ThreadId
myThreadId
        Thread -> Either SomeException a -> IO ()
cleanup (ThreadId -> Barrier () -> Thread
Thread ThreadId
me Barrier ()
bar) Either SomeException a
res
    Thread -> IO Thread
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Thread -> IO Thread) -> Thread -> IO Thread
forall a b. (a -> b) -> a -> b
$ ThreadId -> Barrier () -> Thread
Thread ThreadId
t Barrier ()
bar


stopThreads :: [Thread] -> IO ()
stopThreads :: [Thread] -> IO ()
stopThreads [Thread]
threads = do
    -- if a thread is in a masked action, killing it may take some time, so kill them in parallel
    [Barrier ()]
bars <- [IO (Barrier ())] -> IO [Barrier ()]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [do IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO ()
killThread ThreadId
t; Barrier () -> IO (Barrier ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure Barrier ()
bar | Thread ThreadId
t Barrier ()
bar <- [Thread]
threads]
    (Barrier () -> IO ()) -> [Barrier ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Barrier () -> IO ()
forall a. Barrier a -> IO a
waitBarrier [Barrier ()]
bars


-- Run both actions. If either throws an exception, both threads
-- are killed and an exception reraised.
-- Not called much, so simplicity over performance (2 threads).
withThreadsBoth :: IO a -> IO b -> IO (a, b)
withThreadsBoth :: IO a -> IO b -> IO (a, b)
withThreadsBoth IO a
act1 IO b
act2 = do
    Barrier (Either SomeException a)
bar1 <- IO (Barrier (Either SomeException a))
forall a. IO (Barrier a)
newBarrier
    Barrier (Either SomeException b)
bar2 <- IO (Barrier (Either SomeException b))
forall a. IO (Barrier a)
newBarrier
    ThreadId
parent <- IO ThreadId
myThreadId
    Var Bool
ignore <- Bool -> IO (Var Bool)
forall a. a -> IO (Var a)
newVar Bool
False
    ((forall a. IO a -> IO a) -> IO (a, b)) -> IO (a, b)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (a, b)) -> IO (a, b))
-> ((forall a. IO a -> IO a) -> IO (a, b)) -> IO (a, b)
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> do
        ThreadId
t1 <- ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> do
            Either SomeException a
res1 :: Either SomeException a <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO (Either SomeException a))
-> IO a -> IO (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ IO a -> IO a
forall a. IO a -> IO a
unmask IO a
act1
            IO Bool -> IO () -> IO ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (Var Bool -> IO Bool
forall a. Var a -> IO a
readVar Var Bool
ignore) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Either SomeException a -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) a b.
Applicative m =>
Either a b -> (a -> m ()) -> m ()
whenLeft Either SomeException a
res1 ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> SomeException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parent
            Barrier (Either SomeException a) -> Either SomeException a -> IO ()
forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier (Either SomeException a)
bar1 Either SomeException a
res1
        ThreadId
t2 <- ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> do
            Either SomeException b
res2 :: Either SomeException b <- IO b -> IO (Either SomeException b)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO b -> IO (Either SomeException b))
-> IO b -> IO (Either SomeException b)
forall a b. (a -> b) -> a -> b
$ IO b -> IO b
forall a. IO a -> IO a
unmask IO b
act2
            IO Bool -> IO () -> IO ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (Var Bool -> IO Bool
forall a. Var a -> IO a
readVar Var Bool
ignore) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Either SomeException b -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) a b.
Applicative m =>
Either a b -> (a -> m ()) -> m ()
whenLeft Either SomeException b
res2 ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> SomeException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parent
            Barrier (Either SomeException b) -> Either SomeException b -> IO ()
forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier (Either SomeException b)
bar2 Either SomeException b
res2
        Either SomeException (a, b)
res :: Either SomeException (a,b) <- IO (a, b) -> IO (Either SomeException (a, b))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (a, b) -> IO (Either SomeException (a, b)))
-> IO (a, b) -> IO (Either SomeException (a, b))
forall a b. (a -> b) -> a -> b
$ IO (a, b) -> IO (a, b)
forall a. IO a -> IO a
unmask (IO (a, b) -> IO (a, b)) -> IO (a, b) -> IO (a, b)
forall a b. (a -> b) -> a -> b
$ do
            Right a
v1 <- Barrier (Either SomeException a) -> IO (Either SomeException a)
forall a. Barrier a -> IO a
waitBarrier Barrier (Either SomeException a)
bar1
            Right b
v2 <- Barrier (Either SomeException b) -> IO (Either SomeException b)
forall a. Barrier a -> IO a
waitBarrier Barrier (Either SomeException b)
bar2
            (a, b) -> IO (a, b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
v1,b
v2)
        Var Bool -> Bool -> IO ()
forall a. Var a -> a -> IO ()
writeVar Var Bool
ignore Bool
True
        ThreadId -> IO ()
killThread ThreadId
t1
        IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO ()
killThread ThreadId
t2
        Barrier (Either SomeException a) -> IO (Either SomeException a)
forall a. Barrier a -> IO a
waitBarrier Barrier (Either SomeException a)
bar1
        Barrier (Either SomeException b) -> IO (Either SomeException b)
forall a. Barrier a -> IO a
waitBarrier Barrier (Either SomeException b)
bar2
        (SomeException -> IO (a, b))
-> ((a, b) -> IO (a, b))
-> Either SomeException (a, b)
-> IO (a, b)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO (a, b)
forall e a. Exception e => e -> IO a
throwIO (a, b) -> IO (a, b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Either SomeException (a, b)
res


-- | Run an action in a separate thread.
--   After the first action terminates, the thread will be killed.
--   If the action raises an exception it will be rethrown on the parent thread.
withThreadSlave :: IO () -> IO a -> IO a
withThreadSlave :: IO () -> IO a -> IO a
withThreadSlave IO ()
slave IO a
act = (Cleanup -> IO a) -> IO a
forall a. (Cleanup -> IO a) -> IO a
withCleanup ((Cleanup -> IO a) -> IO a) -> (Cleanup -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Cleanup
cleanup -> do
    Cleanup -> IO () -> IO ()
allocateThread Cleanup
cleanup IO ()
slave
    IO a
act


-- | Run the given action in a separate thread.
--   On cleanup, the thread will be killed before continuing.
--   If the action raises an exception it will be rethrown on the parent thread.
allocateThread :: Cleanup -> IO () -> IO ()
allocateThread :: Cleanup -> IO () -> IO ()
allocateThread Cleanup
cleanup IO ()
act = do
    Barrier ()
bar <- IO (Barrier ())
forall a. IO (Barrier a)
newBarrier
    ThreadId
parent <- IO ThreadId
myThreadId
    Var Bool
ignore <- Bool -> IO (Var Bool)
forall a. a -> IO (Var a)
newVar Bool
False
    IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ Cleanup -> IO ThreadId -> (ThreadId -> IO ()) -> IO ThreadId
forall a. Cleanup -> IO a -> (a -> IO ()) -> IO a
allocate Cleanup
cleanup
        (IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> do
            Either SomeException ()
res :: Either SomeException () <- IO () -> IO (Either SomeException ())
forall e a. Exception e => IO a -> IO (Either e a)
try (IO () -> IO (Either SomeException ()))
-> IO () -> IO (Either SomeException ())
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall a. IO a -> IO a
unmask IO ()
act
            IO Bool -> IO () -> IO ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (Var Bool -> IO Bool
forall a. Var a -> IO a
readVar Var Bool
ignore) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Either SomeException () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) a b.
Applicative m =>
Either a b -> (a -> m ()) -> m ()
whenLeft Either SomeException ()
res ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> SomeException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parent
            Barrier () -> () -> IO ()
forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier ()
bar ()
        )
        (\ThreadId
t -> do Var Bool -> Bool -> IO ()
forall a. Var a -> a -> IO ()
writeVar Var Bool
ignore Bool
True; ThreadId -> IO ()
killThread ThreadId
t; Barrier () -> IO ()
forall a. Barrier a -> IO a
waitBarrier Barrier ()
bar)