{-# LANGUAGE FlexibleContexts #-}

-- | A class for unification
module Hyper.Class.Unify
    ( Unify (..)
    , UVarOf
    , UnifyGen (..)
    , BindingDict (..)
    , applyBindings
    , semiPruneLookup
    , occursError
    ) where

import Control.Monad (unless)
import Control.Monad.Error.Class (MonadError (..))
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Trans.State (get, put, runStateT)
import Data.Kind (Type)
import Hyper.Class.Nodes (HNodes (..), (#>))
import Hyper.Class.Optic (HSubset (..), HSubset')
import Hyper.Class.Recursive
import Hyper.Class.Traversable (htraverse)
import Hyper.Class.ZipMatch (ZipMatch)
import Hyper.Type (HyperType, type (#))
import Hyper.Type.Pure (Pure, _Pure)
import Hyper.Unify.Constraints
import Hyper.Unify.Error (UnifyError (..))
import Hyper.Unify.QuantifiedVar (HasQuantifiedVar (..), MonadQuantify (..))
import Hyper.Unify.Term (UTerm (..), UTermBody (..), uBody)

import Hyper.Internal.Prelude

-- | Unification variable type for a unification monad
type family UVarOf (m :: Type -> Type) :: HyperType

-- | BindingDict implements unification variables for a type in a unification monad.
--
-- It is parameterized on:
--
-- * @v@: The unification variable 'HyperType'
-- * @m@: The 'Monad' to bind in
-- * @t@: The unified term's 'HyperType'
--
-- Has 2 implementations in hypertypes:
--
-- * 'Hyper.Unify.Binding.bindingDict' for pure state based unification
-- * 'Hyper.Unify.Binding.ST.stBinding' for 'Control.Monad.ST.ST' based unification
data BindingDict v m t = BindingDict
    { forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar :: !(v # t -> m (UTerm v # t))
    , forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (UTerm v # t) -> m (v # t)
newVar :: !(UTerm v # t -> m (v # t))
    , forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar :: !(v # t -> UTerm v # t -> m ())
    }

-- | @Unify m t@ enables 'Hyper.Unify.unify' to perform unification for @t@ in the 'Monad' @m@.
--
-- The 'unifyRecursive' method represents the constraint that @Unify m@ applies to all recursive child nodes.
-- It replaces context for 'Unify' to avoid @UndecidableSuperClasses@.
class
    ( Eq (UVarOf m # t)
    , RTraversable t
    , ZipMatch t
    , HasTypeConstraints t
    , HasQuantifiedVar t
    , Monad m
    , MonadQuantify (TypeConstraintsOf t) (QVar t) m
    ) =>
    Unify m t
    where
    -- | The implementation for unification variables binding and lookup
    binding :: BindingDict (UVarOf m) m t

    -- | Handles a unification error.
    --
    -- If 'unifyError' is called then unification has failed.
    -- A compiler implementation may present an error message based on the provided 'UnifyError' when this occurs.
    unifyError :: UnifyError t # UVarOf m -> m a
    default unifyError ::
        (MonadError (e # Pure) m, HSubset' e (UnifyError t)) =>
        UnifyError t # UVarOf m ->
        m a
    unifyError UnifyError t # UVarOf m
e =
        forall (f :: * -> *) (h :: AHyperType -> *) (p :: AHyperType -> *)
       (q :: AHyperType -> *).
(Applicative f, HTraversable h) =>
(forall (n :: AHyperType -> *).
 HWitness h n -> (p # n) -> f (q # n))
-> (h # p) -> f (h # q)
htraverse (forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: AHyperType -> *)
       (c :: (AHyperType -> *) -> Constraint) (n :: AHyperType -> *) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (Pure # t)
applyBindings) UnifyError t # UVarOf m
e
            forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
                forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (s :: AHyperType -> *) (t :: AHyperType -> *)
       (a :: AHyperType -> *) (b :: AHyperType -> *)
       (h :: AHyperType -> *).
HSubset s t a b =>
Prism (s # h) (t # h) (a # h) (b # h)
hSubset forall t b. AReview t b -> b -> t
#)
            forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)

    -- | What to do when top-levels of terms being unified do not match.
    --
    -- Usually this will cause a 'unifyError'.
    --
    -- Some AST terms could be equivalent despite not matching structurally,
    -- like record field extentions with the fields ordered differently.
    -- Those would override the default implementation to handle the unification of mismatching structures.
    structureMismatch ::
        (forall c. Unify m c => UVarOf m # c -> UVarOf m # c -> m (UVarOf m # c)) ->
        t # UVarOf m ->
        t # UVarOf m ->
        m ()
    structureMismatch forall (c :: AHyperType -> *).
Unify m c =>
(UVarOf m # c) -> (UVarOf m # c) -> m (UVarOf m # c)
_ t # UVarOf m
x t # UVarOf m
y = forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: AHyperType -> *) (h :: AHyperType).
t h -> t h -> UnifyError t h
Mismatch t # UVarOf m
x t # UVarOf m
y)

    -- TODO: Putting documentation here causes duplication in the haddock documentation
    unifyRecursive :: Proxy m -> RecMethod (Unify m) t
    {-# INLINE unifyRecursive #-}
    default unifyRecursive :: HNodesConstraint t (Unify m) => Proxy m -> RecMethod (Unify m) t
    unifyRecursive Proxy m
_ Proxy t
_ = forall (a :: Constraint). a => Dict a
Dict

instance Recursive (Unify m) where
    {-# INLINE recurse #-}
    recurse :: forall (h :: AHyperType -> *) (proxy :: Constraint -> *).
(HNodes h, Unify m h) =>
proxy (Unify m h) -> Dict (HNodesConstraint h (Unify m))
recurse = forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (proxy :: Constraint -> *)
       (f :: (AHyperType -> *) -> Constraint) (h :: AHyperType -> *).
proxy (f h) -> Proxy h
proxyArgument

-- | A class for unification monads with scope levels
class Unify m t => UnifyGen m t where
    -- | Get the current scope constraint
    scopeConstraints :: Proxy t -> m (TypeConstraintsOf t)

    unifyGenRecursive :: Proxy m -> RecMethod (UnifyGen m) t
    {-# INLINE unifyGenRecursive #-}
    default unifyGenRecursive ::
        HNodesConstraint t (UnifyGen m) => Proxy m -> RecMethod (UnifyGen m) t
    unifyGenRecursive Proxy m
_ Proxy t
_ = forall (a :: Constraint). a => Dict a
Dict

instance Recursive (UnifyGen m) where
    {-# INLINE recurse #-}
    recurse :: forall (h :: AHyperType -> *) (proxy :: Constraint -> *).
(HNodes h, UnifyGen m h) =>
proxy (UnifyGen m h) -> Dict (HNodesConstraint h (UnifyGen m))
recurse = forall (m :: * -> *) (t :: AHyperType -> *).
UnifyGen m t =>
Proxy m -> RecMethod (UnifyGen m) t
unifyGenRecursive (forall {k} (t :: k). Proxy t
Proxy @m) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (proxy :: Constraint -> *)
       (f :: (AHyperType -> *) -> Constraint) (h :: AHyperType -> *).
proxy (f h) -> Proxy h
proxyArgument

-- | Look up a variable, and return last variable pointing to result.
-- Prunes all variables on way to point to the last variable
-- (path-compression ala union-find).
{-# INLINE semiPruneLookup #-}
semiPruneLookup ::
    Unify m t =>
    UVarOf m # t ->
    m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup :: forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v0 =
    forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v0
        forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            UToVar UVarOf m # t
v1 ->
                do
                    (UVarOf m # t
v, UTerm (UVarOf m) # t
r) <- forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v1
                    forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v0 (forall (v :: AHyperType -> *) (ast :: AHyperType).
v ast -> UTerm v ast
UToVar UVarOf m # t
v)
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure (UVarOf m # t
v, UTerm (UVarOf m) # t
r)
            UTerm (UVarOf m) # t
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (UVarOf m # t
v0, UTerm (UVarOf m) # t
t)

-- | Resolve a term from a unification variable.
--
-- Note that this must be done after
-- all unifications involving the term and its children are done,
-- as it replaces unification state with cached resolved terms.
{-# INLINE applyBindings #-}
applyBindings ::
    forall m t.
    Unify m t =>
    UVarOf m # t ->
    m (Pure # t)
applyBindings :: forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (Pure # t)
applyBindings UVarOf m # t
v0 =
    do
        (UVarOf m # t
v1, UTerm (UVarOf m) ('AHyperType t)
x) <- forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v0
        let result :: (Pure # t) -> m (Pure # t)
result Pure # t
r = Pure # t
r forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (forall (v :: AHyperType -> *) (ast :: AHyperType).
Pure ast -> UTerm v ast
UResolved Pure # t
r)
        let quantify :: TypeConstraintsOf t -> m (Pure # t)
quantify TypeConstraintsOf t
c =
                forall typeConstraints q (m :: * -> *).
MonadQuantify typeConstraints q m =>
typeConstraints -> m q
newQuantifiedVariable TypeConstraintsOf t
c
                    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall (h :: AHyperType -> *) (j :: AHyperType -> *).
Iso (Pure # h) (Pure # j) (h # Pure) (j # Pure)
_Pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: AHyperType -> *) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall t b. AReview t b -> b -> t
#)
                    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Pure # t) -> m (Pure # t)
result
        case UTerm (UVarOf m) ('AHyperType t)
x of
            UResolving UTermBody (UVarOf m) ('AHyperType t)
t -> forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v1 UTermBody (UVarOf m) ('AHyperType t)
t
            UResolved Pure # t
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Pure # t
t
            UUnbound TypeConstraintsOf (GetHyperType ('AHyperType t))
c -> TypeConstraintsOf t -> m (Pure # t)
quantify TypeConstraintsOf (GetHyperType ('AHyperType t))
c
            USkolem TypeConstraintsOf (GetHyperType ('AHyperType t))
c -> TypeConstraintsOf t -> m (Pure # t)
quantify TypeConstraintsOf (GetHyperType ('AHyperType t))
c
            UTerm UTermBody (UVarOf m) ('AHyperType t)
b ->
                do
                    (t # Pure
r, Bool
anyChild) <-
                        forall (f :: * -> *) (h :: AHyperType -> *) (p :: AHyperType -> *)
       (q :: AHyperType -> *).
(Applicative f, HTraversable h) =>
(forall (n :: AHyperType -> *).
 HWitness h n -> (p # n) -> f (q # n))
-> (h # p) -> f (h # q)
htraverse
                            ( forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: AHyperType -> *)
       (c :: (AHyperType -> *) -> Constraint) (n :: AHyperType -> *) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#>
                                \UVarOf m # n
c ->
                                    do
                                        forall (m :: * -> *) s. Monad m => StateT s m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
`unless` forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (forall (v :: AHyperType -> *) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UResolving UTermBody (UVarOf m) ('AHyperType t)
b))
                                        forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Bool
True
                                        forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (Pure # t)
applyBindings UVarOf m # n
c forall a b. a -> (a -> b) -> b
& forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
                            )
                            (UTermBody (UVarOf m) ('AHyperType t)
b forall s a. s -> Getting a s a -> a
^. forall (v1 :: AHyperType -> *) (ast :: AHyperType)
       (v2 :: AHyperType -> *).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody)
                            forall a b. a -> (a -> b) -> b
& (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` Bool
False)
                            forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)
                    forall (h :: AHyperType -> *) (j :: AHyperType -> *).
Iso (Pure # h) (Pure # j) (h # Pure) (j # Pure)
_Pure forall t b. AReview t b -> b -> t
# t # Pure
r forall a b. a -> (a -> b) -> b
& if Bool
anyChild then (Pure # t) -> m (Pure # t)
result else forall (f :: * -> *) a. Applicative f => a -> f a
pure
            UToVar{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"lookup not expected to result in var"
            UConverted{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"conversion state not expected in applyBindings"
            UInstantiated{} ->
                -- This can happen in alphaEq,
                -- where UInstantiated marks that var from one side matches var in the other.
                TypeConstraintsOf t -> m (Pure # t)
quantify forall a. Monoid a => a
mempty

-- | Format and throw an occurs check error
occursError ::
    Unify m t =>
    UVarOf m # t ->
    UTermBody (UVarOf m) # t ->
    m a
occursError :: forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v (UTermBody TypeConstraintsOf (GetHyperType ('AHyperType t))
c 'AHyperType t :# UVarOf m
b) =
    do
        QVar t
q <- forall typeConstraints q (m :: * -> *).
MonadQuantify typeConstraints q m =>
typeConstraints -> m q
newQuantifiedVariable TypeConstraintsOf (GetHyperType ('AHyperType t))
c
        forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v (forall (v :: AHyperType -> *) (ast :: AHyperType).
Pure ast -> UTerm v ast
UResolved (forall (h :: AHyperType -> *) (j :: AHyperType -> *).
Iso (Pure # h) (Pure # j) (h # Pure) (j # Pure)
_Pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: AHyperType -> *) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall t b. AReview t b -> b -> t
# QVar t
q))
        forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: AHyperType -> *) (h :: AHyperType).
t h -> t h -> UnifyError t h
Occurs (forall (t :: AHyperType -> *) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall t b. AReview t b -> b -> t
# QVar t
q) 'AHyperType t :# UVarOf m
b)