{-# Language OverloadedStrings #-}
module Cryptol.TypeCheck.Instantiate (instantiateWith) where
import Cryptol.ModuleSystem.Name (nameIdent)
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Monad
import Cryptol.TypeCheck.Subst (listParamSubst, apSubst)
import Cryptol.TypeCheck.Error
import Cryptol.Parser.Position (Located(..))
import Cryptol.Utils.Ident (Ident)
import Cryptol.Utils.Panic(panic)
import Control.Monad(zipWithM)
import Data.Function (on)
import Data.List(sortBy, groupBy, find)
import Data.Maybe(mapMaybe,isJust)
import Data.Either(partitionEithers)
import qualified Data.Set as Set
instantiateWith :: Name -> Expr -> Schema -> [Located (Maybe Ident,Type)]
-> InferM (Expr,Type)
instantiateWith nm e s ts
| null named = instantiateWithPos nm e s positional
| null positional = instantiateWithNames nm e s named
| otherwise = do recordError CannotMixPositionalAndNamedTypeParams
instantiateWithNames nm e s named
where
(named,positional) = partitionEithers (map classify ts)
classify t = case thing t of
(Just n,ty) -> Left t { thing = (n,ty) }
(Nothing,ty) -> Right ty
instantiateWithPos :: Name -> Expr -> Schema -> [Type] -> InferM (Expr,Type)
instantiateWithPos nm e (Forall as ps t) ts =
do su <- makeSu (1::Int) [] as ts
doInst su e ps t
where
isNamed q = isJust (tpName q)
makeSu n su (q : qs) (ty : tys)
| not (isNamed q) = do r <- unnamed n q
makeSu (n+1) (r : su) qs (ty : tys)
| k1 == k2 = makeSu (n+1) ((q, ty) : su) qs tys
| otherwise = do recordError (KindMismatch k1 k2)
r <- unnamed n q
makeSu (n+1) (r : su) qs tys
where k1 = kindOf q
k2 = kindOf ty
makeSu _ su [] [] = return (reverse su)
makeSu n su (q : qs) [] = do r <- unnamed n q
makeSu (n+1) (r : su) qs []
makeSu _ su [] _ = do recordError TooManyPositionalTypeParams
return (reverse su)
unnamed n q = do ty <- newType src (kindOf q)
return (q, ty)
where
src = case drop (n-1) as of
p:_ ->
case tpFlav p of
TPOther (Just a) -> TypeParamInstNamed nm (nameIdent a)
_ -> TypeParamInstPos nm n
_ -> panic "instantiateWithPos"
[ "Invalid parameter index", show n, show as ]
instantiateWithNames :: Name -> Expr -> Schema -> [Located (Ident,Type)]
-> InferM (Expr,Type)
instantiateWithNames nm e (Forall as ps t) xs =
do sequence_ repeatedParams
mapM_ (recordError . UndefinedTypeParameter . fmap fst) undefParams
su' <- zipWithM paramInst [ 1.. ] as
doInst su' e ps t
where
paramInst n x =
do let k = tpKind x
lkp name = find (\th -> fst (thing th) == nameIdent name) xs
src = case tpName x of
Just na -> TypeParamInstNamed nm (nameIdent na)
Nothing -> TypeParamInstPos nm n
ty <- case lkp =<< tpName x of
Just lty
| k1 == k -> return ty
| otherwise -> inRange (srcRange lty) $
do recordError (KindMismatch k k1)
newType src k
where ty = snd (thing lty)
k1 = kindOf ty
Nothing -> newType src k
return (x, ty)
repeatedParams = mapMaybe isRepeated
$ groupBy ((==) `on` pName)
$ sortBy (compare `on` pName) xs
isRepeated ys@(a : _ : _) =
Just $ recordError (RepeatedTypeParameter (fst (thing a)) (map srcRange ys))
isRepeated _ = Nothing
paramIdents = [ nameIdent n | Just n <- map tpName as ]
undefParams = [ x | x <- xs, pName x `notElem` paramIdents ]
pName = fst . thing
doInst :: [(TParam, Type)] -> Expr -> [Prop] -> Type -> InferM (Expr,Type)
doInst su' e ps t =
do let su = listParamSubst su'
newGoals (CtInst e) (map (apSubst su) ps)
let t1 = apSubst su t
ps' <- concat <$> mapM checkInst su'
newGoals (CtInst e) ps'
return ( addProofParams (addTyParams (map snd su') e), t1 )
where
addTyParams ts e1 = foldl ETApp e1 ts
addProofParams e1 = foldl (\e2 _ -> EProofApp e2) e1 ps
frees = Set.unions (map fvs (t : ps))
bounds = Set.unions (map scope (Set.toList frees))
where
scope (TVFree _ _ vs _) = vs
scope (TVBound _) = Set.empty
checkInst :: (TParam, Type) -> InferM [Prop]
checkInst (tp, ty)
| Set.notMember tp bounds = return []
| otherwise = unify (TVar (tpVar tp)) ty