{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.TDigest.Tree.Internal where
import Control.DeepSeq (NFData (..))
import Control.Monad.ST (ST, runST)
import Data.Binary (Binary (..))
import Data.Either (isRight)
import Data.Foldable (toList)
import Data.List.Compat (foldl')
import Data.List.NonEmpty (nonEmpty)
import Data.Ord (comparing)
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..))
import Data.Semigroup.Reducer (Reducer (..))
import GHC.TypeLits (KnownNat, Nat, natVal)
import Prelude ()
import Prelude.Compat
import qualified Data.Vector.Algorithms.Heap as VHeap
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as MVU
import Data.TDigest.Internal
import qualified Data.TDigest.Postprocess.Internal as PP
data TDigest (compression :: Nat)
= Node
{-# UNPACK #-} !Size
{-# UNPACK #-} !Mean
{-# UNPACK #-} !Weight
{-# UNPACK #-} !Weight
!(TDigest compression)
!(TDigest compression)
| Nil
deriving (Show)
instance KnownNat comp => Semigroup (TDigest comp) where
(<>) = combineDigest
instance KnownNat comp => Reducer Double (TDigest comp) where
cons = insert
snoc = flip insert
unit = singleton
instance KnownNat comp => Monoid (TDigest comp) where
mempty = emptyTDigest
mappend = combineDigest
instance NFData (TDigest comp) where
rnf x = x `seq` ()
instance KnownNat comp => Binary (TDigest comp) where
put = put . getCentroids
get = foldl' (flip insertCentroid) emptyTDigest . lc <$> get
where
lc :: [Centroid] -> [Centroid]
lc = id
instance PP.HasHistogram (TDigest comp) Maybe where
histogram = fmap PP.histogramFromCentroids . nonEmpty . getCentroids
totalWeight = totalWeight
getCentroids :: TDigest comp -> [Centroid]
getCentroids = ($ []) . go
where
go Nil = id
go (Node _ x w _ l r) = go l . ((x,w) : ) . go r
totalWeight :: TDigest comp -> Weight
totalWeight Nil = 0
totalWeight (Node _ _ _ tw _ _) = tw
size :: TDigest comp -> Int
size Nil = 0
size (Node s _ _ _ _ _) = s
minimumValue :: TDigest comp -> Mean
minimumValue = go posInf
where
go acc Nil = acc
go _acc (Node _ x _ _ l _) = go x l
maximumValue :: TDigest comp -> Mean
maximumValue = go negInf
where
go acc Nil = acc
go _acc (Node _ x _ _ _ r) = go x r
emptyTDigest :: TDigest comp
emptyTDigest = Nil
combineDigest
:: KnownNat comp
=> TDigest comp
-> TDigest comp
-> TDigest comp
combineDigest a Nil = a
combineDigest Nil b = b
combineDigest a@(Node n _ _ _ _ _) b@(Node m _ _ _ _ _)
| n < m = compress $ foldl' (flip insertCentroid) b (getCentroids a)
| otherwise = compress $ foldl' (flip insertCentroid) a (getCentroids b)
insertCentroid
:: forall comp. KnownNat comp
=> Centroid
-> TDigest comp
-> TDigest comp
insertCentroid (x, w) Nil = singNode x w
insertCentroid (mean, weight) td = go 0 mean weight False td
where
n :: Weight
n = totalWeight td + weight
compression :: Double
compression = fromInteger $ natVal (Proxy :: Proxy comp)
go
:: Weight
-> Mean
-> Weight
-> Bool
-> TDigest comp
-> TDigest comp
go _ newX newW _ Nil = singNode newX newW
go cum newX newW e (Node s x w tw l r) = case compare newX x of
EQ -> Node s x (w + newW) (tw + newW) l r
LT | thr <= w -> balanceL x w (go cum newX newW e l) r
GT | thr <= w -> balanceR x w l (go (cum + totalWeight l + w) newX newW e r)
LT | e -> balanceL x w (go cum newX newW e l) r
LT -> case l of
Nil -> case mrw of
Nothing -> node' s nx nw (tw + newW) Nil r
Just rw -> balanceL nx nw (go cum newX rw True Nil) r
Node _ _ _ _ _ _
| lmax < newX && abs (newX - x) < abs (newX - lmax) -> case mrw of
Nothing -> node' s nx nw (tw + nw - w) l r
Just rw -> balanceL nx nw (go cum newX rw True l) r
| otherwise -> balanceL x w (go cum newX newW e l) r
where
lmax = maximumValue l
GT | e -> balanceR x w l (go (cum + totalWeight l + w) newX newW True r)
GT -> case r of
Nil -> case mrw of
Nothing -> node' s nx nw (tw + newW) l Nil
Just rw -> balanceR nx nw l (go (cum + totalWeight l + nw) newX rw True Nil)
Node _ _ _ _ _ _
| rmin > newX && abs (newX - x) < abs (newX - rmin) -> case mrw of
Nothing -> node' s nx nw (tw + newW) l r
Just rw -> balanceR nx nw l (go (cum + totalWeight l + nw) newX rw True r)
| otherwise -> balanceR x w l (go (cum + totalWeight l + w) newX newW e r)
where
rmin = minimumValue r
where
cum' = cum + totalWeight l
q = (w / 2 + cum') / n
thr = threshold n q compression
dw :: Weight
mrw :: Maybe Weight
(dw, mrw) =
let diff = assert (thr > w) "threshold should be larger than current node weight"
$ w + newW - thr
in if diff < 0
then (newW, Nothing)
else (thr - w, Just $ diff)
(nx, nw) = combinedCentroid x w x dw
node :: Mean -> Weight -> TDigest comp -> TDigest comp -> TDigest comp
node x w l r = Node
(1 + size l + size r)
x w
(w + totalWeight l + totalWeight r)
l r
balanceR :: Mean -> Weight -> TDigest comp -> TDigest comp -> TDigest comp
balanceR x w l r
| size l + size r <= 1 = node x w l r
| size r > balOmega * size l = case r of
Nil -> error "balanceR: impossible happened"
(Node _ rx rw _ Nil rr) ->
node rx rw (node x w l Nil) rr
(Node _ rx rw _ rl rr)
| size rl < balAlpha * size rr ->
node rx rw (node x w l rl) rr
(Node _ rx rw _ (Node _ rlx rlw _ rll rlr) rr) ->
node rlx rlw (node x w l rll) (node rx rw rlr rr)
| otherwise = node x w l r
balanceL :: Mean -> Weight -> TDigest comp -> TDigest comp -> TDigest comp
balanceL x w l r
| size l + size r <= 1 = node x w l r
| size l > balOmega * size r = case l of
Nil -> error "balanceL: impossible happened"
(Node _ lx lw _ ll Nil) ->
node lx lw ll (node x w Nil r)
(Node _ lx lw _ ll lr)
| size lr < balAlpha * size ll ->
node lx lw ll (node x w lr r)
(Node _ lx lw _ ll (Node _ lrx lrw _ lrl lrr)) ->
node lrx lrw (node lx lw ll lrl) (node x w lrr r)
| otherwise = node x w l r
node' :: Int -> Mean -> Weight -> Weight -> TDigest comp -> TDigest comp -> TDigest comp
node' = Node
singNode :: Mean -> Weight -> TDigest comp
singNode x w = Node 1 x w w Nil Nil
combinedCentroid
:: Mean -> Weight
-> Mean -> Weight
-> Centroid
combinedCentroid x w x' w' =
( (x * w + x' * w') / w''
, w''
)
where
w'' = w + w'
threshold
:: Double
-> Double
-> Double
-> Double
threshold n q compression = 4 * n * q * (1 - q) / compression
compress :: forall comp. KnownNat comp => TDigest comp -> TDigest comp
compress Nil = Nil
compress td
| size td > relMaxSize * compression && size td > absMaxSize
= forceCompress td
| otherwise
= td
where
compression = fromInteger $ natVal (Proxy :: Proxy comp)
forceCompress :: forall comp. KnownNat comp => TDigest comp -> TDigest comp
forceCompress Nil = Nil
forceCompress td =
foldl' (flip insertCentroid) emptyTDigest $ fmap fst $ VU.toList centroids
where
centroids :: VU.Vector (Centroid, Double)
centroids = runST $ do
v <- toMVector td
VHeap.sortBy (comparing snd) v
f <- VU.unsafeFreeze v
pure f
toMVector
:: forall comp s. KnownNat comp
=> TDigest comp
-> ST s (VU.MVector s (Centroid, Double))
toMVector td = do
v <- MVU.new (size td)
(i, cum) <- go v (0 :: Int) (0 :: Double) td
pure $ assert (i == size td && abs (cum - totalWeight td) < 1e-6) "traversal in toMVector:" v
where
go _ i cum Nil = pure (i, cum)
go v i cum (Node _ x w _ l r) = do
(i', cum') <- go v i cum l
MVU.unsafeWrite v i' ((x, w), space w cum')
go v (i' + 1) (cum' + w) r
n = totalWeight td
compression = fromInteger $ natVal (Proxy :: Proxy comp)
space w cum = thr - w
where
q = (w / 2 + cum) / n
thr = threshold n q compression
relMaxSize :: Int
relMaxSize = 25
absMaxSize :: Int
absMaxSize = 1000
balOmega :: Int
balOmega = 3
balAlpha :: Int
balAlpha = 2
debugPrint :: TDigest comp -> IO ()
debugPrint td = go 0 td
where
go i Nil = putStrLn $ replicate (i * 3) ' ' ++ "Nil"
go i (Node s m w tw l r) = do
go (i + 1) l
putStrLn $ replicate (i * 3) ' ' ++ "Node " ++ show (s,m,w,tw)
go (i + 1) r
valid :: TDigest comp -> Bool
valid = isRight . validate
validate :: TDigest comp -> Either String (TDigest comp)
validate td
| not (all sizeValid centroids) = Left "invalid sizes"
| not (all weightValid centroids) = Left "invalid weights"
| not (all orderValid centroids) = Left "invalid ordering"
| not (all balanced centroids) = Left "tree is ill-balanced"
| otherwise = Right td
where
centroids = goc td
goc Nil = []
goc n@(Node _ _ _ _ l r) = n : goc l ++ goc r
sizeValid Nil = True
sizeValid (Node s _ _ _ l r) = s == size l + size r + 1
weightValid Nil = True
weightValid (Node _ _ w tw l r) = eq tw $ w + totalWeight l + totalWeight r
orderValid Nil = True
orderValid (Node _ _ _ _ Nil Nil) = True
orderValid (Node _ x _ _ (Node _ lx _ _ _ _) Nil) = lx < x
orderValid (Node _ x _ _ Nil (Node _ rx _ _ _ _)) = x < rx
orderValid (Node _ x _ _ (Node _ lx _ _ _ _) (Node _ rx _ _ _ _)) = lx < x && x < rx
balanced Nil = True
balanced (Node _ _ _ _ l r) =
size l <= max 1 (balOmega * size r) &&
size r <= max 1 (balOmega * size l)
insert
:: KnownNat comp
=> Double
-> TDigest comp
-> TDigest comp
insert x = compress . insert' x
insert'
:: KnownNat comp
=> Double
-> TDigest comp
-> TDigest comp
insert' x = insertCentroid (x, 1)
singleton :: KnownNat comp => Double -> TDigest comp
singleton x = insert x emptyTDigest
tdigest :: (Foldable f, KnownNat comp) => f Double -> TDigest comp
tdigest = foldl' insertChunk emptyTDigest . chunks . toList
where
insertChunk td xs =
compress (foldl' (flip insert') td xs)
chunks [] = []
chunks xs =
let (a, b) = splitAt 1000 xs
in a : chunks b