{-# LANGUAGE BangPatterns, CPP, FlexibleContexts, Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
#if __GLASGOW_HASKELL__ >= 704
{-# OPTIONS_GHC -fsimpl-tick-factor=200 #-}
#endif

-- |
-- Module    : Statistics.Function
-- Copyright : (c) 2009, 2010, 2011 Bryan O'Sullivan
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- Useful functions.

module Statistics.Function
    (
    -- * Scanning
      minMax
    -- * Sorting
    , sort
    , inplaceSortIO
    -- * Indexing
    , indices
    -- * Bit twiddling
    , nextHighestPowerOfTwo
    -- * Comparison
    , within
    -- * Arithmetic
    , square
    -- * Vectors
    , unsafeModify
    -- * Combinators
    , for
    , rfor
    ) where

#include "MachDeps.h"

import Control.Applicative
import Control.Monad.ST (ST)
import Data.Bits ((.|.), shiftR)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Numeric.MathFunctions.Comparison (within)
import Basement.Monad
import Prelude -- Silence redundant import warnings

-- | Sort a vector.
sort :: U.Vector Double -> U.Vector Double
sort = G.modify inplaceSortST
{-# NOINLINE sort #-}

inplaceSortST :: M.MVector s Double
              -> ST s ()
inplaceSortST mvec = qsort 0 (M.length mvec-1)
    where
        qsort lo hi
            | lo >= hi  = pure ()
            | otherwise = do
                p <- partition lo hi
                qsort lo (pred p)
                qsort (p+1) hi
        pivotStrategy low high = do
            let mid = (low + high) `div` 2
            pivot <- M.unsafeRead mvec mid
            M.unsafeRead mvec high >>= M.unsafeWrite mvec mid
            M.unsafeWrite mvec high pivot
            pure pivot
        partition lo hi = do
            pivot <- pivotStrategy lo hi
            let go iOrig jOrig = do
                    let fw k = do ak <- M.unsafeRead mvec k
                                  if compare ak pivot == LT
                                    then fw (k+1)
                                    else pure (k, ak)
                    (i, ai) <- fw iOrig
                    let bw k | k==i = pure (i, ai)
                             | otherwise = do ak <- M.unsafeRead mvec k
                                              if compare ak pivot /= LT
                                                then bw (pred k)
                                                else pure (k, ak)
                    (j, aj) <- bw jOrig
                    if i < j
                        then do
                            M.unsafeWrite mvec i aj
                            M.unsafeWrite mvec j ai
                            go (i+1) (pred j)
                        else do
                            M.unsafeWrite mvec hi ai
                            M.unsafeWrite mvec i pivot
                            pure i
            go lo hi

inplaceSortIO :: M.MVector (PrimState IO) Double
              -> IO ()
inplaceSortIO mvec = qsort 0 (M.length mvec-1)
    where
        qsort lo hi
            | lo >= hi  = pure ()
            | otherwise = do
                p <- partition lo hi
                qsort lo (pred p)
                qsort (p+1) hi
        pivotStrategy low high = do
            let mid = (low + high) `div` 2
            pivot <- M.unsafeRead mvec mid
            M.unsafeRead mvec high >>= M.unsafeWrite mvec mid
            M.unsafeWrite mvec high pivot
            pure pivot
        partition lo hi = do
            pivot <- pivotStrategy lo hi
            let go iOrig jOrig = do
                    let fw k = do ak <- M.unsafeRead mvec k
                                  if compare ak pivot == LT
                                    then fw (k+1)
                                    else pure (k, ak)
                    (i, ai) <- fw iOrig
                    let bw k | k==i = pure (i, ai)
                             | otherwise = do ak <- M.unsafeRead mvec k
                                              if compare ak pivot /= LT
                                                then bw (pred k)
                                                else pure (k, ak)
                    (j, aj) <- bw jOrig
                    if i < j
                        then do
                            M.unsafeWrite mvec i aj
                            M.unsafeWrite mvec j ai
                            go (i+1) (pred j)
                        else do
                            M.unsafeWrite mvec hi ai
                            M.unsafeWrite mvec i pivot
                            pure i
            go lo hi

-- | Return the indices of a vector.
indices :: (G.Vector v a, G.Vector v Int) => v a -> v Int
indices a = G.enumFromTo 0 (G.length a - 1)
{-# INLINE indices #-}

data MM = MM {-# UNPACK #-} !Double {-# UNPACK #-} !Double

-- | Compute the minimum and maximum of a vector in one pass.
minMax :: (G.Vector v Double) => v Double -> (Double, Double)
minMax = fini . G.foldl' go (MM (1/0) (-1/0))
  where
    go (MM lo hi) k = MM (min lo k) (max hi k)
    fini (MM lo hi) = (lo, hi)
{-# INLINE minMax #-}

-- | Efficiently compute the next highest power of two for a
-- non-negative integer.  If the given value is already a power of
-- two, it is returned unchanged.  If negative, zero is returned.
nextHighestPowerOfTwo :: Int -> Int
nextHighestPowerOfTwo n
#if WORD_SIZE_IN_BITS == 64
  = 1 + _i32
#else
  = 1 + i16
#endif
  where
    i0   = n - 1
    i1   = i0  .|. i0  `shiftR` 1
    i2   = i1  .|. i1  `shiftR` 2
    i4   = i2  .|. i2  `shiftR` 4
    i8   = i4  .|. i4  `shiftR` 8
    i16  = i8  .|. i8  `shiftR` 16
    _i32 = i16 .|. i16 `shiftR` 32
-- It could be implemented as
--
-- > nextHighestPowerOfTwo n = 1 + foldl' go (n-1) [1, 2, 4, 8, 16, 32]
--     where go m i = m .|. m `shiftR` i
--
-- But GHC do not inline foldl (probably because it's recursive) and
-- as result function walks list of boxed ints. Hand rolled version
-- uses unboxed arithmetic.

-- | Multiply a number by itself.
square :: Double -> Double
square x = x * x

-- | Simple for loop.  Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
  where
    loop i | i == n    = return ()
           | otherwise = f i >> loop (i+1)
{-# INLINE for #-}

-- | Simple reverse-for loop.  Counts from /start/-1 to /end/ (which
-- must be less than /start/).
rfor :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
rfor n0 !n f = loop n0
  where
    loop i | i == n    = return ()
           | otherwise = let i' = i-1 in f i' >> loop i'
{-# INLINE rfor #-}

unsafeModify :: M.MVector s Double -> Int -> (Double -> Double) -> ST s ()
unsafeModify v i f = do
  k <- M.unsafeRead v i
  M.unsafeWrite v i (f k)
{-# INLINE unsafeModify #-}