module Cryptol.TypeCheck.CheckModuleInstance (checkModuleInstance) where
import Data.Map ( Map )
import qualified Data.Map as Map
import Control.Monad(unless)
import Cryptol.Parser.Position(Located(..))
import qualified Cryptol.Parser.AST as P
import Cryptol.ModuleSystem.Name (nameIdent, nameLoc)
import Cryptol.ModuleSystem.InstantiateModule(instantiateModule)
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Monad
import Cryptol.TypeCheck.Infer
import Cryptol.TypeCheck.Subst
import Cryptol.TypeCheck.Error
import Cryptol.Utils.Panic
checkModuleInstance :: Module ->
Module ->
InferM Module
checkModuleInstance :: Module -> Module -> InferM Module
checkModuleInstance Module
func Module
inst =
do Map TParam Type
tMap <- Module -> Module -> InferM (Map TParam Type)
checkTyParams Module
func Module
inst
Map Name Expr
vMap <- Module -> Map TParam Type -> Module -> InferM (Map Name Expr)
checkValParams Module
func Map TParam Type
tMap Module
inst
([Located Type]
ctrs, Module
m) <- Module
-> ModName
-> Map TParam Type
-> Map Name Expr
-> InferM ([Located Type], Module)
forall (m :: * -> *).
FreshM m =>
Module
-> ModName
-> Map TParam Type
-> Map Name Expr
-> m ([Located Type], Module)
instantiateModule Module
func (Module -> ModName
mName Module
inst) Map TParam Type
tMap Map Name Expr
vMap
let toG :: Located Type -> Goal
toG Located Type
p = Goal :: ConstraintSource -> Range -> Type -> Goal
Goal { goal :: Type
goal = Located Type -> Type
forall a. Located a -> a
thing Located Type
p
, goalRange :: Range
goalRange = Located Type -> Range
forall a. Located a -> Range
srcRange Located Type
p
, goalSource :: ConstraintSource
goalSource = ModName -> ConstraintSource
CtModuleInstance (Module -> ModName
mName Module
inst)
}
[Goal] -> InferM ()
addGoals ((Located Type -> Goal) -> [Located Type] -> [Goal]
forall a b. (a -> b) -> [a] -> [b]
map Located Type -> Goal
toG [Located Type]
ctrs)
Module -> InferM Module
forall (m :: * -> *) a. Monad m => a -> m a
return Module :: ModName
-> ExportSpec Name
-> [Import]
-> Map Name TySyn
-> Map Name Newtype
-> Map Name AbstractType
-> Map Name ModTParam
-> [Located Type]
-> Map Name ModVParam
-> [DeclGroup]
-> Module
Module { mName :: ModName
mName = Module -> ModName
mName Module
m
, mExports :: ExportSpec Name
mExports = Module -> ExportSpec Name
mExports Module
m
, mImports :: [Import]
mImports = Module -> [Import]
mImports Module
inst [Import] -> [Import] -> [Import]
forall a. [a] -> [a] -> [a]
++ Module -> [Import]
mImports Module
m
, mTySyns :: Map Name TySyn
mTySyns = Map Name TySyn -> Map Name TySyn -> Map Name TySyn
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union (Module -> Map Name TySyn
mTySyns Module
inst) (Module -> Map Name TySyn
mTySyns Module
m)
, mNewtypes :: Map Name Newtype
mNewtypes = Map Name Newtype -> Map Name Newtype -> Map Name Newtype
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union (Module -> Map Name Newtype
mNewtypes Module
inst) (Module -> Map Name Newtype
mNewtypes Module
m)
, mPrimTypes :: Map Name AbstractType
mPrimTypes = Map Name AbstractType
-> Map Name AbstractType -> Map Name AbstractType
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union (Module -> Map Name AbstractType
mPrimTypes Module
inst) (Module -> Map Name AbstractType
mPrimTypes Module
m)
, mParamTypes :: Map Name ModTParam
mParamTypes = Module -> Map Name ModTParam
mParamTypes Module
inst
, mParamConstraints :: [Located Type]
mParamConstraints = Module -> [Located Type]
mParamConstraints Module
inst
, mParamFuns :: Map Name ModVParam
mParamFuns = Module -> Map Name ModVParam
mParamFuns Module
inst
, mDecls :: [DeclGroup]
mDecls = Module -> [DeclGroup]
mDecls Module
inst [DeclGroup] -> [DeclGroup] -> [DeclGroup]
forall a. [a] -> [a] -> [a]
++ Module -> [DeclGroup]
mDecls Module
m
}
checkTyParams :: Module -> Module -> InferM (Map TParam Type)
checkTyParams :: Module -> Module -> InferM (Map TParam Type)
checkTyParams Module
func Module
inst =
[(TParam, Type)] -> Map TParam Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(TParam, Type)] -> Map TParam Type)
-> InferM [(TParam, Type)] -> InferM (Map TParam Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ModTParam -> InferM (TParam, Type))
-> [ModTParam] -> InferM [(TParam, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ModTParam -> InferM (TParam, Type)
checkTParamDefined (Map Name ModTParam -> [ModTParam]
forall k a. Map k a -> [a]
Map.elems (Module -> Map Name ModTParam
mParamTypes Module
func))
where
identMap :: (t -> k) -> Map t a -> Map k a
identMap t -> k
f Map t a
m = [(k, a)] -> Map k a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (t -> k
f t
x, a
ts) | (t
x,a
ts) <- Map t a -> [(t, a)]
forall k a. Map k a -> [(k, a)]
Map.toList Map t a
m ]
tySyns :: Map Ident TySyn
tySyns = (Name -> Ident) -> Map Name TySyn -> Map Ident TySyn
forall k t a. Ord k => (t -> k) -> Map t a -> Map k a
identMap Name -> Ident
nameIdent (Module -> Map Name TySyn
mTySyns Module
inst)
newTys :: Map Ident Newtype
newTys = (Name -> Ident) -> Map Name Newtype -> Map Ident Newtype
forall k t a. Ord k => (t -> k) -> Map t a -> Map k a
identMap Name -> Ident
nameIdent (Module -> Map Name Newtype
mNewtypes Module
inst)
tParams :: Map Ident TParam
tParams = [(Ident, TParam)] -> Map Ident TParam
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (TParam -> Ident
tpId TParam
x, TParam
x) | ModTParam
x0 <- Map Name ModTParam -> [ModTParam]
forall k a. Map k a -> [a]
Map.elems (Module -> Map Name ModTParam
mParamTypes Module
inst)
, let x :: TParam
x = ModTParam -> TParam
mtpParam ModTParam
x0 ]
tpName' :: TParam -> Name
tpName' TParam
x = case TParam -> Maybe Name
tpName TParam
x of
Just Name
n -> Name
n
Maybe Name
Nothing -> String -> [String] -> Name
forall a. HasCallStack => String -> [String] -> a
panic String
"inferModuleInstance.tpId" [String
"Missing name"]
tpId :: TParam -> Ident
tpId = Name -> Ident
nameIdent (Name -> Ident) -> (TParam -> Name) -> TParam -> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TParam -> Name
tpName'
checkTParamDefined :: ModTParam -> InferM (TParam, Type)
checkTParamDefined ModTParam
tp0 =
let tp :: TParam
tp = ModTParam -> TParam
mtpParam ModTParam
tp0
x :: Ident
x = TParam -> Ident
tpId TParam
tp
in case Ident -> Map Ident TySyn -> Maybe TySyn
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Ident
x Map Ident TySyn
tySyns of
Just TySyn
ts -> TParam -> TySyn -> InferM (TParam, Type)
forall t. HasKind t => t -> TySyn -> InferM (t, Type)
checkTySynDef TParam
tp TySyn
ts
Maybe TySyn
Nothing ->
case Ident -> Map Ident Newtype -> Maybe Newtype
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Ident
x Map Ident Newtype
newTys of
Just Newtype
nt -> TParam -> Newtype -> InferM (TParam, Type)
forall t. HasKind t => t -> Newtype -> InferM (t, Type)
checkNewTyDef TParam
tp Newtype
nt
Maybe Newtype
Nothing ->
case Ident -> Map Ident TParam -> Maybe TParam
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Ident
x Map Ident TParam
tParams of
Just TParam
tp1 -> TParam -> TParam -> InferM (TParam, Type)
forall t. HasKind t => t -> TParam -> InferM (t, Type)
checkTP TParam
tp TParam
tp1
Maybe TParam
Nothing ->
do let x' :: Located Ident
x' = Located :: forall a. Range -> a -> Located a
Located { thing :: Ident
thing = Ident
x,
srcRange :: Range
srcRange = Name -> Range
nameLoc (TParam -> Name
tpName' TParam
tp) }
Error -> InferM ()
recordError (Located Ident -> Error
MissingModTParam Located Ident
x')
(TParam, Type) -> InferM (TParam, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (TParam
tp, TVar -> Type
TVar (TParam -> TVar
TVBound TParam
tp))
checkTySynDef :: t -> TySyn -> InferM (t, Type)
checkTySynDef t
tp TySyn
ts =
do let k1 :: Kind
k1 = t -> Kind
forall t. HasKind t => t -> Kind
kindOf t
tp
k2 :: Kind
k2 = TySyn -> Kind
forall t. HasKind t => t -> Kind
kindOf TySyn
ts
Bool -> InferM () -> InferM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2) (Error -> InferM ()
recordError (Maybe TypeSource -> Kind -> Kind -> Error
KindMismatch Maybe TypeSource
forall a. Maybe a
Nothing Kind
k1 Kind
k2))
let nm :: Name
nm = TySyn -> Name
tsName TySyn
ts
src :: ConstraintSource
src = Name -> ConstraintSource
CtPartialTypeFun Name
nm
(Type -> InferM Goal) -> [Type] -> InferM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ConstraintSource -> Type -> InferM Goal
newGoal ConstraintSource
src) (TySyn -> [Type]
tsConstraints TySyn
ts)
(t, Type) -> InferM (t, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (t
tp, Name -> [Type] -> Type -> Type
TUser Name
nm [] (TySyn -> Type
tsDef TySyn
ts))
checkNewTyDef :: t -> Newtype -> InferM (t, Type)
checkNewTyDef t
tp Newtype
nt =
do let k1 :: Kind
k1 = t -> Kind
forall t. HasKind t => t -> Kind
kindOf t
tp
k2 :: Kind
k2 = Newtype -> Kind
forall t. HasKind t => t -> Kind
kindOf Newtype
nt
Bool -> InferM () -> InferM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2) (Error -> InferM ()
recordError (Maybe TypeSource -> Kind -> Kind -> Error
KindMismatch Maybe TypeSource
forall a. Maybe a
Nothing Kind
k1 Kind
k2))
let nm :: Name
nm = Newtype -> Name
ntName Newtype
nt
src :: ConstraintSource
src = Name -> ConstraintSource
CtPartialTypeFun Name
nm
(Type -> InferM Goal) -> [Type] -> InferM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ConstraintSource -> Type -> InferM Goal
newGoal ConstraintSource
src) (Newtype -> [Type]
ntConstraints Newtype
nt)
(t, Type) -> InferM (t, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (t
tp, TCon -> [Type] -> Type
TCon (TC -> TCon
TC (UserTC -> TC
TCNewtype (Name -> Kind -> UserTC
UserTC Name
nm Kind
k2))) [])
checkTP :: t -> TParam -> InferM (t, Type)
checkTP t
tp TParam
tp1 =
do let k1 :: Kind
k1 = t -> Kind
forall t. HasKind t => t -> Kind
kindOf t
tp
k2 :: Kind
k2 = TParam -> Kind
forall t. HasKind t => t -> Kind
kindOf TParam
tp1
Bool -> InferM () -> InferM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2) (Error -> InferM ()
recordError (Maybe TypeSource -> Kind -> Kind -> Error
KindMismatch Maybe TypeSource
forall a. Maybe a
Nothing Kind
k1 Kind
k2))
(t, Type) -> InferM (t, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (t
tp, TVar -> Type
TVar (TParam -> TVar
TVBound TParam
tp1))
checkValParams :: Module ->
Map TParam Type ->
Module ->
InferM (Map Name Expr)
checkValParams :: Module -> Map TParam Type -> Module -> InferM (Map Name Expr)
checkValParams Module
func Map TParam Type
tMap Module
inst =
[(Name, Expr)] -> Map Name Expr
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Name, Expr)] -> Map Name Expr)
-> InferM [(Name, Expr)] -> InferM (Map Name Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ModVParam -> InferM (Name, Expr))
-> [ModVParam] -> InferM [(Name, Expr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ModVParam -> InferM (Name, Expr)
checkParam (Map Name ModVParam -> [ModVParam]
forall k a. Map k a -> [a]
Map.elems (Module -> Map Name ModVParam
mParamFuns Module
func))
where
valMap :: Map Ident (Name, Schema)
valMap = [(Ident, (Name, Schema))] -> Map Ident (Name, Schema)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Ident, (Name, Schema))]
defByParam [(Ident, (Name, Schema))]
-> [(Ident, (Name, Schema))] -> [(Ident, (Name, Schema))]
forall a. [a] -> [a] -> [a]
++ [(Ident, (Name, Schema))]
defByDef)
defByDef :: [(Ident, (Name, Schema))]
defByDef = [ (Name -> Ident
nameIdent (Decl -> Name
dName Decl
d), (Decl -> Name
dName Decl
d, Decl -> Schema
dSignature Decl
d))
| DeclGroup
dg <- Module -> [DeclGroup]
mDecls Module
inst, Decl
d <- DeclGroup -> [Decl]
groupDecls DeclGroup
dg ]
defByParam :: [(Ident, (Name, Schema))]
defByParam = [ (Name -> Ident
nameIdent Name
x, (Name
x, ModVParam -> Schema
mvpType ModVParam
s)) |
(Name
x,ModVParam
s) <- Map Name ModVParam -> [(Name, ModVParam)]
forall k a. Map k a -> [(k, a)]
Map.toList (Module -> Map Name ModVParam
mParamFuns Module
inst) ]
su :: Subst
su = [(TParam, Type)] -> Subst
listParamSubst (Map TParam Type -> [(TParam, Type)]
forall k a. Map k a -> [(k, a)]
Map.toList Map TParam Type
tMap)
checkParam :: ModVParam -> InferM (Name, Expr)
checkParam ModVParam
pr =
let x :: Name
x = ModVParam -> Name
mvpName ModVParam
pr
sP :: Schema
sP = ModVParam -> Schema
mvpType ModVParam
pr
in
case Ident -> Map Ident (Name, Schema) -> Maybe (Name, Schema)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Name -> Ident
nameIdent Name
x) Map Ident (Name, Schema)
valMap of
Just (Name
n,Schema
sD) -> do Expr
e <- Name -> Schema -> Schema -> InferM Expr
makeValParamDef Name
n Schema
sD (Subst -> Schema -> Schema
forall t. TVars t => Subst -> t -> t
apSubst Subst
su Schema
sP)
(Name, Expr) -> InferM (Name, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
x,Expr
e)
Maybe (Name, Schema)
Nothing -> do Error -> InferM ()
recordError (Located Ident -> Error
MissingModVParam
Located :: forall a. Range -> a -> Located a
Located { thing :: Ident
thing = Name -> Ident
nameIdent Name
x
, srcRange :: Range
srcRange = Name -> Range
nameLoc Name
x })
(Name, Expr) -> InferM (Name, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
x, String -> [String] -> Expr
forall a. HasCallStack => String -> [String] -> a
panic String
"checkValParams" [String
"Should not use this"])
makeValParamDef :: Name ->
Schema ->
Schema ->
InferM Expr
makeValParamDef :: Name -> Schema -> Schema -> InferM Expr
makeValParamDef Name
x Schema
sDef Schema
pDef =
Name -> Schema -> InferM Expr -> InferM Expr
forall a. Name -> Schema -> InferM a -> InferM a
withVar Name
x Schema
sDef (InferM Expr -> InferM Expr) -> InferM Expr -> InferM Expr
forall a b. (a -> b) -> a -> b
$ do ~(DExpr Expr
e) <- Decl -> DeclDef
dDefinition (Decl -> DeclDef) -> InferM Decl -> InferM DeclDef
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bind Name -> (Schema, [Goal]) -> InferM Decl
checkSigB Bind Name
bnd (Schema
pDef,[])
Expr -> InferM Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e
where
bnd :: Bind Name
bnd = Bind :: forall name.
Located name
-> [Pattern name]
-> Located (BindDef name)
-> Maybe (Schema name)
-> Bool
-> Maybe Fixity
-> [Pragma]
-> Bool
-> Maybe Text
-> Bind name
P.Bind { bName :: Located Name
P.bName = Name -> Located Name
forall a. a -> Located a
loc Name
x
, bParams :: [Pattern Name]
P.bParams = []
, bDef :: Located (BindDef Name)
P.bDef = BindDef Name -> Located (BindDef Name)
forall a. a -> Located a
loc (Expr Name -> BindDef Name
forall name. Expr name -> BindDef name
P.DExpr (Name -> Expr Name
forall n. n -> Expr n
P.EVar Name
x))
, bSignature :: Maybe (Schema Name)
P.bSignature = Maybe (Schema Name)
forall a. Maybe a
Nothing
, bInfix :: Bool
P.bInfix = Bool
False
, bFixity :: Maybe Fixity
P.bFixity = Maybe Fixity
forall a. Maybe a
Nothing
, bPragmas :: [Pragma]
P.bPragmas = []
, bMono :: Bool
P.bMono = Bool
False
, bDoc :: Maybe Text
P.bDoc = Maybe Text
forall a. Maybe a
Nothing
}
loc :: a -> Located a
loc a
a = Located :: forall a. Range -> a -> Located a
P.Located { srcRange :: Range
P.srcRange = Name -> Range
nameLoc Name
x, thing :: a
P.thing = a
a }