{-# LANGUAGE CPP #-}

-- | Create clusters of non-overlapping things.

module Agda.Utils.Cluster
  ( C
  , cluster
  , cluster'
  ) where

import Control.Monad

-- An imperative union-find library:
import Data.Equivalence.Monad (runEquivT, equateAll, classDesc)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty

import qualified Data.IntMap as IntMap
#if __GLASGOW_HASKELL__ < 804
import Data.Semigroup
#endif

import Agda.Utils.Functor
import Agda.Utils.Singleton
import Agda.Utils.Fail

-- | Characteristic identifiers.
type C = Int

-- | Given a function @f :: a -> NonEmpty C@ which returns a non-empty list of
--   characteristics @C@ of @a@, partition a list of @a@s into groups
--   such that each element in a group shares at least one characteristic
--   with at least one other element of the group.
cluster :: (a -> NonEmpty C) -> [a] -> [NonEmpty a]
cluster :: (a -> NonEmpty C) -> [a] -> [NonEmpty a]
cluster a -> NonEmpty C
f [a]
as = [(a, NonEmpty C)] -> [NonEmpty a]
forall a. [(a, NonEmpty C)] -> [NonEmpty a]
cluster' ([(a, NonEmpty C)] -> [NonEmpty a])
-> [(a, NonEmpty C)] -> [NonEmpty a]
forall a b. (a -> b) -> a -> b
$ (a -> (a, NonEmpty C)) -> [a] -> [(a, NonEmpty C)]
forall a b. (a -> b) -> [a] -> [b]
map (\ a
a -> (a
a, a -> NonEmpty C
f a
a)) [a]
as

-- | Partition a list of @a@s paired with a non-empty list of
--   characteristics @C@ into groups
--   such that each element in a group shares at least one characteristic
--   with at least one other element of the group.
cluster' :: [(a, NonEmpty C)] -> [NonEmpty a]
cluster' :: [(a, NonEmpty C)] -> [NonEmpty a]
cluster' [(a, NonEmpty C)]
acs = Fail [NonEmpty a] -> [NonEmpty a]
forall a. Fail a -> a
runFail_ (Fail [NonEmpty a] -> [NonEmpty a])
-> Fail [NonEmpty a] -> [NonEmpty a]
forall a b. (a -> b) -> a -> b
$ (C -> C)
-> (C -> C -> C)
-> (forall s. EquivT s C C Fail [NonEmpty a])
-> Fail [NonEmpty a]
forall (m :: * -> *) v c a.
(Monad m, Applicative m) =>
(v -> c) -> (c -> c -> c) -> (forall s. EquivT s c v m a) -> m a
runEquivT C -> C
forall a. a -> a
id C -> C -> C
forall a b. a -> b -> a
const ((forall s. EquivT s C C Fail [NonEmpty a]) -> Fail [NonEmpty a])
-> (forall s. EquivT s C C Fail [NonEmpty a]) -> Fail [NonEmpty a]
forall a b. (a -> b) -> a -> b
$ do
  -- Construct the equivalence classes of characteristics.
  [(a, NonEmpty C)]
-> ((a, NonEmpty C) -> EquivT s C C Fail ())
-> EquivT s C C Fail ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(a, NonEmpty C)]
acs (((a, NonEmpty C) -> EquivT s C C Fail ()) -> EquivT s C C Fail ())
-> ((a, NonEmpty C) -> EquivT s C C Fail ())
-> EquivT s C C Fail ()
forall a b. (a -> b) -> a -> b
$ \ (a
_, C
c :| [C]
cs) -> [C] -> EquivT s C C Fail ()
forall c v d (m :: * -> *). MonadEquiv c v d m => [v] -> m ()
equateAll ([C] -> EquivT s C C Fail ()) -> [C] -> EquivT s C C Fail ()
forall a b. (a -> b) -> a -> b
$ C
cC -> [C] -> [C]
forall a. a -> [a] -> [a]
:[C]
cs
  -- Pair each element with its class.
  [IntMap (NonEmpty a)]
cas <- [(a, NonEmpty C)]
-> ((a, NonEmpty C) -> EquivT s C C Fail (IntMap (NonEmpty a)))
-> EquivT s C C Fail [IntMap (NonEmpty a)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(a, NonEmpty C)]
acs (((a, NonEmpty C) -> EquivT s C C Fail (IntMap (NonEmpty a)))
 -> EquivT s C C Fail [IntMap (NonEmpty a)])
-> ((a, NonEmpty C) -> EquivT s C C Fail (IntMap (NonEmpty a)))
-> EquivT s C C Fail [IntMap (NonEmpty a)]
forall a b. (a -> b) -> a -> b
$ \ (a
a, C
c :| [C]
_) -> C -> EquivT s C C Fail C
forall c v d (m :: * -> *). MonadEquiv c v d m => v -> m d
classDesc C
c EquivT s C C Fail C
-> (C -> IntMap (NonEmpty a))
-> EquivT s C C Fail (IntMap (NonEmpty a))
forall (m :: * -> *) a b. Functor m => m a -> (a -> b) -> m b
<&> \ C
k -> C -> NonEmpty a -> IntMap (NonEmpty a)
forall a. C -> a -> IntMap a
IntMap.singleton C
k (a -> NonEmpty a
forall el coll. Singleton el coll => el -> coll
singleton a
a)
  -- Create a map from class to elements.
  let m :: IntMap (NonEmpty a)
m = (NonEmpty a -> NonEmpty a -> NonEmpty a)
-> [IntMap (NonEmpty a)] -> IntMap (NonEmpty a)
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> f (IntMap a) -> IntMap a
IntMap.unionsWith NonEmpty a -> NonEmpty a -> NonEmpty a
forall a. Semigroup a => a -> a -> a
(<>) [IntMap (NonEmpty a)]
cas
  -- Return the values of the map
  [NonEmpty a] -> EquivT s C C Fail [NonEmpty a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([NonEmpty a] -> EquivT s C C Fail [NonEmpty a])
-> [NonEmpty a] -> EquivT s C C Fail [NonEmpty a]
forall a b. (a -> b) -> a -> b
$ IntMap (NonEmpty a) -> [NonEmpty a]
forall a. IntMap a -> [a]
IntMap.elems IntMap (NonEmpty a)
m