module Control.Super.Plugin.Detect
(
ModuleQuery(..)
, findModuleByQuery
, findModule
, defaultFindEitherModuleErrMsg
, ClassQuery(..)
, isOptionalClassQuery
, queriedClasses, moduleQueryOf
, findClassesByQuery
, findClass
, isClass
, findClassesAndInstancesInScope
, findMonoTopTyConInstances
, InstanceImplication
, (===>), (<==>)
, clsDictInstImp, clsDictInstEquiv
, checkInstanceImplications
, checkInstances
) where
import Data.List ( find )
import Data.Either ( isLeft, isRight )
import Data.Maybe ( isNothing, maybeToList, catMaybes, fromMaybe )
import Control.Monad ( forM )
import BasicTypes ( Arity )
import TcRnTypes
( TcGblEnv(..)
, ImportAvails( imp_mods ) )
import TyCon ( TyCon )
import TcPluginM
( TcPluginM
, getEnvs, getInstEnvs )
import Name
( nameModule
, getOccName )
import OccName
( occNameString )
import Module
( Module, ModuleName
, moduleName
, moduleEnvKeys
, mkModuleName )
import Class
( Class(..)
, className, classArity )
import InstEnv
( ClsInst(..)
, instEnvElts
, ie_global
, classInstances )
import PrelNames ( mAIN_NAME )
import Outputable ( SDoc, ($$), text, vcat, ppr, hang )
import qualified Outputable as O
import Control.Super.Plugin.Collection.Set ( Set )
import qualified Control.Super.Plugin.Collection.Set as S
import qualified Control.Super.Plugin.Collection.Map as M
import Control.Super.Plugin.Wrapper
( UnitId, moduleUnitId )
import Control.Super.Plugin.Instance
( instanceTopTyCons
, isMonoTyConInstance
, isPolyTyConInstance )
import Control.Super.Plugin.Utils
( errIndent
, removeDupByIndexEq
, fromRight, fromLeft
, getClassName, getTyConName )
import Control.Super.Plugin.ClassDict
( ClassDict, Optional
, allClsDictEntries
, lookupClsDictClass )
import Control.Super.Plugin.InstanceDict
( InstanceDict
, insertInstDict, emptyInstDict
, allInstDictTyCons
, lookupInstDictByTyCon )
import Control.Super.Plugin.Names
checkInstances
:: ClassDict
-> InstanceDict
-> [InstanceImplication]
-> [(Either (TyCon, Class) ClsInst, SDoc)]
checkInstances clsDict instDict instImplications =
monoCheckErrMsgs ++ polyCheckErrMsgs
where
monoCheckErrMsgs :: [(Either (TyCon, Class) ClsInst, SDoc)]
monoCheckErrMsgs = fmap (\(tc, msg) -> (Left tc, msg))
$ removeDupByIndexEq
$ checkInstanceImplications instDict
$ instImplications
polyCheckErrMsgs :: [(Either (TyCon, Class) ClsInst, SDoc)]
polyCheckErrMsgs = do
(_opt, mClsInsts) <- allClsDictEntries clsDict
case mClsInsts of
Just (cls, insts) -> do
polyInst <- filter (isPolyTyConInstance cls) insts
return (Right polyInst, text "Instance involves more then one top-level type constructor: " $$ ppr polyInst)
Nothing -> []
data ModuleQuery
= ThisModule PluginModuleName (Maybe UnitId)
| EitherModule [ModuleQuery] (Maybe ([Either SDoc Module] -> SDoc))
| AnyModule [ModuleQuery]
instance O.Outputable ModuleQuery where
ppr (ThisModule mdlName mUnitId) = O.text mdlName O.<> (maybe (O.text "") O.ppr $ mUnitId)
ppr (EitherModule mdlQueries _errF) = O.text "XOR " O.<> (O.brackets $ O.hcat $ O.punctuate (O.text ", ") $ fmap O.ppr mdlQueries)
ppr (AnyModule mdlQueries) = O.text "OR " O.<> (O.brackets $ O.hcat $ O.punctuate (O.text ", ") $ fmap O.ppr mdlQueries)
collectModuleNames :: ModuleQuery -> Set ModuleName
collectModuleNames (ThisModule name _) = S.singleton $ mkModuleName name
collectModuleNames (EitherModule qs _) = S.unions $ fmap collectModuleNames qs
collectModuleNames (AnyModule qs) = S.unions $ fmap collectModuleNames qs
isModuleInQuery :: ModuleQuery -> Module -> Bool
isModuleInQuery query mdl = S.member (moduleName mdl)
$ S.insert mAIN_NAME
$ collectModuleNames query
findModuleByQuery :: ModuleQuery -> TcPluginM (Either SDoc Module)
findModuleByQuery (ThisModule mdlName mdlUnit) = findModule mdlUnit mdlName
findModuleByQuery (EitherModule queries mErrFun) = do
queryResults <- forM queries findModuleByQuery
return $ findEitherModule mErrFun queryResults
findModuleByQuery (AnyModule queries) = do
queryResults <- forM queries findModuleByQuery
return $ findAnyModule Nothing queryResults
findModule :: Maybe UnitId -> String -> TcPluginM (Either SDoc Module)
findModule pkgKeyToFind mdlNameToFind = do
(gblEnv, _lclEnv) <- getEnvs
let mdls = moduleEnvKeys $ imp_mods $ tcg_imports $ gblEnv
case find (isModule . splitModule) mdls of
Just mdl -> return $ Right mdl
Nothing -> return $ Left $ text $ "Could not find module '" ++ mdlNameToFind ++ "'"
where
isModule :: (UnitId, ModuleName) -> Bool
isModule (pkgKey, mdlName)
= maybe True (pkgKey ==) pkgKeyToFind
&& mdlName == mkModuleName mdlNameToFind
splitModule :: Module -> (UnitId, ModuleName)
splitModule mdl = (moduleUnitId mdl, moduleName mdl)
findEitherModule :: Maybe ([Either SDoc Module] -> SDoc) -> [Either SDoc Module] -> (Either SDoc Module)
findEitherModule mErrFun eMdls =
case fmap fromRight $ filter isRight eMdls of
[] -> Left $ fromMaybe defaultFindEitherModuleErrMsg mErrFun $ eMdls
[mdl] -> Right mdl
_ -> Left $ fromMaybe defaultFindEitherModuleErrMsg mErrFun $ eMdls
findAnyModule :: Maybe ([SDoc] -> SDoc) -> [Either SDoc Module] -> (Either SDoc Module)
findAnyModule mErrFun eMdls =
case fmap fromRight $ filter isRight eMdls of
[] -> Left $ fromMaybe defaultFindAnyModuleErrMsg mErrFun $ fmap fromLeft eMdls
(mdl : _) -> Right mdl
defaultFindEitherModuleErrMsg :: [Either SDoc Module] -> SDoc
defaultFindEitherModuleErrMsg mdls = case found of
[] -> hang (text "Failed to find either module!") errIndent $ vcat notFound
_ -> hang (text "Found several modules, unclear which one to use:") errIndent $ vcat $ fmap ppr found
where
found = fmap fromRight $ filter isRight mdls
notFound = fmap fromLeft $ filter isLeft mdls
defaultFindAnyModuleErrMsg :: [SDoc] -> SDoc
defaultFindAnyModuleErrMsg mdlErrs = hang (text "Could not find any of the modules!") errIndent $ vcat mdlErrs
data ClassQuery = ClassQuery Optional ModuleQuery [(PluginClassName, Arity)]
instance O.Outputable ClassQuery where
ppr (ClassQuery opt mdlQuery clsNames)
= O.hang (O.text "In module:") errIndent (O.ppr mdlQuery)
O.<> O.hang (O.text $ (if opt then "optionally " else "") ++ "find classes:") errIndent (O.ppr clsNames)
isOptionalClassQuery :: ClassQuery -> Bool
isOptionalClassQuery (ClassQuery opt _mdlQ _clss) = opt
queriedClasses :: ClassQuery -> [PluginClassName]
queriedClasses (ClassQuery _opt _mdlQ clss) = fmap fst clss
moduleQueryOf :: ClassQuery -> ModuleQuery
moduleQueryOf (ClassQuery _opt mdlQ _clss) = mdlQ
findClassesByQuery :: ClassQuery -> TcPluginM (Either SDoc [(PluginClassName, Class)])
findClassesByQuery (ClassQuery opt mdlQuery toFindCls) = do
eClss <- forM toFindCls $ \(clsName, clsArity) -> do
eCls <- findClass (isClass (isModuleInQuery mdlQuery) clsName clsArity)
return (clsName, eCls, clsArity)
let notFound = filter (\(_, c, _) -> isNothing c) eClss
let errMsg :: (PluginClassName, Maybe Class, Arity) -> SDoc
errMsg (n, _, a) = text $ "Could not find class '" ++ n ++ "' with arity " ++ show a ++ "!"
return $ case notFound of
[] -> Right $ fmap (\(n, Just c, _) -> (n, c)) $ eClss
_ | opt -> Right []
_ -> Left $ vcat $ fmap errMsg notFound
findClass :: (Class -> Bool) -> TcPluginM (Maybe Class)
findClass isCls' = do
let isCls = isCls' . is_cls
envs <- fst <$> getEnvs
let foundInstsLcl = (filter isCls . instEnvElts . tcg_inst_env $ envs)
++ (filter isCls . tcg_insts $ envs)
foundInstsGbl <- filter isCls . instEnvElts . ie_global <$> getInstEnvs
return $ case foundInstsLcl ++ foundInstsGbl of
(inst : _) -> Just $ is_cls inst
[] -> Nothing
isClass :: (Module -> Bool) -> String -> Arity -> Class -> Bool
isClass isModule targetClassName targetArity cls =
let clsName = className cls
clsMdl = nameModule clsName
clsNameStr = occNameString $ getOccName clsName
clsArity = classArity cls
in isModule clsMdl
&& clsNameStr == targetClassName
&& clsArity == targetArity
findClassesAndInstancesInScope :: ClassQuery -> TcPluginM (Either SDoc [(PluginClassName, Class, [ClsInst])])
findClassesAndInstancesInScope clsQuery = do
eClss <- findClassesByQuery clsQuery
case eClss of
Left err -> return $ Left err
Right clss -> fmap Right $ forM clss $ \(n, c) -> do
insts <- findInstancesInScope c
return (n, c, insts)
findInstancesInScope :: Class -> TcPluginM [ClsInst]
findInstancesInScope cls = do
instEnvs <- TcPluginM.getInstEnvs
return $ classInstances instEnvs cls
findMonoTopTyConInstances
:: ClassDict
-> InstanceDict
findMonoTopTyConInstances clsDict =
mconcat $ do
tc <- supermonadTyCons
(cls, insts) <- dictEntries
return $ findMonoClassInstance tc cls insts
where
dictEntries :: [(Class, [ClsInst])]
dictEntries = catMaybes $ fmap snd $ allClsDictEntries clsDict
supermonadTyCons :: [TyCon]
supermonadTyCons = S.toList
$ S.unions
$ fmap instanceTopTyCons
$ concat $ fmap snd dictEntries
findMonoClassInstance :: TyCon -> Class -> [ClsInst] -> InstanceDict
findMonoClassInstance tc cls insts =
case filter (isMonoTyConInstance tc cls) insts of
[foundInst] -> insertInstDict tc cls foundInst $ emptyInstDict
_ -> emptyInstDict
data InstanceImplication = InstanceImplies Class Class
instance O.Outputable InstanceImplication where
ppr (InstanceImplies ca cb) = O.text (getClassName ca) O.<> O.text " ===> " O.<> O.text (getClassName cb)
infix 7 ===>
infix 7 <==>
(===>) :: Class -> Class -> [InstanceImplication]
(===>) ca cb = [InstanceImplies ca cb]
(<==>) :: Class -> Class -> [InstanceImplication]
(<==>) ca cb = ca ===> cb ++ cb ===> ca
clsDictInstImp :: ClassDict -> PluginClassName -> PluginClassName -> [InstanceImplication]
clsDictInstImp clsDict caName cbName = do
clsA <- maybeToList $ lookupClsDictClass caName clsDict
clsB <- maybeToList $ lookupClsDictClass cbName clsDict
clsA ===> clsB
clsDictInstEquiv :: ClassDict -> PluginClassName -> PluginClassName -> [InstanceImplication]
clsDictInstEquiv clsDict caName cbName = do
clsA <- maybeToList $ lookupClsDictClass caName clsDict
clsB <- maybeToList $ lookupClsDictClass cbName clsDict
clsA <==> clsB
checkInstanceImplications :: InstanceDict -> [InstanceImplication] -> [((TyCon,Class), SDoc)]
checkInstanceImplications _instDict [] = []
checkInstanceImplications instDict (imp : imps) = do
tc <- S.toList $ allInstDictTyCons instDict
let tcDict = lookupInstDictByTyCon tc instDict
case imp of
InstanceImplies ca cb -> case (M.member ca tcDict, M.member cb tcDict) of
(False, _ ) -> rest
(True , True ) -> rest
(True , False) ->
let errMsg = text $ "There is no unique instance of '" ++ getClassName cb ++ "' for the type '" ++ getTyConName tc ++ "'!"
in ((tc,cb), errMsg) : rest
where
rest = checkInstanceImplications instDict imps