{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
module Language.Fixpoint.Horn.Transformations (
uniq
, flatten
, elim
, elimPis
, solveEbs
, cstrToExpr
) where
import Language.Fixpoint.Horn.Types
import Language.Fixpoint.Horn.Info
import Language.Fixpoint.Smt.Theories as F
import qualified Language.Fixpoint.Types as F
import qualified Language.Fixpoint.Types.Config as F
import Language.Fixpoint.Graph as FG
import qualified Data.HashMap.Strict as M
import Data.String (IsString (..))
import Data.Either (partitionEithers, rights)
import Data.List (nub)
import qualified Data.Set as S
import qualified Data.HashSet as HS
import qualified Data.Graph as DG
import Control.Monad.State
import Data.Maybe (catMaybes, mapMaybe, fromMaybe)
import Language.Fixpoint.Types.Visitor as V
import System.Console.CmdArgs.Verbosity
import Data.Bifunctor (second)
import System.IO (hFlush, stdout)
trace :: String -> a -> a
trace _msg v = v
printPiSols :: (F.PPrint a1, F.PPrint a2, F.PPrint a3) =>
M.HashMap a1 ((a4, a2), a3) -> IO ()
printPiSols piSols =
sequence_ $ ((\(piVar, ((_, args), cstr)) -> do
putStr $ F.showpp piVar
putStr " := "
putStrLn $ F.showpp args
putStrLn $ F.showpp cstr
putStr "\n"
hFlush stdout) <$> M.toList piSols)
solveEbs :: (F.PPrint a) => F.Config -> Query a -> IO (Query a)
solveEbs cfg query@(Query qs vs c cons dist) = do
let normalizedC = flatten . pruneTauts $ hornify c
whenLoud $ putStrLn "Normalized EHC:"
whenLoud $ putStrLn $ F.showpp normalizedC
if isNNF c then pure $ Query qs vs normalizedC cons dist else do
let kvars = boundKvars normalizedC
whenLoud $ putStrLn "Skolemized:"
let poked = pokec normalizedC
whenLoud $ putStrLn $ F.showpp poked
whenLoud $ putStrLn "Skolemized + split:"
let (Just _horn, Just _side) = split poked
let horn = flatten . pruneTauts $ _horn
let side = flatten . pruneTauts $ _side
whenLoud $ putStrLn $ F.showpp (horn, side)
let pivars = boundKvars poked `S.difference` kvars
let cuts = calculateCuts cfg query (forgetPiVars pivars horn)
let acyclicKs = kvars `S.difference` cuts
whenLoud $ putStrLn "solved acyclic kvars:"
let (horn', side') = elimKs' (S.toList acyclicKs) (horn, side)
whenLoud $ putStrLn $ F.showpp horn'
whenLoud $ putStrLn $ F.showpp side'
let elimCutK k c = doelim k [] c
horn' <- pure $ foldr elimCutK horn' cuts
side' <- pure $ foldr elimCutK side' cuts
whenLoud $ putStrLn "pi defining constraints:"
let piSols = M.fromList $ fmap (\pivar -> (pivar, piDefConstr pivar horn')) (S.toList pivars)
whenLoud $ printPiSols piSols
whenLoud $ putStrLn "solved pis:"
let solvedPiCstrs = solPis (S.fromList $ M.keys cons ++ M.keys dist) piSols
whenLoud $ putStrLn $ F.showpp solvedPiCstrs
whenLoud $ putStrLn "solved horn:"
let solvedHorn = substPiSols solvedPiCstrs horn'
whenLoud $ putStrLn $ F.showpp solvedHorn
whenLoud $ putStrLn "solved side:"
let solvedSide = substPiSols solvedPiCstrs side'
whenLoud $ putStrLn $ F.showpp solvedSide
pure $ (Query qs vs (CAnd [solvedHorn, solvedSide]) cons dist)
piDefConstr :: F.Symbol -> Cstr a -> ((F.Symbol, [F.Symbol]), Cstr a)
piDefConstr k c = ((head ns, head formals), defC)
where
(ns, formals, defC) = case go c of
(ns, formals, Just defC) -> (ns, formals, defC)
(_, _, Nothing) -> error $ "pi variable " <> F.showpp k <> " has no defining constraint."
go :: Cstr a -> ([F.Symbol], [[F.Symbol]], Maybe (Cstr a))
go (CAnd cs) = (\(as, bs, cs) -> (concat as, concat bs, cAndMaybes cs)) $ unzip3 $ go <$> cs
go (All b@(Bind n _ (Var k' xs)) c')
| k == k' = ([n], [S.toList $ S.fromList xs `S.difference` S.singleton n], Just c')
| otherwise = fmap (fmap (All b)) (go c')
go (All b c') = fmap (fmap (All b)) (go c')
go _ = ([], [], Nothing)
cAndMaybes :: [Maybe (Cstr a)] -> Maybe (Cstr a)
cAndMaybes maybeCs = case catMaybes maybeCs of
[] -> Nothing
cs -> Just $ CAnd cs
#if !MIN_VERSION_base(4,14,0)
instance Functor ((,,) a b) where
fmap f (a, b, c) = (a, b, f c)
#endif
solPis :: S.Set F.Symbol -> M.HashMap F.Symbol ((F.Symbol, [F.Symbol]), Cstr a) -> M.HashMap F.Symbol Pred
solPis measures piSols = go (M.toList piSols) piSols
where
go ((pi, ((n, xs), c)):pis) piSols = M.insert pi solved $ go pis piSols
where solved = solPi measures pi n (S.fromList xs) piSols c
go [] _ = mempty
solPi :: S.Set F.Symbol -> F.Symbol -> F.Symbol -> S.Set F.Symbol -> M.HashMap F.Symbol ((F.Symbol, [F.Symbol]), Cstr a) -> Cstr a -> Pred
solPi measures basePi n args piSols c = trace ("\n\nsolPi: " <> F.showpp basePi <> "\n\n" <> F.showpp n <> "\n" <> F.showpp (S.toList args) <> "\n" <> F.showpp ((\(a, _, c) -> (a, c)) <$> edges) <> "\n" <> F.showpp (sols n) <> "\n" <> F.showpp rewritten <> "\n" <> F.showpp c <> "\n\n") $ PAnd $ rewritten
where
rewritten = rewriteWithEqualities measures n args equalities
equalities = (nub . fst) $ go (S.singleton basePi) c
edges = eqEdges args mempty equalities
(eGraph, vf, lookupVertex) = DG.graphFromEdges edges
sols x = case lookupVertex x of
Nothing -> []
Just vertex -> nub $ filter (/= F.EVar x) $ mconcat [es | ((_, es), _, _) <- vf <$> DG.reachable eGraph vertex]
go :: S.Set F.Symbol -> Cstr a -> ([(F.Symbol, F.Expr)], S.Set F.Symbol)
go visited (Head p _) = (collectEqualities p, visited)
go visited (CAnd cs) = foldl (\(eqs, visited) c -> let (eqs', visited') = go visited c in (eqs' <> eqs, visited')) (mempty, visited) cs
go visited (All (Bind _ _ (Var pi _)) c)
| S.member pi visited = go visited c
| otherwise = let (_, defC) = (piSols M.! pi)
(eqs', newVisited) = go (S.insert pi visited) defC
(eqs'', newVisited') = go newVisited c in
(eqs' <> eqs'', newVisited')
go visited (All (Bind _ _ p) c) = let (eqs, visited') = go visited c in
(eqs <> collectEqualities p, visited')
go _ Any{} = error "exists should not be present in piSols"
pokec :: Cstr a -> Cstr a
pokec = go mempty
where
go _ (Head c l) = Head c l
go xs (CAnd c) = CAnd (go xs <$> c)
go xs (All b c2) = All b $ go ((bSym b):xs) c2
go xs (Any b@(Bind x t p) c2) = CAnd [All b' $ CAnd [Head p l, go (x:xs) c2], Any b (Head pi l)]
where
b' = Bind x t pi
pi = piVar x xs
l = cLabel c2
piVar :: F.Symbol -> [F.Symbol] -> Pred
piVar x xs = Var (piSym x) (x:xs)
piSym :: F.Symbol -> F.Symbol
piSym s = fromString $ "π" ++ F.symbolString s
split :: Cstr a -> (Maybe (Cstr a), Maybe (Cstr a))
split (CAnd cs) = (andMaybes nosides, andMaybes sides)
where (nosides, sides) = unzip $ split <$> cs
split (All b c) = (All b <$> c', All b <$> c'')
where (c',c'') = split c
split c@Any{} = (Nothing, Just c)
split c@Head{} = (Just c, Nothing)
andMaybes :: [Maybe (Cstr a)] -> Maybe (Cstr a)
andMaybes cs = case catMaybes cs of
[] -> Nothing
[c] -> Just c
cs -> Just $ CAnd cs
elimPis :: [F.Symbol] -> (Cstr a, Cstr a) -> (Cstr a, Cstr a)
elimPis [] cc = cc
elimPis (n:ns) (horn, side) = elimPis ns (apply horn, apply side)
where Just nSol = defs n horn
apply = applyPi (piSym n) nSol
applyPi :: F.Symbol -> Cstr a -> Cstr a -> Cstr a
applyPi k defs (All (Bind x t (Var k' _xs)) c)
| k == k'
= All (Bind x t (Reft $ cstrToExpr defs)) c
applyPi k bp (CAnd cs)
= CAnd $ applyPi k bp <$> cs
applyPi k bp (All b c)
= All b (applyPi k bp c)
applyPi k bp (Any b c)
= Any b (applyPi k bp c)
applyPi k defs (Head (Var k' _xs) a)
| k == k'
= Head (Reft $ cstrToExpr defs) a
applyPi _ _ (Head p a) = Head p a
defs :: F.Symbol -> Cstr a -> Maybe (Cstr a)
defs x (CAnd cs) = andMaybes $ defs x <$> cs
defs x (All (Bind x' _ _) c)
| x' == x
= pure c
defs x (All _ c) = defs x c
defs _ (Head _ _) = Nothing
defs _ (Any _ _) = error "defs should be run only after noside and poke"
cstrToExpr :: Cstr a -> F.Expr
cstrToExpr (Head p _) = predToExpr p
cstrToExpr (CAnd cs) = F.PAnd $ cstrToExpr <$> cs
cstrToExpr (All (Bind x t p) c) = F.PAll [(x,t)] $ F.PImp (predToExpr p) $ cstrToExpr c
cstrToExpr (Any (Bind x t p) c) = F.PExist [(x,t)] $ F.PImp (predToExpr p) $ cstrToExpr c
predToExpr :: Pred -> F.Expr
predToExpr (Reft e) = e
predToExpr (Var k xs) = F.PKVar (F.KV k) (F.Su $ M.fromList su)
where su = zip (kargs k) (F.EVar <$> xs)
predToExpr (PAnd ps) = F.PAnd $ predToExpr <$> ps
elimKs' :: [F.Symbol] -> (Cstr a, Cstr a) -> (Cstr a, Cstr a)
elimKs' [] cstrs = cstrs
elimKs' (k:ks) (noside, side) = elimKs' (trace ("solved kvar " <> F.showpp k <> ":\n" <> F.showpp sol) ks) (noside', side')
where
sol = sol1 k $ scope k noside
noside' = simplify $ doelim k sol noside
side' = simplify $ doelim k sol side
instance V.Visitable Pred where
visit v c (PAnd ps) = PAnd <$> mapM (visit v c) ps
visit v c (Reft e) = Reft <$> visit v c e
visit _ _ var = pure var
instance V.Visitable (Cstr a) where
visit v c (CAnd cs) = CAnd <$> mapM (visit v c) cs
visit v c (Head p a) = Head <$> visit v c p <*> pure a
visit v ctx (All (Bind x t p) c) = All <$> (Bind x t <$> visit v ctx p) <*> visit v ctx c
visit v ctx (Any (Bind x t p) c) = All <$> (Bind x t <$> visit v ctx p) <*> visit v ctx c
rewriteWithEqualities :: S.Set F.Symbol -> F.Symbol -> S.Set F.Symbol -> [(F.Symbol, F.Expr)] -> [Pred]
rewriteWithEqualities measures n args equalities = preds
where
(eGraph, vf, lookupVertex) = DG.graphFromEdges $ eqEdges args mempty equalities
nResult = (n, makeWellFormed 15 $ sols n)
argResults = map (\arg -> (arg, makeWellFormed 15 $ sols arg)) (S.toList args)
preds = (mconcat $ (\(x, es) -> mconcat $ mkEquality x <$> es) <$> (nResult:argResults))
mkEquality x e = [Reft (F.PAtom F.Eq (F.EVar x) e)]
sols :: F.Symbol -> [F.Expr]
sols x = case lookupVertex x of
Nothing -> []
Just vertex -> nub $ filter (/= F.EVar x) $ mconcat [es | ((_, es), _, _) <- vf <$> DG.reachable eGraph vertex]
argsAndPrims = args `S.union` (S.fromList $ map fst $ F.toListSEnv $ F.theorySymbols []) `S.union`measures
isWellFormed :: F.Expr -> Bool
isWellFormed e = (S.fromList $ F.syms e) `S.isSubsetOf` argsAndPrims
makeWellFormed :: Int -> [F.Expr] -> [F.Expr]
makeWellFormed 0 es = filter isWellFormed es
makeWellFormed n es = makeWellFormed (n - 1) $ mconcat $ go <$> es
where
go e = if isWellFormed e then [e] else rewrite rewrites [e]
where
needSolving = (S.fromList $ F.syms e) `S.difference` argsAndPrims
rewrites = (\x -> (x, filter (/= F.EVar x) $ sols x)) <$> S.toList needSolving
rewrite [] es = es
rewrite ((x, rewrites):rewrites') es = rewrite rewrites' $ [F.subst (F.mkSubst [(x, e')]) e | e' <- rewrites, e <- es]
eqEdges :: S.Set F.Symbol ->
M.HashMap F.Symbol ([F.Symbol], [F.Expr]) ->
[(F.Symbol, F.Expr)] ->
[((F.Symbol, [F.Expr]), F.Symbol, [F.Symbol])]
eqEdges _args edgeMap [] = M.foldrWithKey (\x (ys, es) edges -> ((x, es), x, ys):edges) [] edgeMap
eqEdges args edgeMap ((x, e):eqs)
| F.EVar y <- e
, S.member y args = eqEdges args (M.insertWith (<>) x ([y], [F.EVar y]) edgeMap) eqs
| F.EVar y <- e = eqEdges args (M.insertWith (<>) x ([y], []) edgeMap) eqs
| otherwise = eqEdges args (M.insertWith (<>) x ([], [e]) edgeMap) eqs
collectEqualities :: Pred -> [(F.Symbol, F.Expr)]
collectEqualities = goP
where
goP (Reft e) = goE e
goP (PAnd ps) = mconcat $ goP <$> ps
goP _ = mempty
goE (F.PAtom F.Eq left right) = extractEquality left right
goE (F.PAnd es) = mconcat $ goE <$> es
goE _ = mempty
extractEquality :: F.Expr -> F.Expr -> [(F.Symbol, F.Expr)]
extractEquality left right
| F.EVar x <- left, F.EVar y <- right, x == y = mempty
| F.EVar x <- left, F.EVar y <- right = [(x, right), (y, left)]
| F.EVar x <- left = [(x, right)]
| F.EVar x <- right = [(x, left)]
| otherwise = mempty
substPiSols :: M.HashMap F.Symbol Pred -> Cstr a -> Cstr a
substPiSols _ c@Head{} = c
substPiSols piSols (CAnd cs) = CAnd $ substPiSols piSols <$> cs
substPiSols piSols (All (Bind x t p) c)
| Var k _ <- p = All (Bind x t $ M.lookupDefault p k piSols) (substPiSols piSols c)
| otherwise = All (Bind x t p) (substPiSols piSols c)
substPiSols piSols (Any (Bind n _ p) c)
| Head (Var pi _) label <- c, Just sol <- M.lookup pi piSols =
case findSol n sol of
Just e -> Head (flatten $ PAnd $ (\pred -> F.subst1 pred (n, e)) <$> [p, sol]) label
Nothing -> Head (Reft $ F.PAnd []) label
| otherwise = error "missing piSol"
findSol :: F.Symbol -> Pred -> Maybe F.Expr
findSol x = go
where
go (Reft e) = findEq e
go Var{} = Nothing
go (PAnd ps) = case mapMaybe go ps of
[] -> Nothing
x:_ -> Just x
findEq (F.PAtom F.Eq left right)
| F.EVar y <- left, y == x = Just right
| F.EVar y <- right, y == x = Just left
findEq _ = Nothing
type RenameMap = M.HashMap F.Symbol (Integer, [Integer])
uniq :: Cstr a -> Cstr a
uniq c = evalState (uniq' c) M.empty
uniq' :: Cstr a -> State RenameMap (Cstr a)
uniq' (Head c a) = Head <$> gets (rename c) <*> pure a
uniq' (CAnd c) = CAnd <$> mapM uniq' c
uniq' (All b@(Bind x _ _) c2) = do
b' <- uBind b
c2' <- uniq' c2
modify $ popName x
pure $ All b' c2'
uniq' (Any b@(Bind x _ _) c2) = do
b' <- uBind b
c2' <- uniq' c2
modify $ popName x
pure $ Any b' c2'
popName :: F.Symbol -> RenameMap -> RenameMap
popName x m = M.adjust (second tail) x m
pushName :: Maybe (Integer, [Integer]) -> Maybe (Integer, [Integer])
pushName Nothing = Just (0, [0])
pushName (Just (i, is)) = Just (i + 1, (i + 1):is)
uBind :: Bind -> State RenameMap Bind
uBind (Bind x t p) = do
x' <- uVariable x
p' <- gets (rename p)
pure $ Bind x' t p'
uVariable :: IsString a => F.Symbol -> State RenameMap a
uVariable x = do
modify (M.alter pushName x)
i <- gets (head . snd . (M.! x))
pure $ numSym x i
rename :: Pred -> RenameMap -> Pred
rename e m = substPred (M.mapMaybeWithKey (\k v -> case v of
(_, n:_) -> Just $ numSym k n
_ -> Nothing) m) e
numSym :: IsString a => F.Symbol -> Integer -> a
numSym s 0 = fromString $ F.symbolString s
numSym s i = fromString $ F.symbolString s ++ "#" ++ show i
substPred :: M.HashMap F.Symbol F.Symbol -> Pred -> Pred
substPred su (Reft e) = Reft $ F.subst (F.Su $ F.EVar <$> su) e
substPred su (PAnd ps) = PAnd $ substPred su <$> ps
substPred su (Var k xs) = Var k $ upd <$> xs
where upd x = M.lookupDefault x x su
elim :: Cstr a -> Cstr a
elim c = if S.null $ boundKvars res then res else error "called elim on cyclic constraint"
where
res = S.foldl elim1 c (boundKvars c)
elim1 :: Cstr a -> F.Symbol -> Cstr a
elim1 c k = simplify $ doelim k sol c
where sol = sol1 k (scope k c)
scope :: F.Symbol -> Cstr a -> Cstr a
scope k cstr = case go cstr of
Right c -> c
Left l -> Head (Reft F.PTrue) l
where
go c@(Head (Var k' _) _)
| k' == k = Right c
go (Head _ l) = Left l
go c@(All (Bind _ _ p) c') =
if k `S.member` (pKVars p) then Right c else go c'
go Any{} = error "any should not appear after poke"
go c@(CAnd cs) = case rights (go <$> cs) of
[] -> Left $ cLabel c
[c] -> Right c
_ -> Right c
sol1 :: F.Symbol -> Cstr a -> [([Bind], [F.Expr])]
sol1 k (CAnd cs) = sol1 k =<< cs
sol1 k (All b c) = (\(bs, eqs) -> (b:bs, eqs)) <$> sol1 k c
sol1 k (Head (Var k' ys) _) | k == k'
= [([], zipWith (F.PAtom F.Eq) (F.EVar <$> xs) (F.EVar <$> ys))]
where xs = zipWith const (kargs k) ys
sol1 _ (Head _ _) = []
sol1 _ (Any _ _) = error "ebinds don't work with old elim"
kargs :: F.Symbol -> [F.Symbol]
kargs k = fromString . (("κarg$" ++ F.symbolString k ++ "#") ++) . show <$> [1..]
doelim :: F.Symbol -> [([Bind], [F.Expr])] -> Cstr a -> Cstr a
doelim k bss (CAnd cs)
= CAnd $ doelim k bss <$> cs
doelim k bss (All (Bind x t p) c) =
case findKVarInGuard k p of
Right _ -> All (Bind x t p) (doelim k bss c)
Left (kvars, preds) -> demorgan x t kvars preds (doelim k bss c) bss
where
demorgan :: F.Symbol -> F.Sort -> [(F.Symbol, [F.Symbol])] -> [Pred] -> Cstr a -> [([Bind], [F.Expr])] -> Cstr a
demorgan x t kvars preds c bss = mkAnd $ cubeSol <$> bss
where su = F.Su $ M.fromList $ concat $ map (\(k, xs) -> zip (kargs k) (F.EVar <$> xs)) kvars
mkAnd [c] = c
mkAnd cs = CAnd cs
cubeSol ((b:bs), eqs) = All b $ cubeSol (bs, eqs)
cubeSol ([], eqs) = All (Bind x t (PAnd $ (Reft <$> F.subst su eqs) ++ (F.subst su <$> preds))) c
doelim k _ (Head (Var k' _) a)
| k == k'
= Head (Reft F.PTrue) a
doelim _ _ (Head p a) = Head p a
doelim k bss (Any (Bind x t p) c) =
case findKVarInGuard k p of
Right _ -> Any (Bind x t p) (doelim k bss c)
Left (_, rights) -> Any (Bind x t (PAnd rights)) (doelim k bss c)
findKVarInGuard :: F.Symbol -> Pred -> Either ([(F.Symbol, [F.Symbol])], [Pred]) Pred
findKVarInGuard k (PAnd ps) =
if null lefts
then Right (PAnd ps)
else Left $ (newLefts, newRights)
where findResults = findKVarInGuard k <$> ps
(lefts, rights) = partitionEithers findResults
newLefts = concat $ map fst lefts
newRights = concat (snd <$> lefts) ++ rights
findKVarInGuard k p@(Var k' xs)
| k == k' = Left ([(k', xs)], [])
| otherwise = Right p
findKVarInGuard _ p = Right p
boundKvars :: Cstr a -> S.Set F.Symbol
boundKvars (Head p _) = pKVars p
boundKvars (CAnd c) = mconcat $ boundKvars <$> c
boundKvars (All (Bind _ _ p) c) = pKVars p <> boundKvars c
boundKvars (Any (Bind _ _ p) c) = pKVars p <> boundKvars c
pKVars :: Pred -> S.Set F.Symbol
pKVars (Var k _) = S.singleton k
pKVars (PAnd ps) = mconcat $ pKVars <$> ps
pKVars _ = S.empty
isNNF :: Cstr a -> Bool
isNNF Head{} = True
isNNF (CAnd cs) = all isNNF cs
isNNF (All _ c) = isNNF c
isNNF Any{} = False
calculateCuts :: F.Config -> Query a -> Cstr a -> S.Set F.Symbol
calculateCuts cfg (Query qs vs _ cons dist) nnf = convert $ FG.depCuts deps
where
(_, deps) = elimVars cfg (hornFInfo $ Query qs vs nnf cons dist)
convert hashset = S.fromList $ F.kv <$> (HS.toList hashset)
forgetPiVars :: S.Set F.Symbol -> Cstr a -> Cstr a
forgetPiVars _ c@Head{} = c
forgetPiVars pis (CAnd cs) = CAnd $ forgetPiVars pis <$> cs
forgetPiVars pis (All (Bind x t p) c)
| Var k _ <- p, k `S.member` pis = All (Bind x t (PAnd [])) $ forgetPiVars pis c
| otherwise = All (Bind x t p) $ forgetPiVars pis c
forgetPiVars _ Any{} = error "shouldn't be present"
simplify :: Cstr a -> Cstr a
simplify = flatten . pruneTauts . removeDuplicateBinders
class Flatten a where
flatten :: a -> a
instance Flatten (Cstr a) where
flatten (CAnd cs) = case flatten cs of
[c] -> c
cs -> CAnd cs
flatten (Head p a) = Head (flatten p) a
flatten (All (Bind x t p) c) = All (Bind x t (flatten p)) (flatten c)
flatten (Any (Bind x t p) c) = Any (Bind x t (flatten p)) (flatten c)
instance Flatten [Cstr a] where
flatten (CAnd cs : xs) = flatten cs ++ flatten xs
flatten (x:xs)
| Head (Reft p) _ <- fx
, F.isTautoPred p = flatten xs
| otherwise = fx:flatten xs
where fx = flatten x
flatten [] = []
instance Flatten Pred where
flatten (PAnd ps) = case flatten ps of
[p] -> p
ps -> PAnd ps
flatten p = p
instance Flatten [Pred] where
flatten (PAnd ps' : ps) = flatten ps' ++ flatten ps
flatten (p : ps)
| Reft e <- fp
, F.isTautoPred e = flatten ps
| otherwise = fp : flatten ps
where fp = flatten p
flatten [] = []
instance Flatten F.Expr where
flatten (F.PAnd ps) = case flatten ps of
[p] -> p
ps -> F.PAnd ps
flatten p = p
instance Flatten [F.Expr] where
flatten (F.PAnd ps' : ps) = flatten ps' ++ flatten ps
flatten (p : ps)
| F.isTautoPred fp = flatten ps
| otherwise = fp : flatten ps
where fp = flatten p
flatten [] = []
hornify :: Cstr a -> Cstr a
hornify (Head (PAnd ps) a) = CAnd (flip Head a <$> ps')
where ps' = let (ks, qs) = split [] [] (flatten ps) in PAnd qs : ks
split kacc pacc ((Var x xs):qs) = split ((Var x xs):kacc) pacc qs
split kacc pacc (q:qs) = split kacc (q:pacc) qs
split kacc pacc [] = (kacc, pacc)
hornify (Head (Reft r) a) = CAnd (flip Head a <$> ((Reft $ F.PAnd ps):(Reft <$> ks)))
where (ks, ps) = split [] [] $ F.splitPAnd r
split kacc pacc (r@F.PKVar{}:rs) = split (r:kacc) pacc rs
split kacc pacc (r:rs) = split kacc (r:pacc) rs
split kacc pacc [] = (kacc,pacc)
hornify (Head h a) = Head h a
hornify (All b c) = All b $ hornify c
hornify (Any b c) = Any b $ hornify c
hornify (CAnd cs) = CAnd $ hornify <$> cs
removeDuplicateBinders :: Cstr a -> Cstr a
removeDuplicateBinders = go S.empty
where
go _ c@Head{} = c
go xs (CAnd cs) = CAnd $ go xs <$> cs
go xs (All b@(Bind x _ _) c) = if x `S.member` xs then go xs c else All b $ go (S.insert x xs) c
go xs (Any b c) = Any b $ go xs c
pruneTauts :: Cstr a -> Cstr a
pruneTauts = fromMaybe (CAnd []) . go
where
go (Head p l) = do
p' <- goP p
pure $ Head p' l
go (CAnd cs) = if null cs' then Nothing else Just $ CAnd cs'
where cs' = mapMaybe go cs
go (All b c) = do
c' <- go c
pure (All b c')
go c@Any{} = Just c
goP (Reft e) = if F.isTautoPred e then Nothing else Just $ Reft e
goP p@Var{} = Just p
goP (PAnd ps) = if null ps' then Nothing else Just $ PAnd ps'
where ps' = mapMaybe goP ps