module Control.Super.Plugin.Environment
(
SupermonadPluginM
, runSupermonadPlugin
, runSupermonadPluginAndReturn
, runTcPlugin
, getGivenConstraints, getWantedConstraints
, getInstEnvs
, getClassDictionary
, getClass
, isOptionalClass
, getCustomState, putCustomState, modifyCustomState
, getInstanceFor
, addTypeEqualities, addTypeEquality
, addTyVarEqualities, addTyVarEquality
, getTypeEqualities, getTyVarEqualities
, whenNoResults
, addWarning, displayWarnings
, throwPluginError, throwPluginErrorSDoc, catchPluginError
, assert, assertM
, printErr, printMsg, printObj, printWarn
, printConstraints
) where
import Data.List ( groupBy )
import Control.Monad ( when, unless, forM_ )
import Control.Monad.Reader ( ReaderT, runReaderT, asks )
import Control.Monad.State ( StateT , runStateT , gets, modify )
import Control.Monad.Except ( ExceptT, runExceptT, throwError, catchError )
import Control.Monad.Trans.Class ( lift )
import Class ( Class )
import InstEnv ( InstEnvs, ClsInst )
import Type ( TyVar, Type )
import TyCon ( TyCon )
import TcRnTypes ( Ct, TcPluginResult(..) )
import TcPluginM ( TcPluginM, tcPluginIO )
import qualified TcPluginM
import Outputable ( Outputable )
import SrcLoc ( srcSpanFileName_maybe )
import FastString ( unpackFS )
import qualified Outputable as O
import qualified Control.Super.Plugin.Log as L
import Control.Super.Plugin.Names ( PluginClassName )
import Control.Super.Plugin.Constraint
( GivenCt, WantedCt
, constraintSourceLocation
, mkDerivedTypeEqCt, mkDerivedTypeEqCtOfTypes )
import Control.Super.Plugin.ClassDict
( ClassDict
, Optional
, emptyClsDict
, lookupClsDictClass )
import qualified Control.Super.Plugin.ClassDict as ClsD
import Control.Super.Plugin.InstanceDict
( InstanceDict, lookupInstDict )
type SupermonadError = O.SDoc
type SupermonadPluginM s = ReaderT SupermonadPluginEnv
( StateT (SupermonadPluginState s)
( ExceptT SupermonadError TcPluginM
) )
data SupermonadPluginEnv = SupermonadPluginEnv
{ smEnvGivenConstraints :: [GivenCt]
, smEnvWantedConstraints :: [WantedCt]
, smEnvClassDictionary :: ClassDict
}
data SupermonadPluginState s = SupermonadPluginState
{ smStateTyVarEqualities :: [(Ct, TyVar, Type)]
, smStateTypeEqualities :: [(Ct, Type, Type)]
, smStateWarningQueue :: [(String, O.SDoc)]
, smStateCustom :: s
}
runSupermonadPluginAndReturn
:: [GivenCt]
-> [WantedCt]
-> SupermonadPluginM () (ClassDict, s)
-> SupermonadPluginM s a
-> TcPluginM TcPluginResult
runSupermonadPluginAndReturn givenCts wantedCts initStateM pluginM = do
eResult <- runSupermonadPlugin givenCts wantedCts initStateM $ do
if not $ null wantedCts then do
_ <- pluginM
tyVarEqs <- getTyVarEqualities
let tyVarEqCts = fmap (\(baseCt, tv, ty) -> mkDerivedTypeEqCt baseCt tv ty) tyVarEqs
tyEqs <- getTypeEqualities
let tyEqCts = fmap (\(baseCt, ta, tb) -> mkDerivedTypeEqCtOfTypes baseCt ta tb) tyEqs
return $ TcPluginOk [] $ tyVarEqCts ++ tyEqCts
else
return $ TcPluginOk [] []
case eResult of
Left err -> do
L.printErr $ L.sDocToStr err
return $ TcPluginOk [] []
Right solution -> return solution
runSupermonadPlugin
:: [GivenCt]
-> [WantedCt]
-> SupermonadPluginM () (ClassDict, s)
-> SupermonadPluginM s a
-> TcPluginM (Either SupermonadError a)
runSupermonadPlugin givenCts wantedCts initStateM pluginM = do
let initEnv = SupermonadPluginEnv
{ smEnvGivenConstraints = givenCts
, smEnvWantedConstraints = wantedCts
, smEnvClassDictionary = emptyClsDict
}
let initState :: SupermonadPluginState ()
initState = SupermonadPluginState
{ smStateTyVarEqualities = []
, smStateTypeEqualities = []
, smStateWarningQueue = []
, smStateCustom = ()
}
eInitResult <- runExceptT $ flip runStateT initState $ runReaderT initStateM initEnv
case eInitResult of
Left err -> return $ Left err
Right ((smDict, customState), postInitState) -> do
let env = initEnv { smEnvClassDictionary = smDict }
let
state = SupermonadPluginState
{ smStateTyVarEqualities = smStateTyVarEqualities postInitState
, smStateTypeEqualities = smStateTypeEqualities postInitState
, smStateWarningQueue = smStateWarningQueue postInitState
, smStateCustom = customState
}
eResult <- runExceptT $ flip runStateT state $ runReaderT pluginM env
return $ case eResult of
Left err -> Left err
Right (a, _res) -> Right a
runTcPlugin :: TcPluginM a -> SupermonadPluginM s a
runTcPlugin = lift . lift . lift
getClassDictionary :: SupermonadPluginM s ClassDict
getClassDictionary = asks smEnvClassDictionary
getCustomState :: SupermonadPluginM s s
getCustomState = gets smStateCustom
putCustomState :: s -> SupermonadPluginM s ()
putCustomState newS = modify (\s -> s { smStateCustom = newS })
modifyCustomState :: (s -> s) -> SupermonadPluginM s ()
modifyCustomState sf = modify (\s -> s { smStateCustom = sf (smStateCustom s) })
getClass :: PluginClassName -> SupermonadPluginM s (Maybe Class)
getClass clsName = lookupClsDictClass clsName <$> asks smEnvClassDictionary
isOptionalClass :: PluginClassName -> SupermonadPluginM s Optional
isOptionalClass clsName = ClsD.isOptionalClass clsName <$> asks smEnvClassDictionary
getGivenConstraints :: SupermonadPluginM s [GivenCt]
getGivenConstraints = asks smEnvGivenConstraints
getWantedConstraints :: SupermonadPluginM s [WantedCt]
getWantedConstraints = asks smEnvWantedConstraints
getInstEnvs :: SupermonadPluginM s InstEnvs
getInstEnvs = runTcPlugin TcPluginM.getInstEnvs
getInstanceFor :: TyCon -> Class -> SupermonadPluginM InstanceDict (Maybe ClsInst)
getInstanceFor tc cls = fmap (lookupInstDict tc cls) getCustomState
addTyVarEquality :: Ct -> TyVar -> Type -> SupermonadPluginM s ()
addTyVarEquality ct tv ty = modify $ \s -> s { smStateTyVarEqualities = (ct, tv, ty) : smStateTyVarEqualities s }
addTyVarEqualities :: [(Ct, TyVar, Type)] -> SupermonadPluginM s ()
addTyVarEqualities = mapM_ (\(ct, tv, ty) -> addTyVarEquality ct tv ty)
addTypeEquality :: Ct -> Type -> Type -> SupermonadPluginM s ()
addTypeEquality ct ta tb = modify $ \s -> s { smStateTypeEqualities = (ct, ta, tb) : smStateTypeEqualities s }
addTypeEqualities :: [(Ct, Type, Type)] -> SupermonadPluginM s ()
addTypeEqualities = mapM_ (\(ct, ta, tb) -> addTypeEquality ct ta tb)
getTyVarEqualities :: SupermonadPluginM s [(Ct, TyVar, Type)]
getTyVarEqualities = gets $ smStateTyVarEqualities
getTypeEqualities :: SupermonadPluginM s [(Ct, Type, Type)]
getTypeEqualities = gets $ smStateTypeEqualities
addWarning :: String -> O.SDoc -> SupermonadPluginM s ()
addWarning msg details = modify $ \s -> s { smStateWarningQueue = (msg, details) : smStateWarningQueue s }
whenNoResults :: SupermonadPluginM s () -> SupermonadPluginM s ()
whenNoResults m = do
tyVarEqs <- getTyVarEqualities
tyEqs <- getTypeEqualities
when (null tyVarEqs && null tyEqs) m
displayWarnings :: SupermonadPluginM s ()
displayWarnings = whenNoResults $ do
warns <- gets smStateWarningQueue
forM_ warns $ \(msg, details) -> do
printWarn msg
internalPrint $ L.smObjMsg $ L.sDocToStr details
stringToSupermonadError :: String -> SupermonadError
stringToSupermonadError = O.text
assert :: Bool -> String -> SupermonadPluginM s ()
assert cond msg = unless cond $ throwPluginError msg
assertM :: SupermonadPluginM s Bool -> String -> SupermonadPluginM s ()
assertM condM msg = do
cond <- condM
assert cond msg
throwPluginError :: String -> SupermonadPluginM s a
throwPluginError = throwError . stringToSupermonadError
throwPluginErrorSDoc :: O.SDoc -> SupermonadPluginM s a
throwPluginErrorSDoc = throwError
catchPluginError :: SupermonadPluginM s a -> (SupermonadError -> SupermonadPluginM s a) -> SupermonadPluginM s a
catchPluginError = catchError
printObj :: Outputable o => o -> SupermonadPluginM s ()
printObj = internalPrint . L.smObjMsg . L.pprToStr
printMsg :: String -> SupermonadPluginM s ()
printMsg = internalPrint . L.smDebugMsg
printErr :: String -> SupermonadPluginM s ()
printErr = internalPrint . L.smErrMsg
printWarn :: String -> SupermonadPluginM s ()
printWarn = internalPrint . L.smWarnMsg
internalPrint :: String -> SupermonadPluginM s ()
internalPrint = runTcPlugin . tcPluginIO . putStr
printFormattedObj :: String -> SupermonadPluginM s ()
printFormattedObj = internalPrint . L.smObjMsg
printConstraints :: [Ct] -> SupermonadPluginM s ()
printConstraints cts =
forM_ groupedCts $ \(file, ctGroup) -> do
printFormattedObj $ maybe "From unknown file:" (("From " ++) . (++":") . unpackFS) file
mapM_ (printFormattedObj . L.formatConstraint) ctGroup
where
groupedCts = (\ctGroup -> (getCtFile $ head ctGroup, ctGroup)) <$> groupBy eqFileName cts
eqFileName ct1 ct2 = getCtFile ct1 == getCtFile ct2
getCtFile = srcSpanFileName_maybe . constraintSourceLocation