{-# LANGUAGE BangPatterns, CPP, FlexibleContexts, Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
#if __GLASGOW_HASKELL__ >= 704
{-# OPTIONS_GHC -fsimpl-tick-factor=200 #-}
#endif
module Statistics.Function
(
minMax
, sort
, inplaceSortIO
, indices
, nextHighestPowerOfTwo
, within
, square
, unsafeModify
, for
, rfor
) where
#include "MachDeps.h"
import Control.Applicative
import Control.Monad.ST (ST)
import Data.Bits ((.|.), shiftR)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Numeric.MathFunctions.Comparison (within)
import Basement.Monad
import Prelude
sort :: U.Vector Double -> U.Vector Double
sort = G.modify inplaceSortST
{-# NOINLINE sort #-}
inplaceSortST :: M.MVector s Double
-> ST s ()
inplaceSortST mvec = qsort 0 (M.length mvec-1)
where
qsort lo hi
| lo >= hi = pure ()
| otherwise = do
p <- partition lo hi
qsort lo (pred p)
qsort (p+1) hi
pivotStrategy low high = do
let mid = (low + high) `div` 2
pivot <- M.unsafeRead mvec mid
M.unsafeRead mvec high >>= M.unsafeWrite mvec mid
M.unsafeWrite mvec high pivot
pure pivot
partition lo hi = do
pivot <- pivotStrategy lo hi
let go iOrig jOrig = do
let fw k = do ak <- M.unsafeRead mvec k
if compare ak pivot == LT
then fw (k+1)
else pure (k, ak)
(i, ai) <- fw iOrig
let bw k | k==i = pure (i, ai)
| otherwise = do ak <- M.unsafeRead mvec k
if compare ak pivot /= LT
then bw (pred k)
else pure (k, ak)
(j, aj) <- bw jOrig
if i < j
then do
M.unsafeWrite mvec i aj
M.unsafeWrite mvec j ai
go (i+1) (pred j)
else do
M.unsafeWrite mvec hi ai
M.unsafeWrite mvec i pivot
pure i
go lo hi
inplaceSortIO :: M.MVector (PrimState IO) Double
-> IO ()
inplaceSortIO mvec = qsort 0 (M.length mvec-1)
where
qsort lo hi
| lo >= hi = pure ()
| otherwise = do
p <- partition lo hi
qsort lo (pred p)
qsort (p+1) hi
pivotStrategy low high = do
let mid = (low + high) `div` 2
pivot <- M.unsafeRead mvec mid
M.unsafeRead mvec high >>= M.unsafeWrite mvec mid
M.unsafeWrite mvec high pivot
pure pivot
partition lo hi = do
pivot <- pivotStrategy lo hi
let go iOrig jOrig = do
let fw k = do ak <- M.unsafeRead mvec k
if compare ak pivot == LT
then fw (k+1)
else pure (k, ak)
(i, ai) <- fw iOrig
let bw k | k==i = pure (i, ai)
| otherwise = do ak <- M.unsafeRead mvec k
if compare ak pivot /= LT
then bw (pred k)
else pure (k, ak)
(j, aj) <- bw jOrig
if i < j
then do
M.unsafeWrite mvec i aj
M.unsafeWrite mvec j ai
go (i+1) (pred j)
else do
M.unsafeWrite mvec hi ai
M.unsafeWrite mvec i pivot
pure i
go lo hi
indices :: (G.Vector v a, G.Vector v Int) => v a -> v Int
indices a = G.enumFromTo 0 (G.length a - 1)
{-# INLINE indices #-}
data MM = MM {-# UNPACK #-} !Double {-# UNPACK #-} !Double
minMax :: (G.Vector v Double) => v Double -> (Double, Double)
minMax = fini . G.foldl' go (MM (1/0) (-1/0))
where
go (MM lo hi) k = MM (min lo k) (max hi k)
fini (MM lo hi) = (lo, hi)
{-# INLINE minMax #-}
nextHighestPowerOfTwo :: Int -> Int
nextHighestPowerOfTwo n
#if WORD_SIZE_IN_BITS == 64
= 1 + _i32
#else
= 1 + i16
#endif
where
i0 = n - 1
i1 = i0 .|. i0 `shiftR` 1
i2 = i1 .|. i1 `shiftR` 2
i4 = i2 .|. i2 `shiftR` 4
i8 = i4 .|. i4 `shiftR` 8
i16 = i8 .|. i8 `shiftR` 16
_i32 = i16 .|. i16 `shiftR` 32
square :: Double -> Double
square x = x * x
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
where
loop i | i == n = return ()
| otherwise = f i >> loop (i+1)
{-# INLINE for #-}
rfor :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
rfor n0 !n f = loop n0
where
loop i | i == n = return ()
| otherwise = let i' = i-1 in f i' >> loop i'
{-# INLINE rfor #-}
unsafeModify :: M.MVector s Double -> Int -> (Double -> Double) -> ST s ()
unsafeModify v i f = do
k <- M.unsafeRead v i
M.unsafeWrite v i (f k)
{-# INLINE unsafeModify #-}