{-# LANGUAGE FlexibleInstances    #-}

-- Syntactic Equality of Types up tp forall type renaming

module Language.Haskell.Liquid.Types.Equality where

import qualified Language.Fixpoint.Types as F
import           Language.Haskell.Liquid.Types
import qualified Liquid.GHC.API as Ghc

import Control.Monad.Writer.Lazy
-- import Control.Monad
import qualified Data.List as L

instance REq SpecType where
  SpecType
t1 =*= :: SpecType -> SpecType -> Bool
=*= SpecType
t2 = SpecType -> SpecType -> Bool
compareRType SpecType
t1 SpecType
t2

compareRType :: SpecType -> SpecType -> Bool
compareRType :: SpecType -> SpecType -> Bool
compareRType SpecType
i1 SpecType
i2 = Bool
res Bool -> Bool -> Bool
&& forall {a} {a}. (Eq a, Eq a) => [(a, a)] -> Bool
unify [(RTyVar, RTyVar)]
ys
  where
    unify :: [(a, a)] -> Bool
unify [(a, a)]
vs = forall (t :: * -> *). Foldable t => t Bool -> Bool
and (forall {a} {a}. Eq a => [(a, a)] -> Bool
sndEq forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. (a -> a -> Bool) -> [a] -> [[a]]
L.groupBy (\(a
x1,a
_) (a
x2,a
_) -> a
x1 forall a. Eq a => a -> a -> Bool
== a
x2) [(a, a)]
vs)
    sndEq :: [(a, a)] -> Bool
sndEq [] = Bool
True
    sndEq [(a, a)
_] = Bool
True
    sndEq ((a
_,a
y):[(a, a)]
xs) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
==a
y) (forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, a)]
xs)

    (Bool
res, [(RTyVar, RTyVar)]
ys) = forall w a. Writer w a -> (a, w)
runWriter (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
i1 SpecType
i2)
    go :: SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
    go :: SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go (RAllT RTVU RTyCon RTyVar
x1 SpecType
t1 RReft
r1) (RAllT RTVU RTyCon RTyVar
x2 SpecType
t2 RReft
r2)
      | RTV TyVar
v1 <- forall tv s. RTVar tv s -> tv
ty_var_value RTVU RTyCon RTyVar
x1
      , RTV TyVar
v2 <- forall tv s. RTVar tv s -> tv
ty_var_value RTVU RTyCon RTyVar
x2
      , RReft
r1 forall a. REq a => a -> a -> Bool
=*= RReft
r2
      = SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t1 (forall tv ty a. SubsTy tv ty a => (tv, ty) -> a -> a
subt (TyVar
v2, TyVar -> Type
Ghc.mkTyVarTy TyVar
v1) SpecType
t2)

    go (RVar RTyVar
v1 RReft
r1) (RVar RTyVar
v2 RReft
r2)
      = do forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [(RTyVar
v1, RTyVar
v2)]
           forall (m :: * -> *) a. Monad m => a -> m a
return (RReft
r1 forall a. REq a => a -> a -> Bool
=*= RReft
r2)
     -- = v1 == v2 && r1 =*= r2
    go (RFun Symbol
x1 RFInfo
_ SpecType
t11 SpecType
t12 RReft
r1) (RFun Symbol
x2 RFInfo
_ SpecType
t21 SpecType
t22 RReft
r2)
      | Symbol
x1 forall a. Eq a => a -> a -> Bool
== Symbol
x2 Bool -> Bool -> Bool
&& RReft
r1 forall a. REq a => a -> a -> Bool
=*= RReft
r2
      = forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)
    go (RAllP PVU RTyCon RTyVar
x1 SpecType
t1) (RAllP PVU RTyCon RTyVar
x2 SpecType
t2)
      | PVU RTyCon RTyVar
x1 forall a. Eq a => a -> a -> Bool
== PVU RTyCon RTyVar
x2
      = SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t1 SpecType
t2
    go (RApp RTyCon
x1 [SpecType]
ts1 [RTProp RTyCon RTyVar RReft]
ps1 RReft
r1) (RApp RTyCon
x2 [SpecType]
ts2 [RTProp RTyCon RTyVar RReft]
ps2 RReft
r2)
      | RTyCon
x1 forall a. Eq a => a -> a -> Bool
== RTyCon
x2 Bool -> Bool -> Bool
&&
        RReft
r1 forall a. REq a => a -> a -> Bool
=*= RReft
r2 Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. REq a => a -> a -> Bool
(=*=) [RTProp RTyCon RTyVar RReft]
ps1 [RTProp RTyCon RTyVar RReft]
ps2)
      = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go [SpecType]
ts1 [SpecType]
ts2
    go (RAllE Symbol
x1 SpecType
t11 SpecType
t12) (RAllE Symbol
x2 SpecType
t21 SpecType
t22) | Symbol
x1 forall a. Eq a => a -> a -> Bool
== Symbol
x2
      = forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)
    go (REx Symbol
x1 SpecType
t11 SpecType
t12) (REx Symbol
x2 SpecType
t21 SpecType
t22) | Symbol
x1 forall a. Eq a => a -> a -> Bool
== Symbol
x2
      = forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)
    go (RExprArg Located Expr
e1) (RExprArg Located Expr
e2)
      = forall (m :: * -> *) a. Monad m => a -> m a
return (Located Expr
e1 forall a. REq a => a -> a -> Bool
=*= Located Expr
e2)
    go (RAppTy SpecType
t11 SpecType
t12 RReft
r1) (RAppTy SpecType
t21 SpecType
t22 RReft
r2) | RReft
r1 forall a. REq a => a -> a -> Bool
=*= RReft
r2
      = forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)
    go (RRTy [(Symbol, SpecType)]
_ RReft
_ Oblig
_ SpecType
r1) (RRTy [(Symbol, SpecType)]
_ RReft
_ Oblig
_ SpecType
r2)
      = forall (m :: * -> *) a. Monad m => a -> m a
return (SpecType
r1 forall a. REq a => a -> a -> Bool
=*= SpecType
r2)
    go (RHole RReft
r1) (RHole RReft
r2)
      = forall (m :: * -> *) a. Monad m => a -> m a
return (RReft
r1 forall a. REq a => a -> a -> Bool
=*= RReft
r2)
    go SpecType
_t1 SpecType
_t2
      = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

class REq a where
  (=*=) :: a -> a -> Bool

instance REq t2 => REq (Ref t1 t2) where
    (RProp [(Symbol, t1)]
_ t2
t1) =*= :: Ref t1 t2 -> Ref t1 t2 -> Bool
=*= (RProp [(Symbol, t1)]
_ t2
t2) = t2
t1 forall a. REq a => a -> a -> Bool
=*= t2
t2

instance REq (UReft F.Reft) where
  (MkUReft Reft
r1 Predicate
p1) =*= :: RReft -> RReft -> Bool
=*= (MkUReft Reft
r2 Predicate
p2)
     = Reft
r1 forall a. REq a => a -> a -> Bool
=*= Reft
r2 Bool -> Bool -> Bool
&& Predicate
p1 forall a. Eq a => a -> a -> Bool
== Predicate
p2

instance REq F.Reft where
  F.Reft (Symbol
v1, Expr
e1) =*= :: Reft -> Reft -> Bool
=*= F.Reft (Symbol
v2, Expr
e2) = forall a. Subable a => a -> (Symbol, Expr) -> a
F.subst1 Expr
e1 (Symbol
v1, Symbol -> Expr
F.EVar Symbol
v2) forall a. REq a => a -> a -> Bool
=*= Expr
e2

instance REq F.Expr where
  Expr
e1 =*= :: Expr -> Expr -> Bool
=*= Expr
e2 = forall {a}. (Eq a, Fixpoint a) => a -> a -> Bool
go (forall a. Fixpoint a => a -> a
F.simplify Expr
e1) (forall a. Fixpoint a => a -> a
F.simplify Expr
e2)
    where go :: a -> a -> Bool
go a
r1 a
r2 = forall a. PPrint a => String -> a -> a
F.notracepp (String
"comparing " forall a. [a] -> [a] -> [a]
++ forall a. PPrint a => a -> String
showpp (forall a. Fixpoint a => a -> Doc
F.toFix a
r1, forall a. Fixpoint a => a -> Doc
F.toFix a
r2)) forall a b. (a -> b) -> a -> b
$ a
r1 forall a. Eq a => a -> a -> Bool
== a
r2

instance REq r => REq (Located r) where
  Located r
t1 =*= :: Located r -> Located r -> Bool
=*= Located r
t2 = forall a. Located a -> a
val Located r
t1 forall a. REq a => a -> a -> Bool
=*= forall a. Located a -> a
val Located r
t2