----------------------------------------------------------------------------
-- |
-- Module      :  TestUtils
-- Copyright   :  (c) Sergey Vinokurov 2022
-- License     :  Apache-2.0 (see LICENSE)
-- Maintainer  :  serg.foo@gmail.com
----------------------------------------------------------------------------

{-# LANGUAGE BangPatterns   #-}
{-# LANGUAGE DeriveGeneric  #-}
{-# LANGUAGE NamedFieldPuns #-}

module TestUtils
  ( Delay(..)
  , sleep
  , Iterations(..)
  , callN
  , Thread(..)
  , runThread
  , Threads(..)
  , spawnAndCall
  ) where

import Control.Concurrent
import Control.Concurrent.Async
import Control.Monad
import Control.Monad.IO.Class
import Data.Foldable
import Data.List.NonEmpty (NonEmpty(..))
import GHC.Generics (Generic)
import Test.QuickCheck

-- In microseconds
newtype Delay = Delay { Delay -> Int
unDelay :: Int }
  deriving (Delay -> Delay -> Bool
(Delay -> Delay -> Bool) -> (Delay -> Delay -> Bool) -> Eq Delay
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Delay -> Delay -> Bool
== :: Delay -> Delay -> Bool
$c/= :: Delay -> Delay -> Bool
/= :: Delay -> Delay -> Bool
Eq, Int -> Delay -> ShowS
[Delay] -> ShowS
Delay -> String
(Int -> Delay -> ShowS)
-> (Delay -> String) -> ([Delay] -> ShowS) -> Show Delay
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Delay -> ShowS
showsPrec :: Int -> Delay -> ShowS
$cshow :: Delay -> String
show :: Delay -> String
$cshowList :: [Delay] -> ShowS
showList :: [Delay] -> ShowS
Show)

sleep :: MonadIO m => Delay -> m ()
sleep :: forall (m :: * -> *). MonadIO m => Delay -> m ()
sleep (Delay Int
n) = case Int
n of
  Int
0 -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Int
k -> IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
k

instance Arbitrary Delay where
  arbitrary :: Gen Delay
arbitrary = Int -> Delay
Delay (Int -> Delay) -> Gen Int -> Gen Delay
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int, Int) -> Gen Int
chooseInt (Int
0, Int
10)
  shrink :: Delay -> [Delay]
shrink = (Int -> Delay) -> [Int] -> [Delay]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Delay
Delay ([Int] -> [Delay]) -> (Delay -> [Int]) -> Delay -> [Delay]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
x Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
25) ([Int] -> [Int]) -> (Delay -> [Int]) -> Delay -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Int]
forall a. Arbitrary a => a -> [a]
shrink (Int -> [Int]) -> (Delay -> Int) -> Delay -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delay -> Int
unDelay

newtype Iterations = Iterations { Iterations -> Int
unIterations :: Int }
  deriving (Iterations -> Iterations -> Bool
(Iterations -> Iterations -> Bool)
-> (Iterations -> Iterations -> Bool) -> Eq Iterations
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Iterations -> Iterations -> Bool
== :: Iterations -> Iterations -> Bool
$c/= :: Iterations -> Iterations -> Bool
/= :: Iterations -> Iterations -> Bool
Eq, Int -> Iterations -> ShowS
[Iterations] -> ShowS
Iterations -> String
(Int -> Iterations -> ShowS)
-> (Iterations -> String)
-> ([Iterations] -> ShowS)
-> Show Iterations
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Iterations -> ShowS
showsPrec :: Int -> Iterations -> ShowS
$cshow :: Iterations -> String
show :: Iterations -> String
$cshowList :: [Iterations] -> ShowS
showList :: [Iterations] -> ShowS
Show)

instance Arbitrary Iterations where
  arbitrary :: Gen Iterations
arbitrary = Int -> Iterations
Iterations (Int -> Iterations) -> Gen Int -> Gen Iterations
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int, Int) -> Gen Int
chooseInt (Int
0, Int
50)
  shrink :: Iterations -> [Iterations]
shrink = (Int -> Iterations) -> [Int] -> [Iterations]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Iterations
Iterations ([Int] -> [Iterations])
-> (Iterations -> [Int]) -> Iterations -> [Iterations]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
x Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
50) ([Int] -> [Int]) -> (Iterations -> [Int]) -> Iterations -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Int]
forall a. Arbitrary a => a -> [a]
shrink (Int -> [Int]) -> (Iterations -> Int) -> Iterations -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Iterations -> Int
unIterations

callN :: Applicative m => Iterations -> m a -> m ()
callN :: forall (m :: * -> *) a. Applicative m => Iterations -> m a -> m ()
callN (Iterations !Int
n) m a
action = Int -> m ()
forall {t}. (Ord t, Num t) => t -> m ()
go Int
n
  where
    go :: t -> m ()
go !t
k =
      if t
k t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> t
0
      then m a
action m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> t -> m ()
go (t
k t -> t -> t
forall a. Num a => a -> a -> a
- t
1)
      else () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

data Thread = Thread
  { Thread -> Delay
tDelay      :: Delay
  , Thread -> Int
tIncrement  :: Int
  , Thread -> Iterations
tIterations :: Iterations
  } deriving (Thread -> Thread -> Bool
(Thread -> Thread -> Bool)
-> (Thread -> Thread -> Bool) -> Eq Thread
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Thread -> Thread -> Bool
== :: Thread -> Thread -> Bool
$c/= :: Thread -> Thread -> Bool
/= :: Thread -> Thread -> Bool
Eq, Int -> Thread -> ShowS
[Thread] -> ShowS
Thread -> String
(Int -> Thread -> ShowS)
-> (Thread -> String) -> ([Thread] -> ShowS) -> Show Thread
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Thread -> ShowS
showsPrec :: Int -> Thread -> ShowS
$cshow :: Thread -> String
show :: Thread -> String
$cshowList :: [Thread] -> ShowS
showList :: [Thread] -> ShowS
Show, (forall x. Thread -> Rep Thread x)
-> (forall x. Rep Thread x -> Thread) -> Generic Thread
forall x. Rep Thread x -> Thread
forall x. Thread -> Rep Thread x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Thread -> Rep Thread x
from :: forall x. Thread -> Rep Thread x
$cto :: forall x. Rep Thread x -> Thread
to :: forall x. Rep Thread x -> Thread
Generic)

instance Arbitrary Thread where
  arbitrary :: Gen Thread
arbitrary = Delay -> Int -> Iterations -> Thread
Thread (Delay -> Int -> Iterations -> Thread)
-> Gen Delay -> Gen (Int -> Iterations -> Thread)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Delay
forall a. Arbitrary a => Gen a
arbitrary Gen (Int -> Iterations -> Thread)
-> Gen Int -> Gen (Iterations -> Thread)
forall a b. Gen (a -> b) -> Gen a -> Gen b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Int, Int) -> Gen Int
chooseInt (-Int
1000, Int
1000) Gen (Iterations -> Thread) -> Gen Iterations -> Gen Thread
forall a b. Gen (a -> b) -> Gen a -> Gen b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Iterations
forall a. Arbitrary a => Gen a
arbitrary
  shrink :: Thread -> [Thread]
shrink = (Thread -> Bool) -> [Thread] -> [Thread]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1000) (Int -> Bool) -> (Thread -> Int) -> Thread -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int
forall a. Num a => a -> a
abs (Int -> Int) -> (Thread -> Int) -> Thread -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Thread -> Int
tIncrement) ([Thread] -> [Thread])
-> (Thread -> [Thread]) -> Thread -> [Thread]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Thread -> [Thread]
forall a.
(Generic a, RecursivelyShrink (Rep a), GSubterms (Rep a) a) =>
a -> [a]
genericShrink

runThread :: MonadIO m => Thread -> (Delay -> m a) -> (Int -> m b) -> m ()
runThread :: forall (m :: * -> *) a b.
MonadIO m =>
Thread -> (Delay -> m a) -> (Int -> m b) -> m ()
runThread Thread{Delay
tDelay :: Thread -> Delay
tDelay :: Delay
tDelay, Int
tIncrement :: Thread -> Int
tIncrement :: Int
tIncrement, Iterations
tIterations :: Thread -> Iterations
tIterations :: Iterations
tIterations} Delay -> m a
doSleep Int -> m b
f =
  Iterations -> m a -> m ()
forall (m :: * -> *) a. Applicative m => Iterations -> m a -> m ()
callN Iterations
tIterations (Int -> m b
f Int
tIncrement m b -> m a -> m a
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Delay -> m a
doSleep Delay
tDelay)

newtype Threads = Threads { Threads -> NonEmpty Thread
unThreads :: NonEmpty Thread }
  deriving (Threads -> Threads -> Bool
(Threads -> Threads -> Bool)
-> (Threads -> Threads -> Bool) -> Eq Threads
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Threads -> Threads -> Bool
== :: Threads -> Threads -> Bool
$c/= :: Threads -> Threads -> Bool
/= :: Threads -> Threads -> Bool
Eq, Int -> Threads -> ShowS
[Threads] -> ShowS
Threads -> String
(Int -> Threads -> ShowS)
-> (Threads -> String) -> ([Threads] -> ShowS) -> Show Threads
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Threads -> ShowS
showsPrec :: Int -> Threads -> ShowS
$cshow :: Threads -> String
show :: Threads -> String
$cshowList :: [Threads] -> ShowS
showList :: [Threads] -> ShowS
Show)

instance Arbitrary Threads where
  arbitrary :: Gen Threads
arbitrary = do
    Int
n <- (Int, Int) -> Gen Int
chooseInt (Int
0, Int
31)
    NonEmpty Thread -> Threads
Threads (NonEmpty Thread -> Threads)
-> Gen (NonEmpty Thread) -> Gen Threads
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Thread -> [Thread] -> NonEmpty Thread
forall a. a -> [a] -> NonEmpty a
(:|) (Thread -> [Thread] -> NonEmpty Thread)
-> Gen Thread -> Gen ([Thread] -> NonEmpty Thread)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Thread
forall a. Arbitrary a => Gen a
arbitrary Gen ([Thread] -> NonEmpty Thread)
-> Gen [Thread] -> Gen (NonEmpty Thread)
forall a b. Gen (a -> b) -> Gen a -> Gen b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Gen Thread -> Gen [Thread]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n Gen Thread
forall a. Arbitrary a => Gen a
arbitrary)
  shrink :: Threads -> [Threads]
shrink = (NonEmpty Thread -> Threads) -> [NonEmpty Thread] -> [Threads]
forall a b. (a -> b) -> [a] -> [b]
map NonEmpty Thread -> Threads
Threads ([NonEmpty Thread] -> [Threads])
-> (Threads -> [NonEmpty Thread]) -> Threads -> [Threads]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty Thread -> [NonEmpty Thread]
forall a.
(Generic a, RecursivelyShrink (Rep a), GSubterms (Rep a) a) =>
a -> [a]
genericShrink (NonEmpty Thread -> [NonEmpty Thread])
-> (Threads -> NonEmpty Thread) -> Threads -> [NonEmpty Thread]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Threads -> NonEmpty Thread
unThreads

spawnAndCall :: Traversable f => f b -> IO a -> (a -> b -> IO ()) -> IO a
spawnAndCall :: forall (f :: * -> *) b a.
Traversable f =>
f b -> IO a -> (a -> b -> IO ()) -> IO a
spawnAndCall f b
threads IO a
mkRes a -> b -> IO ()
action = do
  a
res <- IO a
mkRes
  (Async () -> IO ()) -> f (Async ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Async () -> IO ()
forall a. Async a -> IO a
wait (f (Async ()) -> IO ()) -> IO (f (Async ())) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (b -> IO (Async ())) -> f b -> IO (f (Async ()))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> (b -> IO ()) -> b -> IO (Async ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b -> IO ()
action a
res) f b
threads
  a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res