{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE StrictData #-}
{-# OPTIONS_GHC -Wall #-}

-- | Specification of a performance measurement type suitable for the 'PerfT' monad transformer.
module Perf.Measure
  ( Measure (..),
    runMeasure,
    runMeasureN,
    cost,
    cputime,
    realtime,
    count,
    cycles,
  )
where

import Control.Monad
import Data.Fixed (Fixed (MkFixed))
import Data.Time.Clock
import Perf.Cycle
import System.CPUTime
import System.CPUTime.Rdtsc
import Prelude

-- $setup
-- >>> import Data.Foldable (foldl')

-- | A Measure consists of a monadic effect prior to measuring, a monadic effect to finalise the measurement, and the value measured
--
-- For example, the measure specified below will return 1 every time measurement is requested, thus forming the base of a simple counter for loopy code.
--
-- >>> let count = Measure 0 (pure ()) (pure 1)
data Measure m b = forall a.
  (Num b) =>
  Measure
  { Measure m b -> b
measure :: b,
    ()
prestep :: m a,
    ()
poststep :: a -> m b
  }

-- | Measure a single effect.
--
-- >>> r <- runMeasure count (pure "joy")
-- >>> r
-- (1,"joy")
runMeasure :: Monad m => Measure m b -> m a -> m (b, a)
runMeasure :: Measure m b -> m a -> m (b, a)
runMeasure (Measure b
_ m a
pre a -> m b
post) m a
a = do
  a
p <- m a
pre
  !a
a' <- m a
a
  b
m' <- a -> m b
post a
p
  (b, a) -> m (b, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (b
m', a
a')

-- | Measure once, but run an effect multiple times.
--
-- >>> r <- runMeasureN 1000 count (pure "joys")
-- >>> r
-- (1,"joys")
runMeasureN :: Monad m => Int -> Measure m b -> m a -> m (b, a)
runMeasureN :: Int -> Measure m b -> m a -> m (b, a)
runMeasureN Int
n (Measure b
_ m a
pre a -> m b
post) m a
a = do
  a
p <- m a
pre
  Int -> m a -> m ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) m a
a
  !a
a' <- m a
a
  b
m' <- a -> m b
post a
p
  (b, a) -> m (b, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (b
m', a
a')

-- | cost of a measurement in terms of the Measure's own units
--
-- >>> r <- cost count
-- >>> r
-- 1
cost :: Monad m => Measure m b -> m b
cost :: Measure m b -> m b
cost (Measure b
_ m a
pre a -> m b
post) = m a
pre m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> m b
post

-- | a measure using 'getCPUTime' from System.CPUTime (unit is picoseconds)
--
-- >>> r <- runMeasure cputime (pure $ foldl' (+) 0 [0..1000])
--
-- > (34000000,500500)
cputime :: Measure IO Integer
cputime :: Measure IO Integer
cputime = Integer
-> IO Integer -> (Integer -> IO Integer) -> Measure IO Integer
forall (m :: * -> *) b a.
Num b =>
b -> m a -> (a -> m b) -> Measure m b
Measure Integer
0 IO Integer
start Integer -> IO Integer
stop
  where
    start :: IO Integer
start = IO Integer
getCPUTime
    stop :: Integer -> IO Integer
stop Integer
a = do
      Integer
t <- IO Integer
getCPUTime
      Integer -> IO Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> IO Integer) -> Integer -> IO Integer
forall a b. (a -> b) -> a -> b
$ Integer
t Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
a

-- | a measure using 'getCurrentTime' (unit is seconds)
--
-- >>> r <- runMeasure realtime (pure $ foldl' (+) 0 [0..1000])
--
-- > (0.000046,500500)
realtime :: Measure IO Double
realtime :: Measure IO Double
realtime = Double -> IO UTCTime -> (UTCTime -> IO Double) -> Measure IO Double
forall (m :: * -> *) b a.
Num b =>
b -> m a -> (a -> m b) -> Measure m b
Measure Double
m0 IO UTCTime
start UTCTime -> IO Double
stop
  where
    m0 :: Double
m0 = Double
0
    start :: IO UTCTime
start = IO UTCTime
getCurrentTime
    stop :: UTCTime -> IO Double
stop UTCTime
a = do
      UTCTime
t <- IO UTCTime
getCurrentTime
      Double -> IO Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> IO Double) -> Double -> IO Double
forall a b. (a -> b) -> a -> b
$ NominalDiffTime -> Double
fromNominalDiffTime (NominalDiffTime -> Double) -> NominalDiffTime -> Double
forall a b. (a -> b) -> a -> b
$ UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
t UTCTime
a

fromNominalDiffTime :: NominalDiffTime -> Double
fromNominalDiffTime :: NominalDiffTime -> Double
fromNominalDiffTime NominalDiffTime
t = Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
i Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
1e-12
  where
    (MkFixed Integer
i) = NominalDiffTime -> Fixed E12
nominalDiffTimeToSeconds NominalDiffTime
t

-- | a 'Measure' used to count iterations
--
-- >>> r <- runMeasure count (pure ())
-- >>> r
-- (1,())
count :: Measure IO Int
count :: Measure IO Int
count = Int -> IO () -> (() -> IO Int) -> Measure IO Int
forall (m :: * -> *) b a.
Num b =>
b -> m a -> (a -> m b) -> Measure m b
Measure Int
m0 IO ()
start () -> IO Int
forall (m :: * -> *) a. (Monad m, Num a) => () -> m a
stop
  where
    m0 :: Int
m0 = Int
0 :: Int
    start :: IO ()
start = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    stop :: () -> m a
stop () = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
1

-- | a 'Measure' using the 'rdtsc' CPU register (units are in cycles)
--
-- >>> r <- runMeasureN 1000 cycles (pure ())
--
-- > (120540,()) -- ghci-level
-- > (18673,())  -- compiled with -O2
cycles :: Measure IO Cycle
cycles :: Measure IO Cycle
cycles = Cycle -> IO Cycle -> (Cycle -> IO Cycle) -> Measure IO Cycle
forall (m :: * -> *) b a.
Num b =>
b -> m a -> (a -> m b) -> Measure m b
Measure Cycle
m0 IO Cycle
start Cycle -> IO Cycle
stop
  where
    m0 :: Cycle
m0 = Cycle
0
    start :: IO Cycle
start = IO Cycle
rdtsc
    stop :: Cycle -> IO Cycle
stop Cycle
a = do
      Cycle
t <- IO Cycle
rdtsc
      Cycle -> IO Cycle
forall (m :: * -> *) a. Monad m => a -> m a
return (Cycle -> IO Cycle) -> Cycle -> IO Cycle
forall a b. (a -> b) -> a -> b
$ Cycle
t Cycle -> Cycle -> Cycle
forall a. Num a => a -> a -> a
- Cycle
a