module Test.QuickSpec.Reasoning.CongruenceClosure(CC, newSym, (=:=), (=?=), rep, evalCC, execCC, runCC, ($$), S, funUse, argUse, lookup, initial, frozen) where
import Prelude hiding (lookup)
import Control.Monad
import Control.Monad.Trans.State.Strict
import Data.IntMap(IntMap)
import qualified Data.IntMap as IntMap
import Test.QuickSpec.Reasoning.UnionFind(UF, Replacement((:>)))
import qualified Test.QuickSpec.Reasoning.UnionFind as UF
import Data.Maybe
import Data.List(foldl')
import Text.Printf
lookup2 :: Int -> Int -> IntMap (IntMap a) -> Maybe a
lookup2 k1 k2 m = IntMap.lookup k2 (IntMap.findWithDefault IntMap.empty k1 m)
insert2 :: Int -> Int -> a -> IntMap (IntMap a) -> IntMap (IntMap a)
insert2 k1 k2 v m = IntMap.insertWith IntMap.union k1 (IntMap.singleton k2 v) m
delete2 :: Int -> Int -> IntMap (IntMap a) -> IntMap (IntMap a)
delete2 k1 k2 m = IntMap.adjust (IntMap.delete k2) k1 m
data FlatEqn = (Int, Int) := Int deriving (Eq, Ord)
data S = S {
funUse :: !(IntMap [(Int, Int)]),
argUse :: !(IntMap [(Int, Int)]),
lookup :: IntMap (IntMap Int),
uf :: UF.S
}
type CC = State S
liftUF :: UF a -> CC a
liftUF m = do
s <- get
let (x, uf') = UF.runUF (uf s) m
put s { uf = uf' }
return x
invariant :: String -> CC ()
invariant _ = return ()
modifyFunUse f = modify (\s -> s { funUse = f (funUse s) })
modifyArgUse f = modify (\s -> s { argUse = f (argUse s) })
addFunUses xs s = modifyFunUse (IntMap.insertWith (++) s xs)
addArgUses xs s = modifyArgUse (IntMap.insertWith (++) s xs)
modifyLookup f = modify (\s -> s { lookup = f (lookup s) })
putLookup l = modifyLookup (const l)
newSym :: CC Int
newSym = liftUF UF.newSym
($$) :: Int -> Int -> CC Int
f $$ x = do
invariant (printf "before %s$$%s" (show f) (show x))
m <- gets lookup
f' <- rep f
x' <- rep x
invariant (printf "at %s$$%s:1" (show f) (show x))
case lookup2 x' f' m of
Nothing -> do
c <- newSym
invariant (printf "at %s$$%s:2" (show f) (show x))
putLookup (insert2 x' f' c m)
addFunUses [(x', c)] f'
addArgUses [(f', c)] x'
invariant (printf "after %s$$%s" (show f) (show x))
return c
Just k -> return k
(=:=) :: Int -> Int -> CC Bool
a =:= b = propagate (a, b)
(=?=) :: Int -> Int -> CC Bool
t =?= u = liftM2 (==) (rep t) (rep u)
propagate (a, b) = do
(unified, pending) <- propagate1 (a, b)
mapM_ propagate pending
return unified
propagate1 (a, b) = do
invariant (printf "before propagate (%s, %s)" (show a) (show b))
res <- liftUF (a UF.=:= b)
case res of
Nothing -> return (False, [])
Just (r :> r') -> do
funUses <- gets (IntMap.lookup r . funUse)
argUses <- gets (IntMap.lookup r . argUse)
case (funUses, argUses) of
(Nothing, Nothing) -> return (True, [])
_ -> fmap (\x -> (True, x)) (updateUses r r' (fromMaybe [] funUses) (fromMaybe [] argUses))
updateUses r r' funUses argUses = do
modifyFunUse (IntMap.delete r)
modifyArgUse (IntMap.delete r)
modifyLookup (IntMap.delete r)
forM_ funUses $ \(x, _) -> do
x' <- rep x
modifyLookup (delete2 x' r)
invariant (printf "after deleting %s" (show r))
let repPair (x, c) = do
x' <- rep x
return (x', c)
funUses' <- mapM repPair funUses
argUses' <- mapM repPair argUses
m <- gets lookup
let foldUses insert lookup pending m uses = foldl' op e uses
where op (pending, newUses, m) (x', c) =
case lookup x' m of
Just k -> ((c, k):pending, newUses, m)
Nothing -> (pending, (x', c):newUses, insert x' c m)
e = (pending, [], m)
(funPending, funNewUses, m') = foldUses (\x' c m -> insert2 x' r' c m)
(\x' m -> lookup2 x' r' m)
[] m funUses'
(pending, argNewUses, argM) = foldUses IntMap.insert IntMap.lookup funPending
(IntMap.findWithDefault IntMap.empty r' m')
argUses'
addFunUses funNewUses r'
addArgUses argNewUses r'
putLookup (if IntMap.null argM then m' else IntMap.insert r' argM m')
invariant (printf "after updateUses (%s, %s)" (show r) (show r'))
return pending
rep :: Int -> CC Int
rep s = liftUF (UF.rep s)
runCC :: S -> CC a -> (a, S)
runCC s m = runState m s
evalCC :: S -> CC a -> a
evalCC s m = fst (runCC s m)
execCC :: S -> CC a -> S
execCC s m = snd (runCC s m)
initial :: Int -> S
initial n = S IntMap.empty IntMap.empty IntMap.empty (UF.initial n)
frozen :: CC a -> CC a
frozen x = fmap (evalState x) get