{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Equality.Extraction
(
extractBest
, CostFunction
, Cost
, depthCost
) where
import qualified Data.Set as S
import qualified Data.IntMap.Strict as IM
import Data.Equality.Utils
import Data.Equality.Graph
extractBest :: forall lang. Language lang
=> EGraph lang
-> CostFunction lang
-> ClassId
-> Fix lang
g :: EGraph lang
g@EGraph{classes :: forall (l :: * -> *). EGraph l -> ClassIdMap (EClass l)
classes = ClassIdMap (EClass lang)
eclasses'} CostFunction lang
cost ((ClassId -> EGraph lang -> ClassId)
-> EGraph lang -> ClassId -> ClassId
forall a b c. (a -> b -> c) -> b -> a -> c
flip ClassId -> EGraph lang -> ClassId
forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find EGraph lang
g -> ClassId
i) =
let allCosts :: ClassIdMap (CostWithExpr lang)
allCosts = ClassIdMap (EClass lang)
-> ClassIdMap (CostWithExpr lang) -> ClassIdMap (CostWithExpr lang)
findCosts ClassIdMap (EClass lang)
eclasses' ClassIdMap (CostWithExpr lang)
forall a. Monoid a => a
mempty
in case ClassId
-> ClassIdMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
forall (lang :: * -> *).
ClassId
-> ClassIdMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
findBest ClassId
i ClassIdMap (CostWithExpr lang)
allCosts of
Just (CostWithExpr (ClassId
_,Fix lang
n)) -> Fix lang
n
Maybe (CostWithExpr lang)
Nothing -> [Char] -> Fix lang
forall a. HasCallStack => [Char] -> a
error ([Char] -> Fix lang) -> [Char] -> Fix lang
forall a b. (a -> b) -> a -> b
$ [Char]
"Couldn't find a best node for e-class " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> ClassId -> [Char]
forall a. Show a => a -> [Char]
show ClassId
i
where
findCosts :: ClassIdMap (EClass lang) -> ClassIdMap (CostWithExpr lang) -> ClassIdMap (CostWithExpr lang)
findCosts :: ClassIdMap (EClass lang)
-> ClassIdMap (CostWithExpr lang) -> ClassIdMap (CostWithExpr lang)
findCosts ClassIdMap (EClass lang)
eclasses ClassIdMap (CostWithExpr lang)
current =
let (Bool
modified, ClassIdMap (CostWithExpr lang)
updated) = ((Bool, ClassIdMap (CostWithExpr lang))
-> ClassId
-> EClass lang
-> (Bool, ClassIdMap (CostWithExpr lang)))
-> (Bool, ClassIdMap (CostWithExpr lang))
-> ClassIdMap (EClass lang)
-> (Bool, ClassIdMap (CostWithExpr lang))
forall a b. (a -> ClassId -> b -> a) -> a -> IntMap b -> a
IM.foldlWithKey (Bool, ClassIdMap (CostWithExpr lang))
-> ClassId -> EClass lang -> (Bool, ClassIdMap (CostWithExpr lang))
f (Bool
False, ClassIdMap (CostWithExpr lang)
current) ClassIdMap (EClass lang)
eclasses
{-# INLINE f #-}
f :: (Bool, ClassIdMap (CostWithExpr lang)) -> Int -> EClass lang -> (Bool, ClassIdMap (CostWithExpr lang))
f :: (Bool, ClassIdMap (CostWithExpr lang))
-> ClassId -> EClass lang -> (Bool, ClassIdMap (CostWithExpr lang))
f = \acc :: (Bool, ClassIdMap (CostWithExpr lang))
acc@(Bool
_, ClassIdMap (CostWithExpr lang)
beingUpdated) ClassId
i' (EClass ClassId
_ Set (ENode lang)
nodes Domain lang
_ NodeMap lang ClassId
_) ->
let
currentCost :: Maybe (CostWithExpr lang)
currentCost = ClassId
-> ClassIdMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
i' ClassIdMap (CostWithExpr lang)
beingUpdated
newCost :: Maybe (CostWithExpr lang)
newCost = (Maybe (CostWithExpr lang)
-> ENode lang -> Maybe (CostWithExpr lang))
-> Maybe (CostWithExpr lang)
-> Set (ENode lang)
-> Maybe (CostWithExpr lang)
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl' (\Maybe (CostWithExpr lang)
c ENode lang
n -> case (Maybe (CostWithExpr lang)
c, Traversable lang =>
ClassIdMap (CostWithExpr lang)
-> ENode lang -> Maybe (CostWithExpr lang)
ClassIdMap (CostWithExpr lang)
-> ENode lang -> Maybe (CostWithExpr lang)
nodeTotalCost ClassIdMap (CostWithExpr lang)
beingUpdated ENode lang
n) of
(Maybe (CostWithExpr lang)
Nothing, Maybe (CostWithExpr lang)
Nothing) -> Maybe (CostWithExpr lang)
forall a. Maybe a
Nothing
(Maybe (CostWithExpr lang)
Nothing, Just CostWithExpr lang
nc) -> CostWithExpr lang -> Maybe (CostWithExpr lang)
forall a. a -> Maybe a
Just CostWithExpr lang
nc
(Just CostWithExpr lang
oc, Maybe (CostWithExpr lang)
Nothing) -> CostWithExpr lang -> Maybe (CostWithExpr lang)
forall a. a -> Maybe a
Just CostWithExpr lang
oc
(Just CostWithExpr lang
oc, Just CostWithExpr lang
nc) -> CostWithExpr lang -> Maybe (CostWithExpr lang)
forall a. a -> Maybe a
Just (CostWithExpr lang
oc CostWithExpr lang -> CostWithExpr lang -> CostWithExpr lang
forall a. Ord a => a -> a -> a
`min` CostWithExpr lang
nc)
) Maybe (CostWithExpr lang)
forall a. Maybe a
Nothing Set (ENode lang)
nodes
in case (Maybe (CostWithExpr lang)
currentCost, Maybe (CostWithExpr lang)
newCost) of
(Maybe (CostWithExpr lang)
Nothing, Just CostWithExpr lang
new) -> (Bool
True, ClassId
-> CostWithExpr lang
-> ClassIdMap (CostWithExpr lang)
-> ClassIdMap (CostWithExpr lang)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
i' CostWithExpr lang
new ClassIdMap (CostWithExpr lang)
beingUpdated)
(Just (CostWithExpr (ClassId, Fix lang)
old), Just (CostWithExpr (ClassId, Fix lang)
new))
| (ClassId, Fix lang) -> ClassId
forall a b. (a, b) -> a
fst (ClassId, Fix lang)
new ClassId -> ClassId -> Bool
forall a. Ord a => a -> a -> Bool
< (ClassId, Fix lang) -> ClassId
forall a b. (a, b) -> a
fst (ClassId, Fix lang)
old -> (Bool
True, ClassId
-> CostWithExpr lang
-> ClassIdMap (CostWithExpr lang)
-> ClassIdMap (CostWithExpr lang)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
i' ((ClassId, Fix lang) -> CostWithExpr lang
forall (lang :: * -> *). (ClassId, Fix lang) -> CostWithExpr lang
CostWithExpr (ClassId, Fix lang)
new) ClassIdMap (CostWithExpr lang)
beingUpdated)
(Maybe (CostWithExpr lang), Maybe (CostWithExpr lang))
_ -> (Bool, ClassIdMap (CostWithExpr lang))
acc
in if Bool
modified
then ClassIdMap (EClass lang)
-> ClassIdMap (CostWithExpr lang) -> ClassIdMap (CostWithExpr lang)
findCosts ClassIdMap (EClass lang)
eclasses ClassIdMap (CostWithExpr lang)
updated
else ClassIdMap (CostWithExpr lang)
updated
nodeTotalCost :: Traversable lang => ClassIdMap (CostWithExpr lang) -> ENode lang -> Maybe (CostWithExpr lang)
nodeTotalCost :: Traversable lang =>
ClassIdMap (CostWithExpr lang)
-> ENode lang -> Maybe (CostWithExpr lang)
nodeTotalCost ClassIdMap (CostWithExpr lang)
m (Node lang ClassId
n) = do
lang (CostWithExpr lang)
expr <- (ClassId -> Maybe (CostWithExpr lang))
-> lang ClassId -> Maybe (lang (CostWithExpr lang))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse ((ClassId
-> ClassIdMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
forall a. ClassId -> IntMap a -> Maybe a
`IM.lookup` ClassIdMap (CostWithExpr lang)
m) (ClassId -> Maybe (CostWithExpr lang))
-> (ClassId -> ClassId) -> ClassId -> Maybe (CostWithExpr lang)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassId -> EGraph lang -> ClassId)
-> EGraph lang -> ClassId -> ClassId
forall a b c. (a -> b -> c) -> b -> a -> c
flip ClassId -> EGraph lang -> ClassId
forall (l :: * -> *). ClassId -> EGraph l -> ClassId
find EGraph lang
g) lang ClassId
n
CostWithExpr lang -> Maybe (CostWithExpr lang)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (CostWithExpr lang -> Maybe (CostWithExpr lang))
-> CostWithExpr lang -> Maybe (CostWithExpr lang)
forall a b. (a -> b) -> a -> b
$ (ClassId, Fix lang) -> CostWithExpr lang
forall (lang :: * -> *). (ClassId, Fix lang) -> CostWithExpr lang
CostWithExpr (CostFunction lang
cost (((ClassId, Fix lang) -> ClassId
forall a b. (a, b) -> a
fst ((ClassId, Fix lang) -> ClassId)
-> (CostWithExpr lang -> (ClassId, Fix lang))
-> CostWithExpr lang
-> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CostWithExpr lang -> (ClassId, Fix lang)
forall (lang :: * -> *). CostWithExpr lang -> (ClassId, Fix lang)
unCWE) (CostWithExpr lang -> ClassId)
-> lang (CostWithExpr lang) -> lang ClassId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> lang (CostWithExpr lang)
expr), (lang (Fix lang) -> Fix lang
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (lang (Fix lang) -> Fix lang) -> lang (Fix lang) -> Fix lang
forall a b. (a -> b) -> a -> b
$ ((ClassId, Fix lang) -> Fix lang
forall a b. (a, b) -> b
snd ((ClassId, Fix lang) -> Fix lang)
-> (CostWithExpr lang -> (ClassId, Fix lang))
-> CostWithExpr lang
-> Fix lang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CostWithExpr lang -> (ClassId, Fix lang)
forall (lang :: * -> *). CostWithExpr lang -> (ClassId, Fix lang)
unCWE) (CostWithExpr lang -> Fix lang)
-> lang (CostWithExpr lang) -> lang (Fix lang)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> lang (CostWithExpr lang)
expr))
{-# INLINE nodeTotalCost #-}
{-# SCC extractBest #-}
type CostFunction l = l Cost -> Cost
type Cost = Int
depthCost :: Language l => CostFunction l
depthCost :: forall (l :: * -> *). Language l => CostFunction l
depthCost = (ClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1) (ClassId -> ClassId)
-> (l ClassId -> ClassId) -> l ClassId -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l ClassId -> ClassId
forall a. Num a => l a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
{-# INLINE depthCost #-}
findBest :: ClassId -> ClassIdMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
findBest :: forall (lang :: * -> *).
ClassId
-> ClassIdMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
findBest ClassId
i = ClassId -> IntMap (CostWithExpr lang) -> Maybe (CostWithExpr lang)
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
i
{-# INLINE findBest #-}
newtype CostWithExpr lang = CostWithExpr { forall (lang :: * -> *). CostWithExpr lang -> (ClassId, Fix lang)
unCWE :: (Cost, Fix lang) }
instance Eq (CostWithExpr lang) where
== :: CostWithExpr lang -> CostWithExpr lang -> Bool
(==) (CostWithExpr (ClassId
a,Fix lang
_)) (CostWithExpr (ClassId
b,Fix lang
_)) = ClassId
a ClassId -> ClassId -> Bool
forall a. Eq a => a -> a -> Bool
== ClassId
b
{-# INLINE (==) #-}
instance Ord (CostWithExpr lang) where
compare :: CostWithExpr lang -> CostWithExpr lang -> Ordering
compare (CostWithExpr (ClassId
a,Fix lang
_)) (CostWithExpr (ClassId
b,Fix lang
_)) = ClassId
a ClassId -> ClassId -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` ClassId
b
{-# INLINE compare #-}