{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-|

Module defining e-nodes ('ENode'), the e-node function symbol ('Operator'), and
mappings from e-nodes ('NodeMap').

-}
module Data.Equality.Graph.Nodes where

import Data.Foldable
import Data.Bifunctor

import Data.Kind

import Control.Monad (void)

import qualified Data.Map.Strict as M

import Data.Equality.Graph.Classes.Id


-- * E-node

-- | An e-node is a function symbol paired with a list of children e-classes.
-- 
-- We define an e-node to be the base functor of some recursive data type
-- parametrized over 'ClassId', i.e. all recursive fields are rather e-class ids.
newtype ENode l = Node { forall (l :: * -> *). ENode l -> l ClassId
unNode :: l ClassId }

deriving instance Eq (l ClassId) => (Eq (ENode l))
deriving instance Ord (l ClassId) => (Ord (ENode l))
deriving instance Show (l ClassId) => (Show (ENode l))

-- | Get the children e-class ids of an e-node
children :: Traversable l => ENode l -> [ClassId]
children :: forall (l :: * -> *). Traversable l => ENode l -> [ClassId]
children = l ClassId -> [ClassId]
forall a. l a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (l ClassId -> [ClassId])
-> (ENode l -> l ClassId) -> ENode l -> [ClassId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode
{-# INLINE children #-}

-- * Operator

-- | An operator is solely the function symbol part of the e-node. Basically,
-- this means children e-classes are ignored.
newtype Operator l = Operator { forall (l :: * -> *). Operator l -> l ()
unOperator :: l () }

deriving instance Eq (l ()) => (Eq (Operator l))
deriving instance Ord (l ()) => (Ord (Operator l))
deriving instance Show (l ()) => (Show (Operator l))

-- | Get the operator (function symbol) of an e-node
operator :: Traversable l => ENode l -> Operator l
operator :: forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator = l () -> Operator l
forall (l :: * -> *). l () -> Operator l
Operator (l () -> Operator l) -> (ENode l -> l ()) -> ENode l -> Operator l
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l ClassId -> l ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (l ClassId -> l ()) -> (ENode l -> l ClassId) -> ENode l -> l ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode l -> l ClassId
forall (l :: * -> *). ENode l -> l ClassId
unNode
{-# INLINE operator #-}

-- * Node Map

-- | A mapping from e-nodes of @l@ to @a@
newtype NodeMap (l :: Type -> Type) a = NodeMap { forall (l :: * -> *) a. NodeMap l a -> Map (ENode l) a
unNodeMap :: M.Map (ENode l) a }
-- TODO: Investigate whether it would be worth it requiring a trie-map for the
-- e-node definition. Probably it isn't better since e-nodes aren't recursive.
  deriving ((forall a b. (a -> b) -> NodeMap l a -> NodeMap l b)
-> (forall a b. a -> NodeMap l b -> NodeMap l a)
-> Functor (NodeMap l)
forall a b. a -> NodeMap l b -> NodeMap l a
forall a b. (a -> b) -> NodeMap l a -> NodeMap l b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (l :: * -> *) a b. a -> NodeMap l b -> NodeMap l a
forall (l :: * -> *) a b. (a -> b) -> NodeMap l a -> NodeMap l b
$cfmap :: forall (l :: * -> *) a b. (a -> b) -> NodeMap l a -> NodeMap l b
fmap :: forall a b. (a -> b) -> NodeMap l a -> NodeMap l b
$c<$ :: forall (l :: * -> *) a b. a -> NodeMap l b -> NodeMap l a
<$ :: forall a b. a -> NodeMap l b -> NodeMap l a
Functor, (forall m. Monoid m => NodeMap l m -> m)
-> (forall m a. Monoid m => (a -> m) -> NodeMap l a -> m)
-> (forall m a. Monoid m => (a -> m) -> NodeMap l a -> m)
-> (forall a b. (a -> b -> b) -> b -> NodeMap l a -> b)
-> (forall a b. (a -> b -> b) -> b -> NodeMap l a -> b)
-> (forall b a. (b -> a -> b) -> b -> NodeMap l a -> b)
-> (forall b a. (b -> a -> b) -> b -> NodeMap l a -> b)
-> (forall a. (a -> a -> a) -> NodeMap l a -> a)
-> (forall a. (a -> a -> a) -> NodeMap l a -> a)
-> (forall a. NodeMap l a -> [a])
-> (forall a. NodeMap l a -> Bool)
-> (forall a. NodeMap l a -> ClassId)
-> (forall a. Eq a => a -> NodeMap l a -> Bool)
-> (forall a. Ord a => NodeMap l a -> a)
-> (forall a. Ord a => NodeMap l a -> a)
-> (forall a. Num a => NodeMap l a -> a)
-> (forall a. Num a => NodeMap l a -> a)
-> Foldable (NodeMap l)
forall a. Eq a => a -> NodeMap l a -> Bool
forall a. Num a => NodeMap l a -> a
forall a. Ord a => NodeMap l a -> a
forall m. Monoid m => NodeMap l m -> m
forall a. NodeMap l a -> Bool
forall a. NodeMap l a -> ClassId
forall a. NodeMap l a -> [a]
forall a. (a -> a -> a) -> NodeMap l a -> a
forall m a. Monoid m => (a -> m) -> NodeMap l a -> m
forall b a. (b -> a -> b) -> b -> NodeMap l a -> b
forall a b. (a -> b -> b) -> b -> NodeMap l a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> ClassId)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
forall (l :: * -> *) a. Eq a => a -> NodeMap l a -> Bool
forall (l :: * -> *) a. Num a => NodeMap l a -> a
forall (l :: * -> *) a. Ord a => NodeMap l a -> a
forall (l :: * -> *) m. Monoid m => NodeMap l m -> m
forall (l :: * -> *) a. NodeMap l a -> Bool
forall (l :: * -> *) a. NodeMap l a -> ClassId
forall (l :: * -> *) a. NodeMap l a -> [a]
forall (l :: * -> *) a. (a -> a -> a) -> NodeMap l a -> a
forall (l :: * -> *) m a. Monoid m => (a -> m) -> NodeMap l a -> m
forall (l :: * -> *) b a. (b -> a -> b) -> b -> NodeMap l a -> b
forall (l :: * -> *) a b. (a -> b -> b) -> b -> NodeMap l a -> b
$cfold :: forall (l :: * -> *) m. Monoid m => NodeMap l m -> m
fold :: forall m. Monoid m => NodeMap l m -> m
$cfoldMap :: forall (l :: * -> *) m a. Monoid m => (a -> m) -> NodeMap l a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> NodeMap l a -> m
$cfoldMap' :: forall (l :: * -> *) m a. Monoid m => (a -> m) -> NodeMap l a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> NodeMap l a -> m
$cfoldr :: forall (l :: * -> *) a b. (a -> b -> b) -> b -> NodeMap l a -> b
foldr :: forall a b. (a -> b -> b) -> b -> NodeMap l a -> b
$cfoldr' :: forall (l :: * -> *) a b. (a -> b -> b) -> b -> NodeMap l a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> NodeMap l a -> b
$cfoldl :: forall (l :: * -> *) b a. (b -> a -> b) -> b -> NodeMap l a -> b
foldl :: forall b a. (b -> a -> b) -> b -> NodeMap l a -> b
$cfoldl' :: forall (l :: * -> *) b a. (b -> a -> b) -> b -> NodeMap l a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> NodeMap l a -> b
$cfoldr1 :: forall (l :: * -> *) a. (a -> a -> a) -> NodeMap l a -> a
foldr1 :: forall a. (a -> a -> a) -> NodeMap l a -> a
$cfoldl1 :: forall (l :: * -> *) a. (a -> a -> a) -> NodeMap l a -> a
foldl1 :: forall a. (a -> a -> a) -> NodeMap l a -> a
$ctoList :: forall (l :: * -> *) a. NodeMap l a -> [a]
toList :: forall a. NodeMap l a -> [a]
$cnull :: forall (l :: * -> *) a. NodeMap l a -> Bool
null :: forall a. NodeMap l a -> Bool
$clength :: forall (l :: * -> *) a. NodeMap l a -> ClassId
length :: forall a. NodeMap l a -> ClassId
$celem :: forall (l :: * -> *) a. Eq a => a -> NodeMap l a -> Bool
elem :: forall a. Eq a => a -> NodeMap l a -> Bool
$cmaximum :: forall (l :: * -> *) a. Ord a => NodeMap l a -> a
maximum :: forall a. Ord a => NodeMap l a -> a
$cminimum :: forall (l :: * -> *) a. Ord a => NodeMap l a -> a
minimum :: forall a. Ord a => NodeMap l a -> a
$csum :: forall (l :: * -> *) a. Num a => NodeMap l a -> a
sum :: forall a. Num a => NodeMap l a -> a
$cproduct :: forall (l :: * -> *) a. Num a => NodeMap l a -> a
product :: forall a. Num a => NodeMap l a -> a
Foldable, Functor (NodeMap l)
Foldable (NodeMap l)
(Functor (NodeMap l), Foldable (NodeMap l)) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> NodeMap l a -> f (NodeMap l b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    NodeMap l (f a) -> f (NodeMap l a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> NodeMap l a -> m (NodeMap l b))
-> (forall (m :: * -> *) a.
    Monad m =>
    NodeMap l (m a) -> m (NodeMap l a))
-> Traversable (NodeMap l)
forall (l :: * -> *). Functor (NodeMap l)
forall (l :: * -> *). Foldable (NodeMap l)
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a.
Monad m =>
NodeMap l (m a) -> m (NodeMap l a)
forall (f :: * -> *) a.
Applicative f =>
NodeMap l (f a) -> f (NodeMap l a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> NodeMap l a -> m (NodeMap l b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> NodeMap l a -> f (NodeMap l b)
forall (l :: * -> *) (m :: * -> *) a.
Monad m =>
NodeMap l (m a) -> m (NodeMap l a)
forall (l :: * -> *) (f :: * -> *) a.
Applicative f =>
NodeMap l (f a) -> f (NodeMap l a)
forall (l :: * -> *) (m :: * -> *) a b.
Monad m =>
(a -> m b) -> NodeMap l a -> m (NodeMap l b)
forall (l :: * -> *) (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> NodeMap l a -> f (NodeMap l b)
$ctraverse :: forall (l :: * -> *) (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> NodeMap l a -> f (NodeMap l b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> NodeMap l a -> f (NodeMap l b)
$csequenceA :: forall (l :: * -> *) (f :: * -> *) a.
Applicative f =>
NodeMap l (f a) -> f (NodeMap l a)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
NodeMap l (f a) -> f (NodeMap l a)
$cmapM :: forall (l :: * -> *) (m :: * -> *) a b.
Monad m =>
(a -> m b) -> NodeMap l a -> m (NodeMap l b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> NodeMap l a -> m (NodeMap l b)
$csequence :: forall (l :: * -> *) (m :: * -> *) a.
Monad m =>
NodeMap l (m a) -> m (NodeMap l a)
sequence :: forall (m :: * -> *) a.
Monad m =>
NodeMap l (m a) -> m (NodeMap l a)
Traversable)

deriving instance (Show a, Show (l ClassId)) => Show (NodeMap l a)
deriving instance Ord (l ClassId) => Semigroup (NodeMap l a)
deriving instance Ord (l ClassId) => Monoid (NodeMap l a)

-- | Insert a value given an e-node in a 'NodeMap'
insertNM :: Ord (l ClassId) => ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM :: forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> NodeMap l a
insertNM ENode l
e a
v (NodeMap Map (ENode l) a
m) = Map (ENode l) a -> NodeMap l a
forall (l :: * -> *) a. Map (ENode l) a -> NodeMap l a
NodeMap (ENode l -> a -> Map (ENode l) a -> Map (ENode l) a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert ENode l
e a
v Map (ENode l) a
m)
{-# INLINE insertNM #-}

-- | Lookup an e-node in a 'NodeMap'
lookupNM :: Ord (l ClassId) => ENode l -> NodeMap l a -> Maybe a
lookupNM :: forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> Maybe a
lookupNM ENode l
e = ENode l -> Map (ENode l) a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ENode l
e (Map (ENode l) a -> Maybe a)
-> (NodeMap l a -> Map (ENode l) a) -> NodeMap l a -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeMap l a -> Map (ENode l) a
forall (l :: * -> *) a. NodeMap l a -> Map (ENode l) a
unNodeMap
{-# INLINE lookupNM #-}

-- | Delete an e-node in a 'NodeMap'
deleteNM :: Ord (l ClassId) => ENode l -> NodeMap l a -> NodeMap l a
deleteNM :: forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> NodeMap l a -> NodeMap l a
deleteNM ENode l
e (NodeMap Map (ENode l) a
m) = Map (ENode l) a -> NodeMap l a
forall (l :: * -> *) a. Map (ENode l) a -> NodeMap l a
NodeMap (ENode l -> Map (ENode l) a -> Map (ENode l) a
forall k a. Ord k => k -> Map k a -> Map k a
M.delete ENode l
e Map (ENode l) a
m)
{-# INLINE deleteNM #-}

-- | Insert a value and lookup by e-node in a 'NodeMap'
insertLookupNM :: Ord (l ClassId) => ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM :: forall (l :: * -> *) a.
Ord (l ClassId) =>
ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
insertLookupNM ENode l
e a
v (NodeMap Map (ENode l) a
m) = (Map (ENode l) a -> NodeMap l a)
-> (Maybe a, Map (ENode l) a) -> (Maybe a, NodeMap l a)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Map (ENode l) a -> NodeMap l a
forall (l :: * -> *) a. Map (ENode l) a -> NodeMap l a
NodeMap ((Maybe a, Map (ENode l) a) -> (Maybe a, NodeMap l a))
-> (Maybe a, Map (ENode l) a) -> (Maybe a, NodeMap l a)
forall a b. (a -> b) -> a -> b
$ (ENode l -> a -> a -> a)
-> ENode l -> a -> Map (ENode l) a -> (Maybe a, Map (ENode l) a)
forall k a.
Ord k =>
(k -> a -> a -> a) -> k -> a -> Map k a -> (Maybe a, Map k a)
M.insertLookupWithKey (\ENode l
_ a
a a
_ -> a
a) ENode l
e a
v Map (ENode l) a
m
{-# INLINE insertLookupNM #-}

-- | As 'Data.Map.foldlWithKeyNM'' but in a 'NodeMap'
foldlWithKeyNM' :: Ord (l ClassId) => (b -> ENode l -> a -> b) -> b -> NodeMap l a -> b 
foldlWithKeyNM' :: forall (l :: * -> *) b a.
Ord (l ClassId) =>
(b -> ENode l -> a -> b) -> b -> NodeMap l a -> b
foldlWithKeyNM' b -> ENode l -> a -> b
f b
b = (b -> ENode l -> a -> b) -> b -> Map (ENode l) a -> b
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' b -> ENode l -> a -> b
f b
b (Map (ENode l) a -> b)
-> (NodeMap l a -> Map (ENode l) a) -> NodeMap l a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeMap l a -> Map (ENode l) a
forall (l :: * -> *) a. NodeMap l a -> Map (ENode l) a
unNodeMap
{-# INLINE foldlWithKeyNM' #-}

-- | As 'Data.Map.foldrWithKeyNM'' but in a 'NodeMap'
foldrWithKeyNM' :: Ord (l ClassId) => (ENode l -> a -> b -> b) -> b -> NodeMap l a -> b 
foldrWithKeyNM' :: forall (l :: * -> *) a b.
Ord (l ClassId) =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> a -> b -> b
f b
b = (ENode l -> a -> b -> b) -> b -> Map (ENode l) a -> b
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey' ENode l -> a -> b -> b
f b
b (Map (ENode l) a -> b)
-> (NodeMap l a -> Map (ENode l) a) -> NodeMap l a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeMap l a -> Map (ENode l) a
forall (l :: * -> *) a. NodeMap l a -> Map (ENode l) a
unNodeMap
{-# INLINE foldrWithKeyNM' #-}

-- | Get the number of entries in a 'NodeMap'.
--
-- This operation takes constant time (__O(1)__)
sizeNM :: NodeMap l a -> Int
sizeNM :: forall (l :: * -> *) a. NodeMap l a -> ClassId
sizeNM = Map (ENode l) a -> ClassId
forall k a. Map k a -> ClassId
M.size (Map (ENode l) a -> ClassId)
-> (NodeMap l a -> Map (ENode l) a) -> NodeMap l a -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeMap l a -> Map (ENode l) a
forall (l :: * -> *) a. NodeMap l a -> Map (ENode l) a
unNodeMap
{-# INLINE sizeNM #-}

-- | As 'Data.Map.traverseWithKeyNM' but in a 'NodeMap'
traverseWithKeyNM :: Applicative t => (ENode l -> a -> t b) -> NodeMap l a -> t (NodeMap l b) 
traverseWithKeyNM :: forall (t :: * -> *) (l :: * -> *) a b.
Applicative t =>
(ENode l -> a -> t b) -> NodeMap l a -> t (NodeMap l b)
traverseWithKeyNM ENode l -> a -> t b
f (NodeMap Map (ENode l) a
m) = Map (ENode l) b -> NodeMap l b
forall (l :: * -> *) a. Map (ENode l) a -> NodeMap l a
NodeMap (Map (ENode l) b -> NodeMap l b)
-> t (Map (ENode l) b) -> t (NodeMap l b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ENode l -> a -> t b) -> Map (ENode l) a -> t (Map (ENode l) b)
forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
M.traverseWithKey ENode l -> a -> t b
f Map (ENode l) a
m
{-# INLINE traverseWithKeyNM #-}

-- Node Set

-- newtype NodeSet l a = NodeSet { unNodeSet :: IM.IntMap (a, ENode l) }
--   deriving (Semigroup, Monoid)

-- insertNS :: Hashable1 l => ENode l -> NodeSet l -> NodeSet l
-- insertNS v = NodeSet . IM.insert (hashNode v) v . unNodeSet