{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
module Data.TDigest.Postprocess.Internal (
HasHistogram (..),
HistBin (..),
histogramFromCentroids,
quantile,
mean,
variance,
cdf,
validateHistogram,
Affine (..),
) where
import Data.Foldable (toList)
import Data.Functor.Compose (Compose (..))
import Data.Functor.Identity (Identity (..))
import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..))
import Data.Semigroup.Foldable (foldMap1)
import Prelude ()
import Prelude.Compat
import qualified Data.List.NonEmpty as NE
import Data.TDigest.Internal
data HistBin = HistBin
{ hbMin :: !Mean
, hbMax :: !Mean
, hbValue :: !Mean
, hbWeight :: !Weight
, hbCumWeight :: !Weight
}
deriving (Show)
class Affine f => HasHistogram a f | a -> f where
histogram :: a -> f (NonEmpty HistBin)
totalWeight :: a -> Weight
instance (HistBin ~ e) => HasHistogram (NonEmpty HistBin) Identity where
histogram = Identity
totalWeight = tw . NE.last where
tw hb = hbWeight hb + hbCumWeight hb
instance (HistBin ~ e) => HasHistogram [HistBin] Maybe where
histogram = nonEmpty
totalWeight = affine 0 totalWeight . histogram
histogramFromCentroids :: NonEmpty Centroid -> NonEmpty HistBin
histogramFromCentroids = make
where
make :: NonEmpty Centroid -> NonEmpty HistBin
make ((x, w) :| []) = HistBin x x x w 0 :| []
make (c1@(x1, w1) :| rest@((x2, _) : _))
= HistBin x1 (mid x1 x2) x1 w1 0 :| iter c1 w1 rest
iter :: (Mean, Weight) -> Weight -> [(Mean, Weight)] -> [HistBin]
iter _ _ [] = []
iter (x0, _) t (c1@(x1, w1) : rest@((x2, _) : _))
= HistBin (mid x0 x1) (mid x1 x2) x1 w1 t: iter c1 (t + w1) rest
iter (x0, _) t [(x1, w1)]
= [HistBin (mid x0 x1) x1 x1 w1 t]
mid a b = (a + b) / 2
quantile :: Double -> Weight -> NonEmpty HistBin -> Double
quantile q tw = iter . toList
where
q' = q * tw
iter [] = error "quantile: empty NonEmpty"
iter [HistBin a b _ w t] = a + (b - a) * (q' - t) / w
iter (HistBin a b _ w t : rest)
| q' < t + w = a + (b - a) * (q' - t) / w
| otherwise = iter rest
mean :: NonEmpty HistBin -> Double
mean = getMean . foldMap1 toMean
where
toMean (HistBin _ _ x w _) = Mean w x
data Mean' = Mean !Double !Double
getMean :: Mean' -> Double
getMean (Mean _ x) = x
instance Semigroup Mean' where
Mean w1 x1 <> Mean w2 x2 = Mean w x
where
w = w1 + w2
x = (x1 * w1 + x2 * w2) / w
variance :: NonEmpty HistBin -> Double
variance = getVariance . foldMap1 toVariance
where
toVariance (HistBin _ _ x w _) = Variance w x 0
data Variance = Variance !Double !Double !Double
getVariance :: Variance -> Double
getVariance (Variance w _ d) = d / (w - 1)
instance Semigroup Variance where
Variance w1 x1 d1 <> Variance w2 x2 d2 = Variance w x d
where
w = w1 + w2
x = (x1 * w1 + x2 * w2) / w
d = d1 + d2 + w1 * (x1 * x1) + w2 * (x2 * x2) - w * x * x
cdf :: Double
-> Double
-> [HistBin] -> Double
cdf x n = iter
where
iter [] = 1
iter (HistBin a b _ w t : rest)
| x < a = 0
| x < b = (t + w * (x - a) / (b - a)) / n
| otherwise = iter rest
validateHistogram :: Foldable f => f HistBin -> Either String (f HistBin)
validateHistogram bs = traverse validPair (pairs $ toList bs) >> pure bs
where
validPair (lb@(HistBin _ lmax _ lwt lcw), rb@(HistBin rmin _ _ _ rcw)) = do
check (lmax == rmin) "gap between bins"
check (lcw + lwt == rcw) "mismatch in weight cumulation"
where
check False err = Left $ err ++ " " ++ show (lb, rb)
check True _ = Right ()
pairs xs = zip xs $ tail xs
class Traversable t => Affine t where
affine :: b -> (a -> b) -> t a -> b
affine x f = fromAffine x . fmap f
fromAffine :: a -> t a -> a
fromAffine x = affine x id
{-# MINIMAL fromAffine | affine #-}
instance Affine Identity where fromAffine _ = runIdentity
instance Affine Maybe where affine = maybe
instance Affine Proxy where affine x _ _ = x
instance (Affine f, Affine g) => Affine (Compose f g) where
affine x f (Compose c) = affine x (affine x f) c