{-# LANGUAGE CPP, MultiWayIf #-}
module TmOracle (
PmExpr(..), PmLit(..), SimpleEq, ComplexEq, PmVarEnv, falsePmExpr,
eqPmLit, filterComplex, isNotPmExprOther, runPmPprM, lhsExprToPmExpr,
hsExprToPmExpr, pprPmExprWithParens,
tmOracle, TmState, initialTmState, solveOneEq, extendSubst, canDiverge,
toComplex, exprDeepLookup, pmLitType, flattenPmVarEnv
) where
#include "HsVersions.h"
import PmExpr
import Id
import Name
import Type
import HsLit
import TcHsSyn
import MonadUtils
import Util
import NameEnv
type PmVarEnv = NameEnv PmExpr
type TmOracleEnv = (Bool, PmVarEnv)
canDiverge :: Name -> TmState -> Bool
canDiverge x (standby, (_unhandled, env))
| PmExprVar y <- varDeepLookup env x
= not $ any (isForcedByEq x) standby || any (isForcedByEq y) standby
| otherwise = False
where
isForcedByEq :: Name -> ComplexEq -> Bool
isForcedByEq y (e1, e2) = varIn y e1 || varIn y e2
varIn :: Name -> PmExpr -> Bool
varIn x e = case e of
PmExprVar y -> x == y
PmExprCon _ es -> any (x `varIn`) es
PmExprLit _ -> False
PmExprEq e1 e2 -> (x `varIn` e1) || (x `varIn` e2)
PmExprOther _ -> False
flattenPmVarEnv :: PmVarEnv -> PmVarEnv
flattenPmVarEnv env = mapNameEnv (exprDeepLookup env) env
type TmState = ([ComplexEq], TmOracleEnv)
initialTmState :: TmState
initialTmState = ([], (False, emptyNameEnv))
solveOneEq :: TmState -> ComplexEq -> Maybe TmState
solveOneEq solver_env@(_,(_,env)) complex
= solveComplexEq solver_env
$ simplifyComplexEq
$ applySubstComplexEq env complex
solveComplexEq :: TmState -> ComplexEq -> Maybe TmState
solveComplexEq solver_state@(standby, (unhandled, env)) eq@(e1, e2) = case eq of
(PmExprOther _,_) -> Just (standby, (True, env))
(_,PmExprOther _) -> Just (standby, (True, env))
(PmExprLit l1, PmExprLit l2) -> case eqPmLit l1 l2 of
True -> Just solver_state
False -> Nothing
(PmExprCon c1 ts1, PmExprCon c2 ts2)
| c1 == c2 -> foldlM solveComplexEq solver_state (zip ts1 ts2)
| otherwise -> Nothing
(PmExprCon _ [], PmExprEq t1 t2)
| isTruePmExpr e1 -> solveComplexEq solver_state (t1, t2)
| isFalsePmExpr e1 -> Just (eq:standby, (unhandled, env))
(PmExprEq t1 t2, PmExprCon _ [])
| isTruePmExpr e2 -> solveComplexEq solver_state (t1, t2)
| isFalsePmExpr e2 -> Just (eq:standby, (unhandled, env))
(PmExprVar x, PmExprVar y)
| x == y -> Just solver_state
| otherwise -> extendSubstAndSolve x e2 solver_state
(PmExprVar x, _) -> extendSubstAndSolve x e2 solver_state
(_, PmExprVar x) -> extendSubstAndSolve x e1 solver_state
(PmExprEq _ _, PmExprEq _ _) -> Just (eq:standby, (unhandled, env))
_ -> Just (standby, (True, env))
extendSubstAndSolve :: Name -> PmExpr -> TmState -> Maybe TmState
extendSubstAndSolve x e (standby, (unhandled, env))
= foldlM solveComplexEq new_incr_state (map simplifyComplexEq changed)
where
(changed, unchanged) = partitionWith (substComplexEq x e) standby
new_incr_state = (unchanged, (unhandled, extendNameEnv env x e))
extendSubst :: Id -> PmExpr -> TmState -> TmState
extendSubst y e (standby, (unhandled, env))
| isNotPmExprOther simpl_e
= (standby, (unhandled, extendNameEnv env x simpl_e))
| otherwise = (standby, (True, env))
where
x = idName y
simpl_e = fst $ simplifyPmExpr $ exprDeepLookup env e
simplifyComplexEq :: ComplexEq -> ComplexEq
simplifyComplexEq (e1, e2) = (fst $ simplifyPmExpr e1, fst $ simplifyPmExpr e2)
simplifyPmExpr :: PmExpr -> (PmExpr, Bool)
simplifyPmExpr e = case e of
PmExprCon c ts -> case mapAndUnzip simplifyPmExpr ts of
(ts', bs) -> (PmExprCon c ts', or bs)
PmExprEq t1 t2 -> simplifyEqExpr t1 t2
_other_expr -> (e, False)
simplifyEqExpr :: PmExpr -> PmExpr -> (PmExpr, Bool)
simplifyEqExpr e1 e2 = case (e1, e2) of
(PmExprVar x, PmExprVar y)
| x == y -> (truePmExpr, True)
(PmExprLit l1, PmExprLit l2) -> case eqPmLit l1 l2 of
True -> (truePmExpr, True)
False -> (falsePmExpr, True)
(PmExprEq {}, _) -> case (simplifyPmExpr e1, simplifyPmExpr e2) of
((e1', True ), (e2', _ )) -> simplifyEqExpr e1' e2'
((e1', _ ), (e2', True )) -> simplifyEqExpr e1' e2'
((e1', False), (e2', False)) -> (PmExprEq e1' e2', False)
(_, PmExprEq {}) -> case (simplifyPmExpr e1, simplifyPmExpr e2) of
((e1', True ), (e2', _ )) -> simplifyEqExpr e1' e2'
((e1', _ ), (e2', True )) -> simplifyEqExpr e1' e2'
((e1', False), (e2', False)) -> (PmExprEq e1' e2', False)
(PmExprCon c1 ts1, PmExprCon c2 ts2)
| c1 == c2 ->
let (ts1', bs1) = mapAndUnzip simplifyPmExpr ts1
(ts2', bs2) = mapAndUnzip simplifyPmExpr ts2
(tss, _bss) = zipWithAndUnzip simplifyEqExpr ts1' ts2'
worst_case = PmExprEq (PmExprCon c1 ts1') (PmExprCon c2 ts2')
in if | not (or bs1 || or bs2) -> (worst_case, False)
| all isTruePmExpr tss -> (truePmExpr, True)
| any isFalsePmExpr tss -> (falsePmExpr, True)
| otherwise -> (worst_case, False)
| otherwise -> (falsePmExpr, True)
_other_equality -> (original, False)
where
original = PmExprEq e1 e2
applySubstComplexEq :: PmVarEnv -> ComplexEq -> ComplexEq
applySubstComplexEq env (e1,e2) = (exprDeepLookup env e1, exprDeepLookup env e2)
varDeepLookup :: PmVarEnv -> Name -> PmExpr
varDeepLookup env x
| Just e <- lookupNameEnv env x = exprDeepLookup env e
| otherwise = PmExprVar x
{-# INLINE varDeepLookup #-}
exprDeepLookup :: PmVarEnv -> PmExpr -> PmExpr
exprDeepLookup env (PmExprVar x) = varDeepLookup env x
exprDeepLookup env (PmExprCon c es) = PmExprCon c (map (exprDeepLookup env) es)
exprDeepLookup env (PmExprEq e1 e2) = PmExprEq (exprDeepLookup env e1)
(exprDeepLookup env e2)
exprDeepLookup _ other_expr = other_expr
tmOracle :: TmState -> [ComplexEq] -> Maybe TmState
tmOracle tm_state eqs = foldlM solveOneEq tm_state eqs
pmLitType :: PmLit -> Type
pmLitType (PmSLit lit) = hsLitType lit
pmLitType (PmOLit _ lit) = overLitType lit