{-# LANGUAGE FlexibleContexts #-}
module Hyper.Class.Unify
( Unify (..)
, UVarOf
, UnifyGen (..)
, BindingDict (..)
, applyBindings
, semiPruneLookup
, occursError
) where
import Control.Monad (unless)
import Control.Monad.Error.Class (MonadError (..))
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Trans.State (get, put, runStateT)
import Data.Kind (Type)
import Hyper.Class.Nodes (HNodes (..), (#>))
import Hyper.Class.Optic (HSubset (..), HSubset')
import Hyper.Class.Recursive
import Hyper.Class.Traversable (htraverse)
import Hyper.Class.ZipMatch (ZipMatch)
import Hyper.Type (HyperType, type (#))
import Hyper.Type.Pure (Pure, _Pure)
import Hyper.Unify.Constraints
import Hyper.Unify.Error (UnifyError (..))
import Hyper.Unify.QuantifiedVar (HasQuantifiedVar (..), MonadQuantify (..))
import Hyper.Unify.Term (UTerm (..), UTermBody (..), uBody)
import Hyper.Internal.Prelude
type family UVarOf (m :: Type -> Type) :: HyperType
data BindingDict v m t = BindingDict
{ forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar :: !(v # t -> m (UTerm v # t))
, forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (UTerm v # t) -> m (v # t)
newVar :: !(UTerm v # t -> m (v # t))
, forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar :: !(v # t -> UTerm v # t -> m ())
}
class
( Eq (UVarOf m # t)
, RTraversable t
, ZipMatch t
, HasTypeConstraints t
, HasQuantifiedVar t
, Monad m
, MonadQuantify (TypeConstraintsOf t) (QVar t) m
) =>
Unify m t
where
binding :: BindingDict (UVarOf m) m t
unifyError :: UnifyError t # UVarOf m -> m a
default unifyError ::
(MonadError (e # Pure) m, HSubset' e (UnifyError t)) =>
UnifyError t # UVarOf m ->
m a
unifyError UnifyError t # UVarOf m
e =
forall (f :: * -> *) (h :: AHyperType -> *) (p :: AHyperType -> *)
(q :: AHyperType -> *).
(Applicative f, HTraversable h) =>
(forall (n :: AHyperType -> *).
HWitness h n -> (p # n) -> f (q # n))
-> (h # p) -> f (h # q)
htraverse (forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: AHyperType -> *)
(c :: (AHyperType -> *) -> Constraint) (n :: AHyperType -> *) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#> forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (Pure # t)
applyBindings) UnifyError t # UVarOf m
e
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (s :: AHyperType -> *) (t :: AHyperType -> *)
(a :: AHyperType -> *) (b :: AHyperType -> *)
(h :: AHyperType -> *).
HSubset s t a b =>
Prism (s # h) (t # h) (a # h) (b # h)
hSubset forall t b. AReview t b -> b -> t
#)
forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)
structureMismatch ::
(forall c. Unify m c => UVarOf m # c -> UVarOf m # c -> m (UVarOf m # c)) ->
t # UVarOf m ->
t # UVarOf m ->
m ()
structureMismatch forall (c :: AHyperType -> *).
Unify m c =>
(UVarOf m # c) -> (UVarOf m # c) -> m (UVarOf m # c)
_ t # UVarOf m
x t # UVarOf m
y = forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: AHyperType -> *) (h :: AHyperType).
t h -> t h -> UnifyError t h
Mismatch t # UVarOf m
x t # UVarOf m
y)
unifyRecursive :: Proxy m -> RecMethod (Unify m) t
{-# INLINE unifyRecursive #-}
default unifyRecursive :: HNodesConstraint t (Unify m) => Proxy m -> RecMethod (Unify m) t
unifyRecursive Proxy m
_ Proxy t
_ = forall (a :: Constraint). a => Dict a
Dict
instance Recursive (Unify m) where
{-# INLINE recurse #-}
recurse :: forall (h :: AHyperType -> *) (proxy :: Constraint -> *).
(HNodes h, Unify m h) =>
proxy (Unify m h) -> Dict (HNodesConstraint h (Unify m))
recurse = forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (proxy :: Constraint -> *)
(f :: (AHyperType -> *) -> Constraint) (h :: AHyperType -> *).
proxy (f h) -> Proxy h
proxyArgument
class Unify m t => UnifyGen m t where
scopeConstraints :: Proxy t -> m (TypeConstraintsOf t)
unifyGenRecursive :: Proxy m -> RecMethod (UnifyGen m) t
{-# INLINE unifyGenRecursive #-}
default unifyGenRecursive ::
HNodesConstraint t (UnifyGen m) => Proxy m -> RecMethod (UnifyGen m) t
unifyGenRecursive Proxy m
_ Proxy t
_ = forall (a :: Constraint). a => Dict a
Dict
instance Recursive (UnifyGen m) where
{-# INLINE recurse #-}
recurse :: forall (h :: AHyperType -> *) (proxy :: Constraint -> *).
(HNodes h, UnifyGen m h) =>
proxy (UnifyGen m h) -> Dict (HNodesConstraint h (UnifyGen m))
recurse = forall (m :: * -> *) (t :: AHyperType -> *).
UnifyGen m t =>
Proxy m -> RecMethod (UnifyGen m) t
unifyGenRecursive (forall {k} (t :: k). Proxy t
Proxy @m) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (proxy :: Constraint -> *)
(f :: (AHyperType -> *) -> Constraint) (h :: AHyperType -> *).
proxy (f h) -> Proxy h
proxyArgument
{-# INLINE semiPruneLookup #-}
semiPruneLookup ::
Unify m t =>
UVarOf m # t ->
m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup :: forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v0 =
forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> m (UTerm v # t)
lookupVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v0
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
UToVar UVarOf m # t
v1 ->
do
(UVarOf m # t
v, UTerm (UVarOf m) # t
r) <- forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v1
forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v0 (forall (v :: AHyperType -> *) (ast :: AHyperType).
v ast -> UTerm v ast
UToVar UVarOf m # t
v)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UVarOf m # t
v, UTerm (UVarOf m) # t
r)
UTerm (UVarOf m) # t
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (UVarOf m # t
v0, UTerm (UVarOf m) # t
t)
{-# INLINE applyBindings #-}
applyBindings ::
forall m t.
Unify m t =>
UVarOf m # t ->
m (Pure # t)
applyBindings :: forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (Pure # t)
applyBindings UVarOf m # t
v0 =
do
(UVarOf m # t
v1, UTerm (UVarOf m) ('AHyperType t)
x) <- forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (UVarOf m # t, UTerm (UVarOf m) # t)
semiPruneLookup UVarOf m # t
v0
let result :: (Pure # t) -> m (Pure # t)
result Pure # t
r = Pure # t
r forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (forall (v :: AHyperType -> *) (ast :: AHyperType).
Pure ast -> UTerm v ast
UResolved Pure # t
r)
let quantify :: TypeConstraintsOf t -> m (Pure # t)
quantify TypeConstraintsOf t
c =
forall typeConstraints q (m :: * -> *).
MonadQuantify typeConstraints q m =>
typeConstraints -> m q
newQuantifiedVariable TypeConstraintsOf t
c
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall (h :: AHyperType -> *) (j :: AHyperType -> *).
Iso (Pure # h) (Pure # j) (h # Pure) (j # Pure)
_Pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: AHyperType -> *) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall t b. AReview t b -> b -> t
#)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Pure # t) -> m (Pure # t)
result
case UTerm (UVarOf m) ('AHyperType t)
x of
UResolving UTermBody (UVarOf m) ('AHyperType t)
t -> forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v1 UTermBody (UVarOf m) ('AHyperType t)
t
UResolved Pure # t
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Pure # t
t
UUnbound TypeConstraintsOf (GetHyperType ('AHyperType t))
c -> TypeConstraintsOf t -> m (Pure # t)
quantify TypeConstraintsOf (GetHyperType ('AHyperType t))
c
USkolem TypeConstraintsOf (GetHyperType ('AHyperType t))
c -> TypeConstraintsOf t -> m (Pure # t)
quantify TypeConstraintsOf (GetHyperType ('AHyperType t))
c
UTerm UTermBody (UVarOf m) ('AHyperType t)
b ->
do
(t # Pure
r, Bool
anyChild) <-
forall (f :: * -> *) (h :: AHyperType -> *) (p :: AHyperType -> *)
(q :: AHyperType -> *).
(Applicative f, HTraversable h) =>
(forall (n :: AHyperType -> *).
HWitness h n -> (p # n) -> f (q # n))
-> (h # p) -> f (h # q)
htraverse
( forall {k} (t :: k). Proxy t
Proxy @(Unify m) forall (h :: AHyperType -> *)
(c :: (AHyperType -> *) -> Constraint) (n :: AHyperType -> *) r.
(HNodes h, HNodesConstraint h c) =>
Proxy c -> (c n => r) -> HWitness h n -> r
#>
\UVarOf m # n
c ->
do
forall (m :: * -> *) s. Monad m => StateT s m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
`unless` forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v1 (forall (v :: AHyperType -> *) (ast :: AHyperType).
UTermBody v ast -> UTerm v ast
UResolving UTermBody (UVarOf m) ('AHyperType t)
b))
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put Bool
True
forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
(UVarOf m # t) -> m (Pure # t)
applyBindings UVarOf m # n
c forall a b. a -> (a -> b) -> b
& forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
)
(UTermBody (UVarOf m) ('AHyperType t)
b forall s a. s -> Getting a s a -> a
^. forall (v1 :: AHyperType -> *) (ast :: AHyperType)
(v2 :: AHyperType -> *).
Lens (UTermBody v1 ast) (UTermBody v2 ast) (ast :# v1) (ast :# v2)
uBody)
forall a b. a -> (a -> b) -> b
& (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
`runStateT` Bool
False)
forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
Proxy m -> RecMethod (Unify m) t
unifyRecursive (forall {k} (t :: k). Proxy t
Proxy @m) (forall {k} (t :: k). Proxy t
Proxy @t)
forall (h :: AHyperType -> *) (j :: AHyperType -> *).
Iso (Pure # h) (Pure # j) (h # Pure) (j # Pure)
_Pure forall t b. AReview t b -> b -> t
# t # Pure
r forall a b. a -> (a -> b) -> b
& if Bool
anyChild then (Pure # t) -> m (Pure # t)
result else forall (f :: * -> *) a. Applicative f => a -> f a
pure
UToVar{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"lookup not expected to result in var"
UConverted{} -> forall a. HasCallStack => [Char] -> a
error [Char]
"conversion state not expected in applyBindings"
UInstantiated{} ->
TypeConstraintsOf t -> m (Pure # t)
quantify forall a. Monoid a => a
mempty
occursError ::
Unify m t =>
UVarOf m # t ->
UTermBody (UVarOf m) # t ->
m a
occursError :: forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UVarOf m # t) -> (UTermBody (UVarOf m) # t) -> m a
occursError UVarOf m # t
v (UTermBody TypeConstraintsOf (GetHyperType ('AHyperType t))
c 'AHyperType t :# UVarOf m
b) =
do
QVar t
q <- forall typeConstraints q (m :: * -> *).
MonadQuantify typeConstraints q m =>
typeConstraints -> m q
newQuantifiedVariable TypeConstraintsOf (GetHyperType ('AHyperType t))
c
forall (v :: AHyperType -> *) (m :: * -> *) (t :: AHyperType -> *).
BindingDict v m t -> (v # t) -> (UTerm v # t) -> m ()
bindVar forall (m :: * -> *) (t :: AHyperType -> *).
Unify m t =>
BindingDict (UVarOf m) m t
binding UVarOf m # t
v (forall (v :: AHyperType -> *) (ast :: AHyperType).
Pure ast -> UTerm v ast
UResolved (forall (h :: AHyperType -> *) (j :: AHyperType -> *).
Iso (Pure # h) (Pure # j) (h # Pure) (j # Pure)
_Pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: AHyperType -> *) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall t b. AReview t b -> b -> t
# QVar t
q))
forall (m :: * -> *) (t :: AHyperType -> *) a.
Unify m t =>
(UnifyError t # UVarOf m) -> m a
unifyError (forall (t :: AHyperType -> *) (h :: AHyperType).
t h -> t h -> UnifyError t h
Occurs (forall (t :: AHyperType -> *) (f :: AHyperType).
HasQuantifiedVar t =>
Prism' (t f) (QVar t)
quantifiedVar forall t b. AReview t b -> b -> t
# QVar t
q) 'AHyperType t :# UVarOf m
b)