{-# LANGUAGE FlexibleContexts #-}
module Util.Expr
( substitute, eqComAssoc, isLiteral, isSubExprOf
, wildcard, wildcardsToVars, restoreWildcards
) where
import Control.Applicative
import Control.Monad.State
import Data.Bool
import Domain.Math.Expr
import Ideas.Common.Rewriting.Term
import Ideas.Common.View
import Ideas.Utils.Uniplate
import Util.Monad
import qualified Data.Map as M
substitute :: M.Map String Expr -> Expr -> Expr
substitute m = rec
where
rec (Var s) =
case M.lookup s m of
Just e | s `notElem` vars e -> rec e
_ -> Var s
rec e = descend rec e
wildcardSymbol :: Symbol
wildcardSymbol = newSymbol "wildcard"
isWildcardSymbol :: Symbol -> Bool
isWildcardSymbol = (== wildcardSymbol)
wildcard :: String -> Expr
wildcard s = Sym wildcardSymbol [Var s]
prefixVariables :: Char -> Expr -> Expr
prefixVariables s (Var v) = Var (s : v)
prefixVariables _ a@(Sym s _) | isWildcardSymbol s = a
prefixVariables s e = descend (prefixVariables s) e
instantiateWildcards :: Expr -> Expr
instantiateWildcards (Sym s [Var v]) | isWildcardSymbol s = Var ('w':v)
instantiateWildcards e = descend instantiateWildcards e
wildcardsToVars :: Expr -> Expr
wildcardsToVars = instantiateWildcards . prefixVariables 'v'
restoreWildcards :: Expr -> Expr
restoreWildcards (Var ('v':v)) = Var v
restoreWildcards (Var ('w':v)) = wildcard v
restoreWildcards e = descend restoreWildcards e
isLiteral :: Expr -> Bool
isLiteral (Nat _) = True
isLiteral (Negate n) = isLiteral n
isLiteral (Number _) = True
isLiteral _ = False
eq :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => Expr -> Expr -> m Bool
eq x y = do
s <- get
b <- eq' x y
unless b (put s)
return b
eq' :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => Expr -> Expr -> m Bool
eq' (Sym s1 _) (Sym s2 _) | isWildcardSymbol s1 && isWildcardSymbol s2 = return False
eq' a b@(Sym s _) | isWildcardSymbol s = eq' b a
eq' (Sym s [Var v]) b | isWildcardSymbol s = do (me,pe) <- get
case (M.lookup v me, M.lookup v pe) of
(Just a, _) -> eq' a b
(_, Just p) | not (p b) -> return False
_ -> put (M.insert v b me, pe) >> return True
eq' (Nat a) (Nat b) = return (a == b)
eq' (Number a) (Number b) = return (a == b)
eq' (Var a) (Var b) = return (a == b)
eq' a b | Just (sa,[a1,a2]) <- getFunction a
, Just (sb,[b1,b2]) <- getFunction b
, sa == sb
, sa == plusSymbol || sa == timesSymbol
= do a1b1 <- eq' a1 b1
a2b1 <- eq' a2 b1
a1b2 <- eq' a1 b2
a2b2 <- eq' a2 b2
return ((a1b1 && a2b2) || (a1b2 && a2b1))
eq' a b | Just (sa,as) <- getFunction a
, Just (sb,bs) <- getFunction b
, sa == sb
= and <$> zipWithM eq' as bs
eq' _ _ = return False
isSubExprOf :: (Expr -> Expr) -> Expr -> Expr -> State (M.Map String Expr, M.Map String (Expr -> Bool)) Bool
isSubExprOf norm se e = state (\s -> maybe (False,s) (first (const True)) (runStateT (extractSum norm e se) s))
eqComAssoc :: (Expr -> Expr) -> Expr -> Expr -> State (M.Map String Expr, M.Map String (Expr -> Bool)) Bool
eqComAssoc norm e se = state (\s -> maybe (False,s) (first null) (runStateT (extractSum norm e se) s))
extractSum :: (Expr -> Expr) -> Expr -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractSum norm e s =
mIf (eqOrNorm norm e s) (return []) empty <|>
foldM
(extractSumMember norm)
(from sumView s)
(from sumView e)
extractSumMember :: (Expr -> Expr) -> [Expr] -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractSumMember norm eSumMembers sSumMember = case sSumMember of
Sym s [Var _] | isWildcardSymbol s -> mIf (eq sSumMember (to sumView eSumMembers)) (return []) noWildcardCase
_ -> noWildcardCase
where noWildcardCase = do xs <- deleteByM (\x y -> null <$> extractProduct norm x y) sSumMember eSumMembers
guard (length xs < length eSumMembers)
return xs
extractProduct :: (Expr -> Expr) -> Expr -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractProduct norm e s = guard (eEven == sEven) >>
foldM (extractProductMember norm b) sProductMembers eProductMembers
<|> mIf (eqOrNorm norm e s) (return []) empty
where
(eEven, eProductMembers) = from productView e
(sEven, sProductMembers) = from productView s
b = True
extractProductMember :: (Expr -> Expr) -> Bool -> [Expr] -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractProductMember norm _ sv r = do
xs <- deleteByM (eqOrNorm norm) r sv
guard (length xs < length sv)
return xs
eqOrNorm :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => (Expr -> Expr) -> Expr -> Expr -> m Bool
eqOrNorm norm x y = do b <- eq x y
if b then return True
else eq (norm x) (norm y)
mIf :: Applicative m => m Bool -> m a -> m a -> m a
mIf = (flip . fmap flip) (liftA3 bool)