{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE PatternGuards             #-}
{-# LANGUAGE ScopedTypeVariables       #-}

{-# OPTIONS_GHC -Wno-name-shadowing    #-}

module Language.Fixpoint.Solver.Rewrite
  ( getRewrite
  , subExprs
  , unify
  , ordConstraints
  , convert
  , passesTerminationCheck
  , RewriteArgs(..)
  , RWTerminationOpts(..)
  , SubExpr
  , TermOrigin(..)
  , OCType
  , RESTOrdering(..)
  ) where

import           Control.Monad.State (guard)
import           Control.Monad.Trans.Maybe
import           Data.Hashable
import qualified Data.HashMap.Strict  as M
import qualified Data.List            as L
import qualified Data.Text as TX
import           GHC.IO.Handle.Types (Handle)
import           GHC.Generics
import           Text.PrettyPrint (text)
import           Language.Fixpoint.Types.Config (RESTOrdering(..))
import           Language.Fixpoint.Types hiding (simplify)
import           Language.REST
import           Language.REST.KBO (kbo)
import           Language.REST.LPO (lpo)
import           Language.REST.OCAlgebra as OC
import           Language.REST.OCToAbstract (lift)
import           Language.REST.Op
import           Language.REST.SMT (SMTExpr)
import           Language.REST.WQOConstraints.ADT (ConstraintsADT, adtOC)
import qualified Language.REST.RuntimeTerm as RT

-- | @(e, f)@ asserts that @e@ is a subexpression of @f e@
type SubExpr = (Expr, Expr -> Expr)

data TermOrigin = PLE | RW deriving (Int -> TermOrigin -> ShowS
[TermOrigin] -> ShowS
TermOrigin -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TermOrigin] -> ShowS
$cshowList :: [TermOrigin] -> ShowS
show :: TermOrigin -> [Char]
$cshow :: TermOrigin -> [Char]
showsPrec :: Int -> TermOrigin -> ShowS
$cshowsPrec :: Int -> TermOrigin -> ShowS
Show, TermOrigin -> TermOrigin -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TermOrigin -> TermOrigin -> Bool
$c/= :: TermOrigin -> TermOrigin -> Bool
== :: TermOrigin -> TermOrigin -> Bool
$c== :: TermOrigin -> TermOrigin -> Bool
Eq)

instance PPrint TermOrigin where
  pprintTidy :: Tidy -> TermOrigin -> Doc
pprintTidy Tidy
_ = [Char] -> Doc
text forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show


data RWTerminationOpts =
    RWTerminationCheckEnabled
  | RWTerminationCheckDisabled

data RewriteArgs = RWArgs
 { RewriteArgs -> Expr -> IO Bool
isRWValid          :: Expr -> IO Bool
 , RewriteArgs -> RWTerminationOpts
rwTerminationOpts  :: RWTerminationOpts
 }

-- Monomorphize ordering constraints so we don't litter PLE with type variables
-- Also helps since GHC doesn't support impredicate polymorphism (yet)
data OCType =
    RPO (ConstraintsADT Op)
  | LPO (ConstraintsADT Op)
  | KBO (SMTExpr Bool)
  | Fuel Int
  deriving (OCType -> OCType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OCType -> OCType -> Bool
$c/= :: OCType -> OCType -> Bool
== :: OCType -> OCType -> Bool
$c== :: OCType -> OCType -> Bool
Eq, Int -> OCType -> ShowS
[OCType] -> ShowS
OCType -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [OCType] -> ShowS
$cshowList :: [OCType] -> ShowS
show :: OCType -> [Char]
$cshow :: OCType -> [Char]
showsPrec :: Int -> OCType -> ShowS
$cshowsPrec :: Int -> OCType -> ShowS
Show, forall x. Rep OCType x -> OCType
forall x. OCType -> Rep OCType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep OCType x -> OCType
$cfrom :: forall x. OCType -> Rep OCType x
Generic, Eq OCType
Int -> OCType -> Int
OCType -> Int
forall a. Eq a -> (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: OCType -> Int
$chash :: OCType -> Int
hashWithSalt :: Int -> OCType -> Int
$chashWithSalt :: Int -> OCType -> Int
Hashable)

ordConstraints :: RESTOrdering -> (Handle, Handle) -> OCAlgebra OCType RT.RuntimeTerm IO
ordConstraints :: RESTOrdering -> (Handle, Handle) -> OCAlgebra OCType RuntimeTerm IO
ordConstraints RESTOrdering
RESTRPO      (Handle, Handle)
solver = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints ConstraintsADT Op -> OCType
RPO OCType -> ConstraintsADT Op
asRPO ((Handle, Handle) -> OCAlgebra (ConstraintsADT Op) RuntimeTerm IO
adtRPO (Handle, Handle)
solver)
  where
    asRPO :: OCType -> ConstraintsADT Op
asRPO (RPO ConstraintsADT Op
t) = ConstraintsADT Op
t
    asRPO OCType
_       = forall a. HasCallStack => a
undefined

ordConstraints RESTOrdering
RESTKBO      (Handle, Handle)
solver = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints SMTExpr Bool -> OCType
KBO OCType -> SMTExpr Bool
asKBO ((Handle, Handle) -> OCAlgebra (SMTExpr Bool) RuntimeTerm IO
kbo (Handle, Handle)
solver)
  where
    asKBO :: OCType -> SMTExpr Bool
asKBO (KBO SMTExpr Bool
t) = SMTExpr Bool
t
    asKBO OCType
_       = forall a. HasCallStack => a
undefined

ordConstraints RESTOrdering
RESTLPO      (Handle, Handle)
solver = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints ConstraintsADT Op -> OCType
LPO OCType -> ConstraintsADT Op
asLPO (forall (impl :: * -> *) base lifted (m :: * -> *).
(ToSMTVar base Int, Ord base, Eq base, Hashable base, Show lifted,
 Show base, Show (impl base)) =>
WQOConstraints impl m
-> ConstraintGen impl base lifted Identity
-> OCAlgebra (impl base) lifted m
lift ((Handle, Handle) -> WQOConstraints ConstraintsADT IO
adtOC (Handle, Handle)
solver) forall (oc :: * -> *).
(Show (oc Op), Eq (oc Op), Hashable (oc Op)) =>
ConstraintGen oc Op RuntimeTerm Identity
lpo)
  where
    asLPO :: OCType -> ConstraintsADT Op
asLPO (LPO ConstraintsADT Op
t) = ConstraintsADT Op
t
    asLPO OCType
_       = forall a. HasCallStack => a
undefined

ordConstraints (RESTFuel Int
n) (Handle, Handle)
_      = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints Int -> OCType
Fuel OCType -> Int
asFuel forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => Int -> OCAlgebra Int a m
fuelOC Int
n
  where
    asFuel :: OCType -> Int
asFuel (Fuel Int
n) = Int
n
    asFuel OCType
_        = forall a. HasCallStack => a
undefined


convert :: Expr -> RT.RuntimeTerm
convert :: Expr -> RuntimeTerm
convert (EIte Expr
i Expr
t Expr
e)   = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$ite" forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr
i,Expr
t,Expr
e]
convert e :: Expr
e@EApp{}       | (Expr
f, [Expr]
terms) <- Expr -> (Expr, [Expr])
splitEAppThroughECst Expr
e, EVar Symbol
fName <- Expr -> Expr
dropECst Expr
f
                       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
fName)) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
terms
convert (EVar Symbol
s)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
s)) []
convert (PNot Expr
e)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$not" [ Expr -> RuntimeTerm
convert Expr
e ]
convert (PAnd [Expr]
es)      = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$and" forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (POr [Expr]
es)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$or" forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (PAtom Brel
s Expr
l Expr
r)  = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op forall a b. (a -> b) -> a -> b
$ Text
"$atom" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show) Brel
s) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (EBin Bop
o Expr
l Expr
r)   = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op forall a b. (a -> b) -> a -> b
$ Text
"$ebin" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show) Bop
o) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (ECon Constant
c)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op forall a b. (a -> b) -> a -> b
$ Text
"$econ" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show) Constant
c) []
convert (ESym (SL Text
tx)) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op Text
tx) []
convert (ECst Expr
t Sort
_)     = Expr -> RuntimeTerm
convert Expr
t
convert (PIff Expr
e0 Expr
e1)   = Expr -> RuntimeTerm
convert (Brel -> Expr -> Expr -> Expr
PAtom Brel
Eq Expr
e0 Expr
e1)
convert (PImp Expr
e0 Expr
e1)   = Expr -> RuntimeTerm
convert ([Expr] -> Expr
POr [Expr -> Expr
PNot Expr
e0, Expr
e1])
convert Expr
e              = forall a. HasCallStack => [Char] -> a
error (forall a. Show a => a -> [Char]
show Expr
e)

passesTerminationCheck :: OCAlgebra oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck :: forall oc a. OCAlgebra oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck OCAlgebra oc a IO
aoc RewriteArgs
rwArgs oc
c =
  case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
    RWTerminationOpts
RWTerminationCheckEnabled  -> forall c a (m :: * -> *). OCAlgebra c a m -> c -> m Bool
isSat OCAlgebra oc a IO
aoc oc
c
    RWTerminationOpts
RWTerminationCheckDisabled -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

-- | Yields the result of rewriting an expression with an autorewrite equation.
--
-- Yields nothing if:
--
--  * The result of the rewrite is identical to the original expression
--  * Any of the arguments of the autorewrite has a refinement type which is
--    not satisfied in the current context.
--
getRewrite ::
     OCAlgebra oc Expr IO
  -> RewriteArgs
  -> oc
  -> SubExpr
  -> AutoRewrite
  -> MaybeT IO ((Expr, Expr), Expr, oc)
getRewrite :: forall oc.
OCAlgebra oc Expr IO
-> RewriteArgs
-> oc
-> SubExpr
-> AutoRewrite
-> MaybeT IO ((Expr, Expr), Expr, oc)
getRewrite OCAlgebra oc Expr IO
aoc RewriteArgs
rwArgs oc
c (Expr
subE, Expr -> Expr
toE) (AutoRewrite [SortedReft]
args Expr
lhs Expr
rhs) =
  do
    Subst
su <- forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
lhs Expr
subE
    let subE' :: Expr
subE' = forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
rhs
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Expr
subE forall a. Eq a => a -> a -> Bool
/= Expr
subE'
    let expr' :: Expr
expr' = Expr -> Expr
toE Expr
subE'
        eqn :: (Expr, Expr)
eqn = (forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
lhs, Expr
subE')
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su) [(Symbol, Expr)]
exprs
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
      RWTerminationOpts
RWTerminationCheckEnabled ->
        let
          c' :: oc
c' = forall c a (m :: * -> *). OCAlgebra c a m -> c -> a -> a -> c
refine OCAlgebra oc Expr IO
aoc oc
c Expr
subE Expr
subE'
        in
          ((Expr, Expr)
eqn, Expr
expr', oc
c')
      RWTerminationOpts
RWTerminationCheckDisabled -> ((Expr, Expr)
eqn, Expr
expr', oc
c)
  where
    check :: Expr -> MaybeT IO ()
    check :: Expr -> MaybeT IO ()
check Expr
e = do
      Bool
valid <- forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RewriteArgs -> Expr -> IO Bool
isRWValid RewriteArgs
rwArgs Expr
e
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
valid

    freeVars :: [Symbol]
freeVars = [Symbol
s | RR Sort
_ (Reft (Symbol
s, Expr
_)) <- [SortedReft]
args ]
    exprs :: [(Symbol, Expr)]
exprs    = [(Symbol
s, Expr
e) | RR Sort
_ (Reft (Symbol
s, Expr
e)) <- [SortedReft]
args ]

    checkSubst :: Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su (Symbol
s, Expr
e) =
      do
        let su' :: Subst
su' = Subst -> Subst -> Subst
catSubst Subst
su forall a b. (a -> b) -> a -> b
$ [(Symbol, Expr)] -> Subst
mkSubst [(Symbol
"VV", forall a. Subable a => Subst -> a -> a
subst Subst
su (Symbol -> Expr
EVar Symbol
s))]
        -- liftIO $ printf "Substitute %s in %s\n" (show su') (show e)
        Expr -> MaybeT IO ()
check forall a b. (a -> b) -> a -> b
$ forall a. Subable a => Subst -> a -> a
subst (Subst -> Subst -> Subst
catSubst Subst
su Subst
su') Expr
e


subExprs :: Expr -> [SubExpr]
subExprs :: Expr -> [SubExpr]
subExprs Expr
e = (Expr
e,forall a. a -> a
id)forall a. a -> [a] -> [a]
:Expr -> [SubExpr]
subExprs' Expr
e

subExprs' :: Expr -> [SubExpr]
subExprs' :: Expr -> [SubExpr]
subExprs' (EIte Expr
c Expr
lhs Expr
rhs)  = [SubExpr]
c''
  where
    c' :: [SubExpr]
c' = Expr -> [SubExpr]
subExprs Expr
c
    c'' :: [SubExpr]
c'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte (Expr -> Expr
f Expr
e') Expr
lhs Expr
rhs)) [SubExpr]
c'

subExprs' (EBin Bop
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' (PImp Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' (PIff Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PIff (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PIff Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' (PAtom Brel
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' e :: Expr
e@EApp{} =
  if Expr
f forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.===" Bool -> Bool -> Bool
||
     Expr
f forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.==." Bool -> Bool -> Bool
||
     Expr
f forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.?"
  then []
  else forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int, Expr) -> [SubExpr]
replace [(Int, Expr)]
indexedArgs
    where
      (Expr
f, [Expr]
es)          = Expr -> (Expr, [Expr])
splitEApp Expr
e
      indexedArgs :: [(Int, Expr)]
indexedArgs      = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Expr]
es
      replace :: (Int, Expr) -> [SubExpr]
replace (Int
i, Expr
arg) = do
        (Expr
subArg, Expr -> Expr
toArg) <- Expr -> [SubExpr]
subExprs Expr
arg
        forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
subArg, \Expr
subArg' -> Expr -> [Expr] -> Expr
eApps Expr
f forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
i [Expr]
es forall a. [a] -> [a] -> [a]
++ Expr -> Expr
toArg Expr
subArg' forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop (Int
iforall a. Num a => a -> a -> a
+Int
1) [Expr]
es)

subExprs' (ECst Expr
e Sort
t) =
    [ (Expr
e', \Expr
subE -> Expr -> Sort -> Expr
ECst (Expr -> Expr
toE Expr
subE) Sort
t) | (Expr
e', Expr -> Expr
toE) <- Expr -> [SubExpr]
subExprs' Expr
e ]

subExprs' (PAnd [Expr]
es) = [ (Expr
e, [Expr] -> Expr
PAnd forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
f) | (Expr
e, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
es ]

subExprs' (POr [Expr]
es) = [ (Expr
e, [Expr] -> Expr
POr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
f) | (Expr
e, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
es ]

subExprs' Expr
_ = []

-- | Computes the subexpressions of a list of expressions.
-- Each subexpression comes with a function that rebuilds the
-- context in which the subexpression occurs.
--
-- > and [ es == f e | (e, f) <- subs es ]
--
subs :: [Expr] -> [(Expr, Expr -> [Expr])]
subs :: [Expr] -> [(Expr, Expr -> [Expr])]
subs [] = []
subs [Expr
x] = [ (Expr
s, \Expr
e -> [Expr -> Expr
f Expr
e]) | (Expr
s, Expr -> Expr
f) <- Expr -> [SubExpr]
subExprs Expr
x ]
subs (Expr
x:[Expr]
xs) = [ (Expr
s, \Expr
e -> Expr -> Expr
f Expr
e forall a. a -> [a] -> [a]
: [Expr]
xs) | (Expr
s, Expr -> Expr
f) <- Expr -> [SubExpr]
subExprs Expr
x ]
              forall a. [a] -> [a] -> [a]
++
              [ (Expr
s, \Expr
e -> Expr
x forall a. a -> [a] -> [a]
: Expr -> [Expr]
f Expr
e) | (Expr
s, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
xs ]


unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
_ []     []               = forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su forall k v. HashMap k v
M.empty)
unifyAll [Symbol]
freeVars (Expr
template:[Expr]
xs) (Expr
seen:[Expr]
ys) =
  do
    rs :: Subst
rs@(Su HashMap Symbol Expr
s1) <- [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
template Expr
seen
    let xs' :: [Expr]
xs' = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
xs
    let ys' :: [Expr]
ys' = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
ys
    (Su HashMap Symbol Expr
s2) <- [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll ([Symbol]
freeVars forall a. Eq a => [a] -> [a] -> [a]
L.\\ forall k v. HashMap k v -> [k]
M.keys HashMap Symbol Expr
s1) [Expr]
xs' [Expr]
ys'
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
M.union HashMap Symbol Expr
s1 HashMap Symbol Expr
s2)
unifyAll [Symbol]
_ [Expr]
_ [Expr]
_ = forall a. HasCallStack => a
undefined

-- | @unify vs template e = Just su@ yields a substitution @su@
-- such that subst su template == e
--
-- Moreover, @su@ is constraint to only substitute variables in @vs@.
--
-- Yields @Nothing@ if no substitution exists.
--
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
_ Expr
template Expr
seenExpr | Expr
template forall a. Eq a => a -> a -> Bool
== Expr
seenExpr = forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su forall k v. HashMap k v
M.empty)
unify [Symbol]
freeVars Expr
template Expr
seenExpr = case (Expr -> Expr
dropECst Expr
template, Expr
seenExpr) of
  -- preserve seen casts if possible
  (EVar Symbol
rwVar, Expr
_) | Symbol
rwVar forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Symbol]
freeVars ->
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (forall k v. Hashable k => k -> v -> HashMap k v
M.singleton Symbol
rwVar Expr
seenExpr)
  -- otherwise discard the seen casts
  (Expr
template', Expr
_) -> case (Expr
template', Expr -> Expr
dropECst Expr
seenExpr) of
    (EVar Symbol
lhs, EVar Symbol
rhs) | Symbol -> [Char]
removeModName Symbol
lhs forall a. Eq a => a -> a -> Bool
== Symbol -> [Char]
removeModName Symbol
rhs ->
                           forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su forall k v. HashMap k v
M.empty)
      where
        removeModName :: Symbol -> [Char]
removeModName Symbol
ts = [Char] -> ShowS
go [Char]
"" (Symbol -> [Char]
symbolString Symbol
ts) where
          go :: [Char] -> ShowS
go [Char]
buf []         = [Char]
buf
          go [Char]
_   (Char
'.':[Char]
rest) = [Char] -> ShowS
go [] [Char]
rest
          go [Char]
buf (Char
x:[Char]
xs)     = [Char] -> ShowS
go ([Char]
buf forall a. [a] -> [a] -> [a]
++ [Char
x]) [Char]
xs
    (EApp Expr
templateF Expr
templateBody, EApp Expr
seenF Expr
seenBody) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (ENeg Expr
rw, ENeg Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (EBin Bop
op Expr
rwLeft Expr
rwRight, EBin Bop
op' Expr
seenLeft Expr
seenRight) | Bop
op forall a. Eq a => a -> a -> Bool
== Bop
op' ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
rwLeft, Expr
rwRight] [Expr
seenLeft, Expr
seenRight]
    (EIte Expr
cond Expr
rwLeft Expr
rwRight, EIte Expr
seenCond Expr
seenLeft Expr
seenRight) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
cond, Expr
rwLeft, Expr
rwRight] [Expr
seenCond, Expr
seenLeft, Expr
seenRight]
    (ECst Expr
rw Sort
_, Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (ETApp Expr
rw Sort
_, ETApp Expr
seen Sort
_) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (ETAbs Expr
rw Symbol
_, ETAbs Expr
seen Symbol
_) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PAnd [Expr]
rw, PAnd [Expr]
seen ) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
    (POr [Expr]
rw, POr [Expr]
seen ) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
    (PNot Expr
rw, PNot Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PImp Expr
templateF Expr
templateBody, PImp Expr
seenF Expr
seenBody) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (PIff Expr
templateF Expr
templateBody, PIff Expr
seenF Expr
seenBody) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (PAtom Brel
rel Expr
templateF Expr
templateBody, PAtom Brel
rel' Expr
seenF Expr
seenBody) | Brel
rel forall a. Eq a => a -> a -> Bool
== Brel
rel' ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (PAll [(Symbol, Sort)]
_ Expr
rw, PAll [(Symbol, Sort)]
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PExist [(Symbol, Sort)]
_ Expr
rw, PExist [(Symbol, Sort)]
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PGrad KVar
_ Subst
_ GradInfo
_ Expr
rw, PGrad KVar
_ Subst
_ GradInfo
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (ECoerc Sort
_ Sort
_ Expr
rw, ECoerc Sort
_ Sort
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (Expr, Expr)
_ -> forall a. Maybe a
Nothing