{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TemplateHaskell #-}

module Hyper.Class.Infer
    ( InferOf
    , Infer (..)
    , InferChild (..)
    , _InferChild
    , InferredChild (..)
    , inType
    , inRep
    ) where

import qualified Control.Lens as Lens
import GHC.Generics
import Hyper
import Hyper.Class.Unify
import Hyper.Recurse

import Hyper.Internal.Prelude

-- | @InferOf e@ is the inference result of @e@.
--
-- Most commonly it is an inferred type, using
--
-- > type instance InferOf MyTerm = ANode MyType
--
-- But it may also be other things, for example:
--
-- * An inferred value (for types inside terms)
-- * An inferred type together with a scope
type family InferOf (t :: HyperType) :: HyperType

-- | A 'HyperType' containing an inferred child node
data InferredChild v h t = InferredChild
    { forall (v :: HyperType) (h :: HyperType) (t :: AHyperType).
InferredChild v h t -> h t
_inRep :: !(h t)
    -- ^ Inferred node.
    --
    -- An 'inferBody' implementation needs to place this value in the corresponding child node of the inferred term body
    , forall (v :: HyperType) (h :: HyperType) (t :: AHyperType).
InferredChild v h t -> InferOf (GetHyperType t) # v
_inType :: !(InferOf (GetHyperType t) # v)
    -- ^ The inference result for the child node.
    --
    -- An 'inferBody' implementation may use it to perform unifications with it.
    }

makeLenses ''InferredChild

-- | A 'HyperType' containing an inference action.
--
-- The caller may modify the scope before invoking the action via
-- 'Hyper.Class.Infer.Env.localScopeType' or 'Hyper.Infer.ScopeLevel.localLevel'
newtype InferChild m h t = InferChild {forall (m :: * -> *) (h :: HyperType) (t :: AHyperType).
InferChild m h t -> m (InferredChild (UVarOf m) h t)
inferChild :: m (InferredChild (UVarOf m) h t)}

makePrisms ''InferChild

-- | @Infer m t@ enables 'Hyper.Infer.infer' to perform type-inference for @t@ in the 'Monad' @m@.
--
-- The 'inferContext' method represents the following constraints on @t@:
--
-- * @HNodesConstraint (InferOf t) (Unify m)@ - The child nodes of the inferrence can unify in the @m@ 'Monad'
-- * @HNodesConstraint t (Infer m)@ - @Infer m@ is also available for child nodes
--
-- It replaces context for the 'Infer' class to avoid @UndecidableSuperClasses@.
--
-- Instances usually don't need to implement this method as the default implementation works for them,
-- but infinitely polymorphic trees such as 'Hyper.Type.AST.NamelessScope.Scope' do need to implement the method,
-- because the required context is infinite.
class (Monad m, HFunctor t) => Infer m t where
    -- | Infer the body of an expression given the inference actions for its child nodes.
    inferBody ::
        t # InferChild m h ->
        m (t # h, InferOf t # UVarOf m)
    default inferBody ::
        (Generic1 t, Infer m (Rep1 t), InferOf t ~ InferOf (Rep1 t)) =>
        t # InferChild m h ->
        m (t # h, InferOf t # UVarOf m)
    inferBody =
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ forall k (f :: k -> *) (a :: k). Generic1 f => Rep1 f a -> f a
to1) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1

    -- TODO: Putting documentation here causes duplication in the haddock documentation
    inferContext ::
        proxy0 m ->
        proxy1 t ->
        Dict (HNodesConstraint t (Infer m), HNodesConstraint (InferOf t) (UnifyGen m))
    {-# INLINE inferContext #-}
    default inferContext ::
        (HNodesConstraint t (Infer m), HNodesConstraint (InferOf t) (UnifyGen m)) =>
        proxy0 m ->
        proxy1 t ->
        Dict (HNodesConstraint t (Infer m), HNodesConstraint (InferOf t) (UnifyGen m))
    inferContext proxy0 m
_ proxy1 t
_ = forall (a :: Constraint). a => Dict a
Dict

instance Recursive (Infer m) where
    {-# INLINE recurse #-}
    recurse :: forall (h :: HyperType) (proxy :: Constraint -> *).
(HNodes h, Infer m h) =>
proxy (Infer m h) -> Dict (HNodesConstraint h (Infer m))
recurse proxy (Infer m h)
p = forall (a :: Constraint). a => Dict a
Dict forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext (forall {k} (t :: k). Proxy t
Proxy @m) (forall (proxy :: Constraint -> *) (f :: HyperType -> Constraint)
       (h :: HyperType).
proxy (f h) -> Proxy h
proxyArgument proxy (Infer m h)
p)

type instance InferOf (a :+: _) = InferOf a

instance (InferOf a ~ InferOf b, Infer m a, Infer m b) => Infer m (a :+: b) where
    {-# INLINE inferBody #-}
    inferBody :: forall (h :: HyperType).
((a :+: b) # InferChild m h)
-> m ((a :+: b) # h, InferOf (a :+: b) # UVarOf m)
inferBody (L1 a ('AHyperType (InferChild m h))
x) = forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody a ('AHyperType (InferChild m h))
x forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1
    inferBody (R1 b ('AHyperType (InferChild m h))
x) = forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody b ('AHyperType (InferChild m h))
x forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1

    {-# INLINE inferContext #-}
    inferContext :: forall (proxy0 :: (* -> *) -> *) (proxy1 :: HyperType -> *).
proxy0 m
-> proxy1 (a :+: b)
-> Dict
     (HNodesConstraint (a :+: b) (Infer m),
      HNodesConstraint (InferOf (a :+: b)) (UnifyGen m))
inferContext proxy0 m
p proxy1 (a :+: b)
_ = forall (a :: Constraint). a => Dict a
Dict forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (forall {k} (t :: k). Proxy t
Proxy @a) forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (forall {k} (t :: k). Proxy t
Proxy @b)

type instance InferOf (M1 _ _ h) = InferOf h

instance Infer m h => Infer m (M1 i c h) where
    {-# INLINE inferBody #-}
    inferBody :: forall (h :: HyperType).
(M1 i c h # InferChild m h)
-> m (M1 i c h # h, InferOf (M1 i c h) # UVarOf m)
inferBody (M1 h ('AHyperType (InferChild m h))
x) = forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody h ('AHyperType (InferChild m h))
x forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1

    {-# INLINE inferContext #-}
    inferContext :: forall (proxy0 :: (* -> *) -> *) (proxy1 :: HyperType -> *).
proxy0 m
-> proxy1 (M1 i c h)
-> Dict
     (HNodesConstraint (M1 i c h) (Infer m),
      HNodesConstraint (InferOf (M1 i c h)) (UnifyGen m))
inferContext proxy0 m
p proxy1 (M1 i c h)
_ = forall (a :: Constraint). a => Dict a
Dict forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (forall {k} (t :: k). Proxy t
Proxy @h)

type instance InferOf (Rec1 h) = InferOf h

instance Infer m h => Infer m (Rec1 h) where
    {-# INLINE inferBody #-}
    inferBody :: forall (h :: HyperType).
(Rec1 h # InferChild m h)
-> m (Rec1 h # h, InferOf (Rec1 h) # UVarOf m)
inferBody (Rec1 h ('AHyperType (InferChild m h))
x) = forall (m :: * -> *) (t :: HyperType) (h :: HyperType).
Infer m t =>
(t # InferChild m h) -> m (t # h, InferOf t # UVarOf m)
inferBody h ('AHyperType (InferChild m h))
x forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall s t a b. Field1 s t a b => Lens s t a b
Lens._1 forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ forall k (f :: k -> *) (p :: k). f p -> Rec1 f p
Rec1

    {-# INLINE inferContext #-}
    inferContext :: forall (proxy0 :: (* -> *) -> *) (proxy1 :: HyperType -> *).
proxy0 m
-> proxy1 (Rec1 h)
-> Dict
     (HNodesConstraint (Rec1 h) (Infer m),
      HNodesConstraint (InferOf (Rec1 h)) (UnifyGen m))
inferContext proxy0 m
p proxy1 (Rec1 h)
_ = forall (a :: Constraint). a => Dict a
Dict forall (c :: Constraint) e r. HasDict c e => (c => r) -> e -> r
\\ forall (m :: * -> *) (t :: HyperType) (proxy0 :: (* -> *) -> *)
       (proxy1 :: HyperType -> *).
Infer m t =>
proxy0 m
-> proxy1 t
-> Dict
     (HNodesConstraint t (Infer m),
      HNodesConstraint (InferOf t) (UnifyGen m))
inferContext proxy0 m
p (forall {k} (t :: k). Proxy t
Proxy @h)