{-# LANGUAGE FlexibleInstances, GeneralizedNewtypeDeriving, UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts, TemplateHaskell, EmptyCase #-}
module AST.Term.Nominal
( NominalDecl(..), nParams, nScheme, KWitness(..)
, NominalInst(..), nId, nArgs
, ToNom(..), tnId, tnVal
, FromNom(..), _FromNom
, HasNominalInst(..)
, NomVarTypes
, MonadNominals(..)
, LoadedNominalDecl, loadNominalDecl
) where
import AST
import AST.Class.Has (HasChild(..))
import AST.Class.Traversable (ContainedK(..))
import AST.Class.ZipMatch (ZipMatch(..))
import AST.Combinator.Flip (_Flip)
import AST.Infer
import AST.Recurse
import AST.Term.FuncType (FuncType(..))
import AST.Term.Map (TermMap(..), _TermMap)
import AST.Term.Scheme
import AST.TH.Internal.Instances (makeCommonInstances)
import AST.Unify
import AST.Unify.Generalize (GTerm(..), _GMono, instantiateWith, instantiateForAll)
import AST.Unify.New (newTerm)
import AST.Unify.QuantifiedVar (HasQuantifiedVar(..), OrdQVar)
import AST.Unify.Term (UTerm(..))
import Control.Applicative (Alternative(..))
import Control.DeepSeq (NFData)
import Control.Lens (Prism', makeLenses, makePrisms)
import qualified Control.Lens as Lens
import Control.Lens.Operators
import Control.Monad.Trans.Writer (execWriterT)
import Data.Binary (Binary)
import Data.Foldable (traverse_)
import Data.Kind (Type)
import Data.Proxy (Proxy(..))
import qualified Data.Map as Map
import Generics.Constraints (Constraints)
import GHC.Generics (Generic)
import Text.PrettyPrint ((<+>))
import qualified Text.PrettyPrint as Pretty
import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens)
import Prelude.Compat
type family NomVarTypes (t :: Knot -> Type) :: Knot -> Type
data NominalDecl typ k = NominalDecl
{ _nParams :: Tree (NomVarTypes typ) QVars
, _nScheme :: Scheme (NomVarTypes typ) typ k
} deriving Generic
data NominalInst nomId varTypes k = NominalInst
{ _nId :: nomId
, _nArgs :: Tree varTypes (QVarInstances (GetKnot k))
} deriving Generic
data ToNom nomId term k = ToNom
{ _tnId :: nomId
, _tnVal :: k # term
} deriving Generic
newtype FromNom nomId (term :: Knot -> *) (k :: Knot) = FromNom nomId
deriving newtype (Eq, Ord, Binary, NFData)
deriving stock (Show, Generic)
data LoadedNominalDecl typ v = LoadedNominalDecl
{ _lnParams :: Tree (NomVarTypes typ) (QVarInstances (GetKnot v))
, _lnForalls :: Tree (NomVarTypes typ) (QVarInstances (GetKnot v))
, _lnType :: Tree (GTerm (GetKnot v)) typ
} deriving Generic
makeLenses ''NominalDecl
makeLenses ''NominalInst
makeLenses ''ToNom
makePrisms ''FromNom
makeCommonInstances [''NominalDecl, ''NominalInst, ''ToNom, ''LoadedNominalDecl]
makeKTraversableAndBases ''NominalDecl
makeKTraversableApplyAndBases ''ToNom
makeKTraversableApplyAndBases ''FromNom
instance KNodes v => KNodes (NominalInst n v) where
type KNodesConstraint (NominalInst n v) c = KNodesConstraint v c
data KWitness (NominalInst n v) c = E_NominalInst_k (KWitness v c)
{-# INLINE kLiftConstraint #-}
kLiftConstraint (E_NominalInst_k w) = kLiftConstraint w
instance KFunctor v => KFunctor (NominalInst n v) where
{-# INLINE mapK #-}
mapK f = nArgs %~ mapK (\w -> _QVarInstances . Lens.mapped %~ f (E_NominalInst_k w))
instance KFoldable v => KFoldable (NominalInst n v) where
{-# INLINE foldMapK #-}
foldMapK f =
foldMapK (\w -> foldMap (f (E_NominalInst_k w)) . (^. _QVarInstances)) . (^. nArgs)
instance KTraversable v => KTraversable (NominalInst n v) where
{-# INLINE sequenceK #-}
sequenceK (NominalInst n v) =
traverseK (const (_QVarInstances (traverse runContainedK))) v
<&> NominalInst n
instance
( Eq nomId
, ZipMatch varTypes
, KTraversable varTypes
, KNodesConstraint varTypes ZipMatch
, KNodesConstraint varTypes OrdQVar
) =>
ZipMatch (NominalInst nomId varTypes) where
{-# INLINE zipMatch #-}
zipMatch (NominalInst xId x) (NominalInst yId y)
| xId /= yId = Nothing
| otherwise =
zipMatch x y
>>= traverseK
( Proxy @ZipMatch #*# Proxy @OrdQVar #>
\(Pair (QVarInstances c0) (QVarInstances c1)) ->
zipMatch (TermMap c0) (TermMap c1)
<&> (^. _TermMap)
<&> QVarInstances
)
<&> NominalInst xId
instance Constraints (ToNom nomId term k) Pretty => Pretty (ToNom nomId term k) where
pPrintPrec lvl p (ToNom nomId term) =
(pPrint nomId <> Pretty.text "#") <+> pPrintPrec lvl 11 term
& maybeParens (p > 10)
class (Pretty (QVar k), Pretty (outer # k)) => PrettyConstraints outer k
instance (Pretty (QVar k), Pretty (outer # k)) => PrettyConstraints outer k
instance
( Pretty nomId
, KApply varTypes, KFoldable varTypes
, KNodesConstraint varTypes (PrettyConstraints k)
) =>
Pretty (NominalInst nomId varTypes k) where
pPrint (NominalInst n vars) =
pPrint n <>
joinArgs
(foldMapK (Proxy @(PrettyConstraints k) #> mkArgs) vars)
where
joinArgs [] = mempty
joinArgs xs =
Pretty.text "[" <>
Pretty.sep (Pretty.punctuate (Pretty.text ",") xs)
<> Pretty.text "]"
mkArgs (QVarInstances m) =
Map.toList m <&>
\(k, v) ->
(pPrint k <> Pretty.text ":") <+> pPrint v
instance (RNodes t, KNodes (NomVarTypes t)) => KNodes (LoadedNominalDecl t) where
type KNodesConstraint (LoadedNominalDecl t) c =
( KNodesConstraint (NomVarTypes t) c
, c t
, Recursive c
)
data KWitness (LoadedNominalDecl t) n where
E_LoadedNominalDecl_Body :: KRecWitness t n -> KWitness (LoadedNominalDecl t) n
E_LoadedNominalDecl_NomVarTypes :: KWitness (NomVarTypes t) n -> KWitness (LoadedNominalDecl t) n
{-# INLINE kLiftConstraint #-}
kLiftConstraint (E_LoadedNominalDecl_Body w) = kLiftConstraint (E_Flip_GTerm w)
kLiftConstraint (E_LoadedNominalDecl_NomVarTypes w) = kLiftConstraint w
instance
(Recursively KFunctor typ, KFunctor (NomVarTypes typ)) =>
KFunctor (LoadedNominalDecl typ) where
{-# INLINE mapK #-}
mapK f (LoadedNominalDecl mp mf t) =
LoadedNominalDecl (onMap mp) (onMap mf)
(t & Lens.from _Flip %~ mapK (\(E_Flip_GTerm w) -> f (E_LoadedNominalDecl_Body w)))
where
onMap = mapK (\w -> _QVarInstances . Lens.mapped %~ f (E_LoadedNominalDecl_NomVarTypes w))
instance
(Recursively KFoldable typ, KFoldable (NomVarTypes typ)) =>
KFoldable (LoadedNominalDecl typ) where
{-# INLINE foldMapK #-}
foldMapK f (LoadedNominalDecl mp mf t) =
onMap mp <> onMap mf <>
foldMapK (\(E_Flip_GTerm w) -> f (E_LoadedNominalDecl_Body w)) (_Flip # t)
where
onMap = foldMapK (\w -> foldMap (f (E_LoadedNominalDecl_NomVarTypes w)) . (^. _QVarInstances))
instance
(RTraversable typ, KTraversable (NomVarTypes typ)) =>
KTraversable (LoadedNominalDecl typ) where
{-# INLINE sequenceK #-}
sequenceK (LoadedNominalDecl p f t) =
LoadedNominalDecl
<$> onMap p
<*> onMap f
<*> Lens.from _Flip sequenceK t
where
onMap = traverseK (const ((_QVarInstances . traverse) runContainedK))
{-# INLINE loadBody #-}
loadBody ::
( Unify m typ
, HasChild varTypes typ
, Ord (QVar typ)
) =>
Tree varTypes (QVarInstances (UVarOf m)) ->
Tree varTypes (QVarInstances (UVarOf m)) ->
Tree typ (GTerm (UVarOf m)) ->
m (Tree (GTerm (UVarOf m)) typ)
loadBody params foralls x =
case x ^? quantifiedVar >>= get of
Just r -> GPoly r & pure
Nothing ->
case traverseK (const (^? _GMono)) x of
Just xm -> newTerm xm <&> GMono
Nothing -> GBody x & pure
where
get v =
params ^? getChild . _QVarInstances . Lens.ix v <|>
foralls ^? getChild . _QVarInstances . Lens.ix v
{-# INLINE loadNominalDecl #-}
loadNominalDecl ::
forall m typ.
( Monad m
, KTraversable (NomVarTypes typ)
, KNodesConstraint (NomVarTypes typ) (Unify m)
, HasScheme (NomVarTypes typ) m typ
) =>
Tree Pure (NominalDecl typ) ->
m (Tree (LoadedNominalDecl typ) (UVarOf m))
loadNominalDecl (Pure (NominalDecl params (Scheme foralls typ))) =
do
paramsL <- traverseK (Proxy @(Unify m) #> makeQVarInstances) params
forallsL <- traverseK (Proxy @(Unify m) #> makeQVarInstances) foralls
wrapM
(Proxy @(HasScheme (NomVarTypes typ) m) #>>
loadBody paramsL forallsL
) typ
<&> LoadedNominalDecl paramsL forallsL
class MonadNominals nomId typ m where
getNominalDecl :: nomId -> m (Tree (LoadedNominalDecl typ) (UVarOf m))
class HasNominalInst nomId typ where
nominalInst :: Prism' (Tree typ k) (Tree (NominalInst nomId (NomVarTypes typ)) k)
{-# INLINE lookupParams #-}
lookupParams ::
forall m varTypes.
( Applicative m
, KTraversable varTypes
, KNodesConstraint varTypes (Unify m)
) =>
Tree varTypes (QVarInstances (UVarOf m)) ->
m (Tree varTypes (QVarInstances (UVarOf m)))
lookupParams =
traverseK (Proxy @(Unify m) #> (_QVarInstances . traverse) lookupParam)
where
lookupParam v =
lookupVar binding v
>>=
\case
UInstantiated r -> pure r
USkolem l ->
scopeConstraints <&> (<> l) >>= newVar binding . UUnbound
_ -> error "unexpected state at nominal's parameter"
type instance InferOf (ToNom n e) = NominalInst n (NomVarTypes (TypeOf e))
instance
( MonadScopeLevel m
, MonadNominals nomId (TypeOf expr) m
, KTraversable (NomVarTypes (TypeOf expr))
, KNodesConstraint (NomVarTypes (TypeOf expr)) (Unify m)
, Unify m (TypeOf expr)
, HasInferredType expr
, Infer m expr
) =>
Infer m (ToNom nomId expr) where
{-# INLINE inferBody #-}
inferBody (ToNom nomId val) =
do
(InferredChild valI valR, typ, paramsT) <-
do
v <- inferChild val
LoadedNominalDecl params foralls gen <- getNominalDecl nomId
recover <-
traverseK_
( Proxy @(Unify m) #>
traverse_ (instantiateForAll USkolem) . (^. _QVarInstances)
) foralls
& execWriterT
(typ, paramsT) <- instantiateWith (lookupParams params) UUnbound gen
(v, typ, paramsT) <$ sequence_ recover
& localLevel
(ToNom nomId valI, NominalInst nomId paramsT)
<$ unify typ (valR ^# inferredType (Proxy @expr))
type instance InferOf (FromNom n e) = FuncType (TypeOf e)
instance
( Infer m expr
, HasNominalInst nomId (TypeOf expr)
, MonadNominals nomId (TypeOf expr) m
, KTraversable (NomVarTypes (TypeOf expr))
, KNodesConstraint (NomVarTypes (TypeOf expr)) (Unify m)
, Unify m (TypeOf expr)
) =>
Infer m (FromNom nomId expr) where
{-# INLINE inferBody #-}
inferBody (FromNom nomId) =
do
LoadedNominalDecl params _ gen <- getNominalDecl nomId
(typ, paramsT) <- instantiateWith (lookupParams params) UUnbound gen
nominalInst # NominalInst nomId paramsT & newTerm
<&> (`FuncType` typ)
<&> (FromNom nomId, )