{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Multiset (
Multiset, Group,
empty, singleton, replicate,
fromList, fromGroupList,
fromCountMap,
null,
size, distinctSize,
member, notMember,
isSubsetOf, isProperSubsetOf,
count, (!),
insert, remove, removeAll, modify,
map, mapGroups,
filter, filterGroups,
max, min, difference, unionWith, intersectionWith,
toSet,
toGroupList, toGrowingGroupList, toShrinkingGroupList,
toCountMap,
mostCommon
) where
import Prelude hiding (filter, foldr, map, max, min, null, replicate)
import qualified Prelude as Prelude
import Data.Binary (Binary(..))
import Data.Foldable (foldl', foldr, toList)
import Data.List (groupBy, sortOn)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Semigroup (Semigroup, (<>))
import Data.Set (Set)
import qualified Data.Set as Set
import qualified GHC.Exts
data Multiset v = Multiset
{ _toMap :: !(Map v Int)
, _size :: !Int
} deriving (Eq, Ord, Read, Show)
type Group v = (v, Int)
instance Ord v => Semigroup (Multiset v) where
(<>) = unionWith' (+)
instance Ord v => Monoid (Multiset v) where
mempty = empty
instance Foldable Multiset where
foldr f r0 (Multiset m _) = Map.foldrWithKey go r0 m where
go v n r1 = foldr f r1 $ replicate n v
instance Binary v => Binary (Multiset v) where
put (Multiset m s) = put m <> put s
get = Multiset <$> get <*> get
#if __GLASGOW_HASKELL__ >= 708
instance Ord v => GHC.Exts.IsList (Multiset v) where
type Item (Multiset v) = v
fromList = fromList
toList = toList
#endif
null :: Multiset v -> Bool
null = Map.null . _toMap
size :: Multiset v -> Int
size = _size
distinctSize :: Multiset v -> Int
distinctSize = Map.size . _toMap
empty :: Multiset v
empty = Multiset Map.empty 0
singleton :: v -> Multiset v
singleton = replicate 1
replicate :: Int -> v -> Multiset v
replicate n v = if n > 0
then Multiset (Map.singleton v n) n
else empty
fromCountMap :: Ord v => Map v Int -> Multiset v
fromCountMap = Map.foldlWithKey' go empty where
go ms v n = if n > 0
then modify (+ n) v ms
else ms
fromList :: Ord v => [v] -> Multiset v
fromList = foldl' (flip insert) empty
fromGroupList :: Ord v => [Group v] -> Multiset v
fromGroupList = foldl' go empty where
go ms (v,n) = modify (+ n) v ms
member :: Ord v => v -> Multiset v -> Bool
member v = Map.member v . _toMap
notMember :: Ord v => v -> Multiset v -> Bool
notMember v = Map.notMember v . _toMap
count :: Ord v => v -> Multiset v -> Int
count v = Map.findWithDefault 0 v . _toMap
(!) :: Ord v => Multiset v -> v -> Int
(!) = flip count
modify :: Ord v => (Int -> Int) -> v -> Multiset v -> Multiset v
modify f v ms@(Multiset m s) = Multiset m' s' where
n = count v ms
n' = Prelude.max 0 (f n)
m' = if n' > 0 then Map.insert v n' m else Map.delete v m
s' = s - n + n'
insert :: Ord v => v -> Multiset v -> Multiset v
insert = modify (+1)
remove :: Ord v => v -> Multiset v -> Multiset v
remove = modify (subtract 1)
removeAll :: Ord v => v -> Multiset v -> Multiset v
removeAll = modify (const 0)
filter :: Ord v => (v -> Bool) -> Multiset v -> Multiset v
filter f = filterGroups (f . fst)
filterGroups :: Ord v => (Group v -> Bool) -> Multiset v -> Multiset v
filterGroups f (Multiset m _) = Map.foldlWithKey' go empty m where
go ms v n = if f (v,n)
then modify (+ n) v ms
else ms
map :: (Ord v1, Ord v2) => (v1 -> v2) -> Multiset v1 -> Multiset v2
map f (Multiset m s) = Multiset (Map.mapKeysWith (+) f m) s
mapGroups :: Ord v => (Group v -> Group v) -> Multiset v -> Multiset v
mapGroups f ms = fromGroupList $ fmap f $ toGroupList ms
max :: Ord v => Multiset v -> Multiset v -> Multiset v
max = unionWith' Prelude.max
min :: Ord v => Multiset v -> Multiset v -> Multiset v
min = intersectionWith Prelude.min
unionWith :: Ord v => (Int -> Int -> Int) -> Multiset v -> Multiset v -> Multiset v
unionWith f ms1 ms2 = fromGroupList $ fmap go $ toList vs where
vs = Set.union (toSet ms1) (toSet ms2)
go v = (v, (f (count v ms1) (count v ms2)))
intersectionWith :: Ord v => (Int -> Int -> Int) -> Multiset v -> Multiset v -> Multiset v
intersectionWith f (Multiset m1 _) (Multiset m2 _) = fromCountMap $ Map.intersectionWith f m1 m2
difference :: Ord v => Multiset v -> Multiset v -> Multiset v
difference (Multiset m1 _) (Multiset m2 _) = fromCountMap $ Map.differenceWith go m1 m2 where
go n1 n2 = let n = n1 - n2 in if n > 0 then Just n else Nothing
isSubsetOf :: Ord v => Multiset v -> Multiset v -> Bool
isSubsetOf (Multiset m _) ms = Map.foldrWithKey go True m where
go v n r = count v ms >= n && r
isProperSubsetOf :: Ord v => Multiset v -> Multiset v -> Bool
isProperSubsetOf ms1 ms2 = size ms1 < size ms2 && ms1 `isSubsetOf` ms2
toCountMap :: Multiset v -> Map v Int
toCountMap = _toMap
toSet :: Multiset v -> Set v
toSet = Map.keysSet . _toMap
toGroupList :: Multiset v -> [Group v]
toGroupList = Map.toList . _toMap
toGrowingGroupList :: Multiset v -> [Group v]
toGrowingGroupList = sortOn snd . toGroupList
toShrinkingGroupList :: Multiset v -> [Group v]
toShrinkingGroupList = sortOn (negate . snd) . toGroupList
mostCommon :: Multiset v -> [(Int, [v])]
mostCommon = fmap go . groupBy (\e1 e2 -> snd e1 == snd e2) . toShrinkingGroupList where
go ((v, n) : groups) = (n, v : fmap fst groups)
go _ = error "unreachable"
unionWith' :: Ord v => (Int -> Int -> Int) -> Multiset v -> Multiset v -> Multiset v
unionWith' f (Multiset m1 _) (Multiset m2 _) = fromCountMap $ Map.unionWith f m1 m2