module Myxine.ConjMap
  ( ConjMap
  , empty
  , lookup
  , insert
  , union
  ) where

import Prelude hiding (lookup)
import Data.Maybe
import Data.Hashable
import qualified Data.HashSet as HashSet
import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict (HashMap)

data ConjMap k a
  = ConjMap [a] (HashMap k (ConjMap k a))
  deriving (Eq, Ord, Show, Functor, Foldable, Traversable)

empty :: ConjMap k a
empty = ConjMap [] HashMap.empty

-- | Retrieve all items whose keys are a (non-strict) subset of the specified
-- keys.
lookup :: (Eq k, Hashable k) => [k] -> ConjMap k a -> [a]
lookup = go . HashSet.fromList
  where
    go facts (ConjMap universal specific) =
      universal <> fromMaybe [] (goSpecific facts specific)

    goSpecific facts specific =
      foldMap (\fact -> go (HashSet.delete fact facts) <$>
                           HashMap.lookup fact specific) facts

-- | Add an item such that it can be retrieved only by giving a (non-strict)
-- superset of all the specified keys.
insert :: (Eq k, Hashable k) => [k] -> a -> ConjMap k a -> ConjMap k a
insert patList a =
  goTree (HashSet.fromList patList)
  where
    -- Invariant: no pattern appears twice in a branch of a tree, although it
    -- can occur multiple times in separate branches
    goTree pats (ConjMap universal specific)
      | HashSet.null pats = ConjMap (a : universal) specific
      | otherwise = ConjMap universal (goNodes pats (HashMap.toList specific))

    goNodes pats [] = freshNode pats
    goNodes pats ((k, t) : rest)
      | HashSet.member k pats =
        HashMap.fromList ((k, goTree (HashSet.delete k pats) t) : rest)
      | otherwise =
        HashMap.insert k t (goNodes pats rest)

    freshNode (HashSet.toList -> pats) = freshNode' pats
      where
        freshNode' []       = error "Internal error: addMatch: [] passed to freshNode"
        freshNode' [p]      = HashMap.singleton p (ConjMap [a] HashMap.empty)
        freshNode' (p : ps) = HashMap.singleton p (ConjMap [] (freshNode' ps))

union :: (Eq k, Hashable k) => ConjMap k a -> ConjMap k a -> ConjMap k a
union (ConjMap u s) (ConjMap u' s') =
  ConjMap (u <> u') (HashMap.unionWith union s s')

instance (Eq k, Hashable k) => Semigroup (ConjMap k a) where
  (<>) = union

instance (Eq k, Hashable k) => Monoid (ConjMap k a) where
  mempty = empty