{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Scheduler.Computation
( Comp(.., Par, Par'), getCompWorkers
) where
import Control.Concurrent (getNumCapabilities)
import Control.DeepSeq (NFData(..), deepseq)
import Control.Monad.IO.Class
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
import Data.Word
data Comp
= Seq
| ParOn ![Int]
| ParN {-# UNPACK #-} !Word16
deriving Eq
pattern Par :: Comp
pattern Par <- ParOn [] where
Par = ParOn []
pattern Par' :: Comp
pattern Par' <- ParN 0 where
Par' = ParN 0
instance Show Comp where
show Seq = "Seq"
show Par = "Par"
show (ParOn ws) = "ParOn " ++ show ws
show (ParN n) = "ParN " ++ show n
showsPrec _ Seq = ("Seq" ++)
showsPrec _ Par = ("Par" ++)
showsPrec 0 comp = (show comp ++)
showsPrec _ comp = (("(" ++ show comp ++ ")") ++)
instance NFData Comp where
rnf comp =
case comp of
Seq -> ()
ParOn wIds -> wIds `deepseq` ()
ParN n -> n `deepseq` ()
{-# INLINE rnf #-}
instance Monoid Comp where
mempty = Seq
{-# INLINE mempty #-}
mappend = joinComp
{-# INLINE mappend #-}
instance Semigroup Comp where
(<>) = joinComp
{-# INLINE (<>) #-}
joinComp :: Comp -> Comp -> Comp
joinComp x y =
case x of
Seq -> y
Par -> Par
ParN 0 -> ParN 0
ParOn xs ->
case y of
Par -> Par
ParN 0 -> ParN 0
ParOn ys -> ParOn (xs <> ys)
_ -> x
ParN n1 ->
case y of
Seq -> x
Par -> Par
ParOn _ -> y
ParN 0 -> y
ParN n2 -> ParN (max n1 n2)
{-# NOINLINE joinComp #-}
getCompWorkers :: MonadIO m => Comp -> m Int
getCompWorkers =
\case
Seq -> return 1
Par -> liftIO getNumCapabilities
ParOn ws -> return $ length ws
ParN 0 -> liftIO getNumCapabilities
ParN n -> return $ fromIntegral n