{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.KnownNat.Solver (plugin) where
import Control.Arrow ((&&&), first)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra (lookupModule, lookupName, newWanted,
tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens, mkSubst', substType)
#endif
import GHC.TypeLits.Normalise.SOP (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)
import Class (Class, classMethods, className, classTyCon)
import FamInst (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id (idType)
import InstEnv (instanceDFunId,lookupUniqueInstEnv)
#if MIN_VERSION_ghc(8,5,0)
import MkCore (mkNaturalExpr)
#endif
import Module (mkModuleName, moduleName, moduleNameString)
import Name (nameModule_maybe, nameOccName)
import OccName (mkTcOcc, occNameString)
import Plugins (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
#endif
import PrelNames (knownNatClassName)
#if MIN_VERSION_ghc(8,5,0)
import TcEvidence (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
#else
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
#endif
#if MIN_VERSION_ghc(8,5,0)
import TcPluginM (unsafeTcPluginTcM)
#endif
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM (zonkCt)
#endif
import TcPluginM (TcPluginM, tcLookupClass, getInstEnvs)
import TcRnTypes (Ct, TcPlugin(..), TcPluginResult (..), ctEvidence, ctEvLoc,
#if MIN_VERSION_ghc(8,5,0)
ctEvPred, ctEvExpr, ctLoc, ctLocSpan, isWanted,
#else
ctEvPred, ctEvTerm, ctLoc, ctLocSpan, isWanted,
#endif
mkNonCanonical, setCtLoc, setCtLocSpan)
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
import Type
(EqRel (NomEq), PredTree (ClassPred,EqPred), PredType, classifyPredType,
dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe)
import TyCon (tyConName)
import TyCoRep (Type (..), TyLit (..))
import Var (DFunId)
type KnownNatDefs = Int -> Maybe Class
type KnConstraint = (Ct
,Class
,Type
)
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = const $ Just normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
, pluginRecompile = purePlugin
#endif
}
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-knownnat"
TcPlugin { tcPluginInit = lookupKnownNatDefs
, tcPluginSolve = solveKnownNat
, tcPluginStop = const (return ())
}
solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
-> TcPluginM TcPluginResult
solveKnownNat _defs _givens _deriveds [] = return (TcPluginOk [] [])
solveKnownNat defs givens _deriveds wanteds = do
-- GHC 7.10 puts deriveds with the wanteds, so filter them out
let wanteds' = filter (isWanted . ctEvidence) wanteds
#if MIN_VERSION_ghc(8,4,0)
subst = map fst
$ mkSubst' givens
kn_wanteds = map (\(x,y,z) -> (x,y,substType subst z))
$ mapMaybe toKnConstraint wanteds'
#else
kn_wanteds = mapMaybe toKnConstraint wanteds'
#endif
case kn_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
#if MIN_VERSION_ghc(8,4,0)
let given_map = map toGivenEntry (flattenGivens givens)
#else
given_map <- mapM (fmap toGivenEntry . zonkCt) givens
#endif
(solved,new) <- (unzip . catMaybes) <$> (mapM (constraintToEvTerm defs given_map) kn_wanteds)
return (TcPluginOk solved (concat new))
toKnConstraint :: Ct -> Maybe KnConstraint
toKnConstraint ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
ClassPred cls [ty]
| className cls == knownNatClassName
-> Just (ct,cls,ty)
_ -> Nothing
#if MIN_VERSION_ghc(8,5,0)
toGivenEntry :: Ct -> (CType,EvExpr)
#else
toGivenEntry :: Ct -> (CType,EvTerm)
#endif
toGivenEntry ct = let ct_ev = ctEvidence ct
c_ty = ctEvPred ct_ev
#if MIN_VERSION_ghc(8,5,0)
ev = ctEvExpr ct_ev
#else
ev = ctEvTerm ct_ev
#endif
in (CType c_ty,ev)
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
md <- lookupModule myModule myPackage
kn1C <- look md "KnownNat1"
kn2C <- look md "KnownNat2"
kn3C <- look md "KnownNat3"
return $ (\case { 1 -> Just kn1C
; 2 -> Just kn2C
; 3 -> Just kn3C
; _ -> Nothing
})
where
look md s = do
nm <- lookupName md (mkTcOcc s)
tcLookupClass nm
myModule = mkModuleName "GHC.TypeLits.KnownNat"
myPackage = fsLit "ghc-typelits-knownnat"
constraintToEvTerm
:: KnownNatDefs
#if MIN_VERSION_ghc(8,5,0)
-> [(CType,EvExpr)]
#else
-> [(CType,EvTerm)]
#endif
-> KnConstraint
-> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
constraintToEvTerm defs givens (ct,cls,op) = do
offsetM <- offset op
evM <- case offsetM of
found@Just {} -> return found
_ -> go op
return ((first (,ct)) <$> evM)
where
go :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
go (go_other -> Just ev) = return (Just (ev,[]))
go ty@(TyConApp tc args)
| let tcNm = tyConName tc
, Just m <- nameModule_maybe tcNm
, Just knN_cls <- defs (length args)
= do let mS = moduleNameString (moduleName m)
tcS = occNameString (nameOccName tcNm)
fn = mkStrLitTy (fsLit (mS ++ "." ++ tcS))
args' = fn:args
ienv <- getInstEnvs
case lookupUniqueInstEnv ienv knN_cls args' of
Right (inst, _) -> do
let df_id = instanceDFunId inst
df = (knN_cls,df_id)
df_args = fst
. splitFunTys
. (`piResultTys` args)
$ idType df_id
(evs,new) <- unzip <$> mapM go_arg df_args
return ((,concat new) <$> makeOpDict df cls args' op evs)
_ -> return ((,[]) <$> go_other ty)
go (LitTy (NumTyLit i))
| LitTy _ <- op
= return Nothing
| otherwise
#if MIN_VERSION_ghc(8,5,0)
= (fmap (,[])) <$> makeLitDict cls op i
#else
= return ((,[]) <$> makeLitDict cls op i)
#endif
go _ = return Nothing
#if MIN_VERSION_ghc(8,5,0)
go_arg :: PredType -> TcPluginM (EvExpr,[Ct])
#else
go_arg :: PredType -> TcPluginM (EvTerm,[Ct])
#endif
go_arg ty = case lookup (CType ty) givens of
Just ev -> return (ev,[])
_ -> do
(ev,wanted) <- makeWantedEv ct ty
return (ev,[wanted])
go_other :: Type -> Maybe EvTerm
go_other ty =
let knClsTc = classTyCon cls
kn = mkTyConApp knClsTc [ty]
cast = if CType ty == CType op
#if MIN_VERSION_ghc(8,6,0)
then Just . EvExpr
#else
then Just
#endif
else makeKnCoercion cls ty op
in cast =<< lookup (CType kn) givens
offset :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
offset want = runMaybeT $ do
let
unKn ty' = case classifyPredType ty' of
ClassPred cls' [ty'']
| className cls' == knownNatClassName
-> Just ty''
_ -> Nothing
unEq ty' = case classifyPredType ty' of
EqPred NomEq ty1 ty2 -> Just (ty1,ty2)
_ -> Nothing
rewrites = mapMaybe (unEq . unCType . fst) givens
rewriteTy tyK (ty1,ty2) | ty1 `eqType` tyK = Just ty2
| ty2 `eqType` tyK = Just ty1
| otherwise = Nothing
knowns = mapMaybe (unKn . unCType . fst) givens
knownsR = catMaybes $ concatMap (\t -> map (rewriteTy t) rewrites) knowns
subWant = mkTyConApp typeNatSubTyCon . (:[want])
exploded = map (fst . runWriter . normaliseNat . subWant &&& id)
(knowns ++ knownsR)
examineDiff (S [P [I n]]) entire = Just (entire,I n)
examineDiff (S [P [V v]]) entire = Just (entire,V v)
examineDiff _ _ = Nothing
interesting = mapMaybe (uncurry examineDiff) exploded
((h,corr):_) <- pure interesting
let x = case corr of
I 0 -> h
I i | i < 0 -> mkTyConApp typeNatAddTyCon [h,mkNumLitTy (negate i)]
| otherwise -> mkTyConApp typeNatSubTyCon [h,mkNumLitTy i]
_ -> mkTyConApp typeNatSubTyCon [h,reifySOP (S [P [corr]])]
MaybeT (go x)
makeWantedEv
:: Ct
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> TcPluginM (EvExpr,Ct)
#else
-> TcPluginM (EvTerm,Ct)
#endif
makeWantedEv ct ty = do
wantedCtEv <- newWanted (ctLoc ct) ty
#if MIN_VERSION_ghc(8,5,0)
let ev = ctEvExpr wantedCtEv
#else
let ev = ctEvTerm wantedCtEv
#endif
wanted = mkNonCanonical wantedCtEv
ct_ls = ctLocSpan (ctLoc ct)
ctl = ctEvLoc wantedCtEv
wanted' = setCtLoc wanted (setCtLocSpan ctl ct_ls)
return (ev,wanted')
makeOpDict :: (Class,DFunId)
-> Class
-> [Type]
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> [EvExpr]
#else
-> [EvTerm]
#endif
-> Maybe EvTerm
makeOpDict (opCls,dfid) knCls tyArgs z evArgs
| Just (_, kn_co_dict) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, op_co_dict) <- tcInstNewTyCon_maybe (classTyCon opCls) tyArgs
, [ op_meth ] <- classMethods opCls
, Just (op_tcRep,op_args) <- splitTyConApp_maybe
$ funResultTy
$ (`piResultTys` tyArgs)
$ idType op_meth
, Just (_, op_co_rep) <- tcInstNewTyCon_maybe op_tcRep op_args
#if MIN_VERSION_ghc(8,5,0)
, let EvExpr dfun_inst = evDFunApp dfid (tail tyArgs) evArgs
#else
, let dfun_inst = EvDFunApp dfid (tail tyArgs) evArgs
#endif
op_to_kn = mkTcTransCo (mkTcTransCo op_co_dict op_co_rep)
(mkTcSymCo (mkTcTransCo kn_co_dict kn_co_rep))
ev_tm = mkEvCast dfun_inst op_to_kn
= Just ev_tm
| otherwise
= Nothing
makeKnCoercion :: Class
-> Type
-> Type
#if MIN_VERSION_ghc(8,5,0)
-> EvExpr
#else
-> EvTerm
#endif
-> Maybe EvTerm
makeKnCoercion knCls x z xEv
| Just (_, kn_co_dict_z) <- tcInstNewTyCon_maybe (classTyCon knCls) [z]
, [ kn_meth ] <- classMethods knCls
, Just kn_tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType kn_meth
, Just (_, kn_co_rep_z) <- tcInstNewTyCon_maybe kn_tcRep [z]
, Just (_, kn_co_rep_x) <- tcInstNewTyCon_maybe kn_tcRep [x]
, Just (_, kn_co_dict_x) <- tcInstNewTyCon_maybe (classTyCon knCls) [x]
= Just . mkEvCast xEv $ (kn_co_dict_x `mkTcTransCo` kn_co_rep_x) `mkTcTransCo` mkTcSymCo (kn_co_dict_z `mkTcTransCo` kn_co_rep_z)
| otherwise = Nothing
#if MIN_VERSION_ghc(8,5,0)
makeLitDict :: Class -> Type -> Integer -> TcPluginM (Maybe EvTerm)
#else
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
#endif
makeLitDict clas ty i
| Just (_, co_dict) <- tcInstNewTyCon_maybe (classTyCon clas) [ty]
, [ meth ] <- classMethods clas
, Just tcRep <- tyConAppTyCon_maybe
$ funResultTy
$ dropForAlls
$ idType meth
, Just (_, co_rep) <- tcInstNewTyCon_maybe tcRep [ty]
#if MIN_VERSION_ghc(8,5,0)
= do
et <- unsafeTcPluginTcM (mkNaturalExpr i)
let ev_tm = mkEvCast et (mkTcSymCo (mkTcTransCo co_dict co_rep))
return (Just ev_tm)
| otherwise
= return Nothing
#else
, let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
= Just ev_tm
| otherwise
= Nothing
#endif