{-# LANGUAGE BangPatterns, FlexibleContexts, UnboxedTuples #-}
module Statistics.Sample.KernelDensity
(
kde
) where
import Numeric.MathFunctions.Constants (m_sqrt_2_pi)
import Prelude hiding (const, min, max, sum)
import Statistics.Function (minMax, nextHighestPowerOfTwo)
import Statistics.Math.RootFinding (fromRoot, ridders)
import Statistics.Sample.Histogram (histogram_)
import Statistics.Sample.Internal (sum)
import Statistics.Transform (CD, dct, idct)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
kde :: (G.Vector v CD, G.Vector v Double, G.Vector v Int)
=> Int
-> v Double -> (v Double, v Double)
kde n0 xs = kde_ n0 (lo - range / 10) (hi + range / 10) xs
where
(lo,hi) = minMax xs
range | G.length xs <= 1 = 1
| lo == hi = 1
| otherwise = hi - lo
{-# INLINABLE kde #-}
{-# SPECIAlIZE kde :: Int -> U.Vector Double -> (U.Vector Double, U.Vector Double) #-}
{-# SPECIAlIZE kde :: Int -> V.Vector Double -> (V.Vector Double, V.Vector Double) #-}
kde_ :: (G.Vector v CD, G.Vector v Double, G.Vector v Int)
=> Int
-> Double
-> Double
-> v Double
-> (v Double, v Double)
kde_ n0 min max xs
| G.null xs = error "Statistics.KernelDensity.kde: empty sample"
| n0 <= 1 = error "Statistics.KernelDensity.kde: invalid number of points"
| otherwise = (mesh, density)
where
mesh = G.generate ni $ \z -> min + (d * fromIntegral z)
where d = r / (n-1)
density = G.map (/(2 * r)) . idct $ G.zipWith f a (G.enumFromTo 0 (n-1))
where f b z = b * exp (sqr z * sqr pi * t_star * (-0.5))
!n = fromIntegral ni
!ni = nextHighestPowerOfTwo n0
!r = max - min
a = dct . G.map (/ sum h) $ h
where h = G.map (/ len) $ histogram_ ni min max xs
!len = fromIntegral (G.length xs)
!t_star = fromRoot (0.28 * len ** (-0.4)) . ridders 1e-14 (0,0.1) $ \x ->
x - (len * (2 * sqrt pi) * go 6 (f 7 x)) ** (-0.4)
where
f q t = 2 * pi ** (q*2) * sum (G.zipWith g iv a2v)
where g i a2 = i ** q * a2 * exp ((-i) * sqr pi * t)
a2v = G.map (sqr . (*0.5)) $ G.tail a
iv = G.map sqr $ G.enumFromTo 1 (n-1)
go s !h | s == 1 = h
| otherwise = go (s-1) (f s time)
where time = (2 * const * k0 / len / h) ** (2 / (3 + 2 * s))
const = (1 + 0.5 ** (s+0.5)) / 3
k0 = U.product (G.enumFromThenTo 1 3 (2*s-1)) / m_sqrt_2_pi
sqr x = x * x
{-# INLINABLE kde_ #-}
{-# SPECIAlIZE kde_ :: Int -> Double -> Double -> U.Vector Double -> (U.Vector Double, U.Vector Double) #-}
{-# SPECIAlIZE kde_ :: Int -> Double -> Double -> V.Vector Double -> (V.Vector Double, V.Vector Double) #-}