{-# LANGUAGE AllowAmbiguousTypes #-} -- joinA
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-|

E-class analysis, which allows the concise expression of a program analysis over
the e-graph.

An e-class analysis resembles abstract interpretation lifted to the e-graph
level, attaching analysis data from a semilattice to each e-class.

The e-graph maintains and propagates this data as e-classes get merged and new
e-nodes are added.

Analysis data can be used directly to modify the e-graph, to inform how or if
rewrites apply their right-hand sides, or to determine the cost of terms during
the extraction process.

References: https://arxiv.org/pdf/2004.03082.pdf

-}
module Data.Equality.Analysis where

import Data.Kind (Type)
import Control.Arrow ((***))

import Data.Equality.Utils
import Data.Equality.Language
import Data.Equality.Graph.Classes

-- | An e-class analysis with domain @domain@ defined for a language @l@.
--
-- The @domain@ is the type of the domain of the e-class analysis, that is, the
-- type of the data stored in an e-class according to this e-class analysis
class Eq domain => Analysis domain (l :: Type -> Type) where

    -- | When a new e-node is added into a new, singleton e-class, construct a
    -- new value of the domain to be associated with the new e-class, by
    -- accessing the associated data of the node's children
    --
    -- The argument is the e-node term populated with its children data
    --
    -- === Example
    --
    -- @
    -- -- domain = Maybe Double
    -- makeA :: Expr (Maybe Double) -> Maybe Double
    -- makeA = \case
    --     BinOp Div e1 e2 -> liftA2 (/) e1 e2
    --     BinOp Sub e1 e2 -> liftA2 (-) e1 e2
    --     BinOp Mul e1 e2 -> liftA2 (*) e1 e2
    --     BinOp Add e1 e2 -> liftA2 (+) e1 e2
    --     Const x -> Just x
    --     Sym _ -> Nothing
    -- @
    makeA :: l domain -> domain

    -- | When e-classes c1 c2 are being merged into c, join d_c1 and
    -- d_c2 into a new value d_c to be associated with the new
    -- e-class c
    joinA :: domain -> domain -> domain

    -- | Optionally modify the e-class c (based on d_c), typically by adding an
    -- e-node to c. Modify should be idempotent if no other changes occur to
    -- the e-class, i.e., modify(modify(c)) = modify(c)
    --
    -- The return value of the modify function is both the modified class and
    -- the expressions (in their fixed-point form) to add to this class. We
    -- can't manually add them because not only would it skip some of the
    -- internal steps of representing + merging, but also because it's
    -- impossible to add any expression with depth > 0 without access to the
    -- e-graph (since we must represent every sub-expression in the e-graph
    -- first).
    --
    -- That's why we must return the modified class and the expressions to add
    -- to this class.
    --
    -- === Example
    --
    -- Pruning an e-class with a constant value of all its nodes except for the
    -- leaf values, and adding a constant value node
    --
    -- @
    --  -- Prune all except leaf e-nodes
    --  modifyA cl =
    --    case cl^._data of
    --      Nothing -> (cl, [])
    --      Just d -> ((_nodes %~ S.filter (F.null .unNode)) cl, [Fix (Const d)])
    -- @
    modifyA :: EClass domain l -> (EClass domain l, [Fix l])
    modifyA EClass domain l
c = (EClass domain l
c, [])
    {-# INLINE modifyA #-}


-- | The simplest analysis that defines the domain to be () and does nothing
-- otherwise
instance forall l. Analysis () l where
  makeA :: l () -> ()
makeA l ()
_ = ()
  joinA :: () -> () -> ()
joinA = () -> () -> ()
forall a. Semigroup a => a -> a -> a
(<>)


-- This instance is not necessarily well behaved for any two analysis, so care
-- must be taken when using it.
--
-- A possible criterion is:
--
-- For any two analysis, where 'modifyA' is called @m1@ and @m2@ respectively,
-- this instance is well behaved if @m1@ and @m2@ commute.
--
-- That is, if @m1@ and @m2@ satisfy the following law:
-- @
-- m1 . m2 = m2 . m1
-- @
--
-- A simple criterion that should suffice for commutativity. If:
--  * The modify function only depends on the analysis value, and
--  * The modify function doesn't change the analysis value
-- Then any two such functions commute.
--
-- Note: there are weaker (or at least different) criteria for this instance to
-- be well behaved.
instance (Language l, Analysis a l, Analysis b l) => Analysis (a, b) l where

  makeA :: l (a, b) -> (a, b)
  makeA :: l (a, b) -> (a, b)
makeA l (a, b)
g = (forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @a ((a, b) -> a
forall a b. (a, b) -> a
fst ((a, b) -> a) -> l (a, b) -> l a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> l (a, b)
g), forall domain (l :: * -> *).
Analysis domain l =>
l domain -> domain
makeA @b ((a, b) -> b
forall a b. (a, b) -> b
snd ((a, b) -> b) -> l (a, b) -> l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> l (a, b)
g))

  joinA :: (a,b) -> (a,b) -> (a,b)
  joinA :: (a, b) -> (a, b) -> (a, b)
joinA (a
x,b
y) = forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @a @l a
x (a -> a) -> (b -> b) -> (a, b) -> (a, b)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall domain (l :: * -> *).
Analysis domain l =>
domain -> domain -> domain
joinA @b @l b
y

  modifyA :: EClass (a, b) l -> (EClass (a, b) l, [Fix l])
  modifyA :: EClass (a, b) l -> (EClass (a, b) l, [Fix l])
modifyA EClass (a, b) l
c =
    let (EClass a l
ca, [Fix l]
la) = forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA @a (EClass (a, b) l
c { eClassData :: a
eClassData = (a, b) -> a
forall a b. (a, b) -> a
fst (EClass (a, b) l -> (a, b)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassData EClass (a, b) l
c) })
        (EClass b l
cb, [Fix l]
lb) = forall domain (l :: * -> *).
Analysis domain l =>
EClass domain l -> (EClass domain l, [Fix l])
modifyA @b (EClass (a, b) l
c { eClassData :: b
eClassData = (a, b) -> b
forall a b. (a, b) -> b
snd (EClass (a, b) l -> (a, b)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassData EClass (a, b) l
c) })
     in ( ClassId
-> Set (ENode l)
-> (a, b)
-> SList (ClassId, ENode l)
-> EClass (a, b) l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass (EClass (a, b) l -> ClassId
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> ClassId
eClassId EClass (a, b) l
c) (EClass a l -> Set (ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassNodes EClass a l
ca Set (ENode l) -> Set (ENode l) -> Set (ENode l)
forall a. Semigroup a => a -> a -> a
<> EClass b l -> Set (ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassNodes EClass b l
cb) (EClass a l -> a
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassData EClass a l
ca, EClass b l -> b
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassData EClass b l
cb) (EClass a l -> SList (ClassId, ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassParents EClass a l
ca SList (ClassId, ENode l)
-> SList (ClassId, ENode l) -> SList (ClassId, ENode l)
forall a. Semigroup a => a -> a -> a
<> EClass b l -> SList (ClassId, ENode l)
forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassParents EClass b l
cb)
        , [Fix l]
la [Fix l] -> [Fix l] -> [Fix l]
forall a. Semigroup a => a -> a -> a
<> [Fix l]
lb
        )