module Control.Super.Plugin.Solving
( solveConstraints
) where
import Data.Maybe
( catMaybes
, isJust, isNothing
, fromJust, fromMaybe )
import Data.List ( partition, nubBy )
import qualified Data.Set as Set
import Control.Monad ( forM, forM_, filterM )
import TcRnTypes ( Ct(..) )
import TyCon ( TyCon )
import Class ( Class, classTyCon )
import Type
( Type, TyVar
, substTyVar, substTys
, eqType )
import TcType ( isAmbiguousTyVar )
import InstEnv ( ClsInst, instanceHead )
import Unify ( tcUnifyTy )
import qualified Outputable as O
import qualified Control.Super.Plugin.Collection.Set as S
import Control.Super.Plugin.Debug ( sDocToStr )
import Control.Super.Plugin.InstanceDict ( InstanceDict )
import Control.Super.Plugin.Wrapper
( TypeVarSubst, mkTypeVarSubst )
import Control.Super.Plugin.Environment
( SupermonadPluginM
, getGivenConstraints, getWantedConstraints
, getInstanceFor
, addTyVarEquality, addTyVarEqualities
, addTypeEqualities
, getTyVarEqualities
, printMsg, printObj, printErr
, addWarning, displayWarnings
, throwPluginError, throwPluginErrorSDoc )
import Control.Super.Plugin.Environment.Lift
( isPotentiallyInstantiatedCt
, partiallyApplyTyCons )
import Control.Super.Plugin.Constraint
( WantedCt
, isClassConstraint
, isAnyClassConstraint
, constraintClassType
, constraintClassTyArgs )
import Control.Super.Plugin.Separation
( ConstraintGroup
, separateContraints
, componentTopTyCons, componentTopTcVars
, componentMonoTyCon )
import Control.Super.Plugin.Utils
( collectTopTcVars
, collectTopTyCons
, collectTyVars
, associations
, allM )
import Control.Super.Plugin.Log
( formatConstraint )
solveConstraints :: [Class] -> ConstraintGroup -> SupermonadPluginM InstanceDict ()
solveConstraints relevantClss wantedCts = do
let ctGroups = separateContraints wantedCts
let markedCtGroups = fmap (\g -> (componentMonoTyCon relevantClss g, g)) ctGroups
let (monoGroups, polyGroups) = partition (isJust . fst) markedCtGroups
forM_ (fmap (\(tc, g) -> (fromJust tc, g)) monoGroups) $ solveMonoConstraintGroup relevantClss
forM_ (fmap snd polyGroups) $ solvePolyConstraintGroup relevantClss
solveSolvedTyConIndices relevantClss
displayWarnings
solveMonoConstraintGroup :: [Class] -> (TyCon, ConstraintGroup) -> SupermonadPluginM s ()
solveMonoConstraintGroup _relevantClss (_, []) = return ()
solveMonoConstraintGroup relevantClss (tyCon, ctGroup) = do
let smCtGroup = filter (isAnyClassConstraint relevantClss) ctGroup
forM_ smCtGroup $ \ct -> do
let ctAmbVars = Set.filter isAmbiguousTyVar
$ collectTopTcVars
$ fromMaybe []
$ constraintClassTyArgs ct
forM_ ctAmbVars $ \tyVar -> do
appliedTyCon <- either throwPluginErrorSDoc return =<< partiallyApplyTyCons [(tyVar, Left tyCon)]
case nubBy tyConAssocEq appliedTyCon of
[] -> do
throwPluginError "How did this become an empty list?"
[(tv, ty, _)] -> do
addTyVarEquality ct tv ty
_ -> do
throwPluginError "How did this become a list with more then one element?"
where
tyConAssocEq :: (TyVar, Type, [TyVar]) -> (TyVar, Type, [TyVar]) -> Bool
tyConAssocEq (tv, t, tvs) (tv', t', tvs') = tv == tv' && tvs == tvs' && eqType t t'
solvePolyConstraintGroup :: [Class] -> ConstraintGroup -> SupermonadPluginM s ()
solvePolyConstraintGroup relevantClss ctGroup = do
(_, assocs) <- determineValidConstraintGroupAssocs relevantClss ctGroup
appliedAssocs <- forM assocs $ \assoc -> either throwPluginErrorSDoc return =<< partiallyApplyTyCons assoc
case (ctGroup, appliedAssocs) of
([], _) -> return ()
(_, []) -> do
let topTcVars = concat $ fmap collectRelevantTopTcVars ctGroup
if null topTcVars then do
return ()
else do
addWarning
"There are no possible associations for the current constraint group!"
( O.hang (O.text "There are two possible reasons for this warning:") 2
$ O.vcat $
[ O.text "1. Either the group can not be solved or"
, O.text "2. further iterations between the plugin and type checker "
, O.text " have to resolve for sufficient information to arise."
] ++ fmap (O.text . formatConstraint) ctGroup)
(_, [appliedAssoc]) -> do
forM_ appliedAssoc $ \(tv, ty, _flexVars) -> do
addTyVarEquality (head ctGroup) tv ty
(_, _) -> do
printMsg "Possible associations:"
forM_ appliedAssocs printObj
throwPluginError "There is more then one possible association for the current constraint group!"
where
collectRelevantTopTcVars :: Ct -> [TyVar]
collectRelevantTopTcVars ct = do
let isRelevantCt = isAnyClassConstraint relevantClss ct
case (isRelevantCt, constraintClassType ct) of
(True, Just (_cls, tyArgs)) -> Set.toList $ collectTopTcVars tyArgs
_ -> []
type TcTvSet = (S.Set TyCon, Set.Set TyVar)
tctvIntersection :: TcTvSet -> TcTvSet -> TcTvSet
tctvIntersection (tca, tva) (tcb, tvb) = (S.intersection tca tcb, Set.intersection tva tvb)
tctvNull :: TcTvSet -> Bool
tctvNull (tcs, tvs) = S.null tcs && Set.null tvs
tctvUnion :: TcTvSet -> TcTvSet -> TcTvSet
tctvUnion (tca, tva) (tcb, tvb) = (S.union tca tcb, Set.union tva tvb)
tctvToList :: TcTvSet -> [Either TyCon TyVar]
tctvToList (tcs, tvs) = (fmap Left $ S.toList tcs) ++ (fmap Right $ Set.toList tvs)
determineValidConstraintGroupAssocs :: [Class] -> ConstraintGroup -> SupermonadPluginM s ([WantedCt], [[(TyVar, Either TyCon TyVar)]])
determineValidConstraintGroupAssocs _relevantClss [] = throwPluginError "Solving received an empty constraint group!"
determineValidConstraintGroupAssocs relevantClss ctGroup = do
givenCts <- getGivenConstraints
let smCtGroup = filter (isAnyClassConstraint relevantClss) ctGroup
tyConVars <- Set.toList <$> getAmbTyConVarsFrom smCtGroup
let wantedTyConBase = getTyConBaseFrom smCtGroup
let givenTyConBase = getTyConBaseFrom $ filterRelevantCtsWith givenCts wantedTyConBase
let tyConBase :: [Either TyCon TyVar]
tyConBase = tctvToList $ tctvUnion wantedTyConBase givenTyConBase
let assocs :: [[(TyVar, Either TyCon TyVar)]]
assocs = filter (not . null) $ associations $ fmap (\tv -> (tv, tyConBase)) tyConVars
checkedAssocs <- forM assocs $ \assoc -> do
validAssoc <- allM (\ct -> isPotentiallyInstantiatedCt ct assoc) ctGroup
return (assoc, validAssoc)
let validAssocs = fmap fst $ filter snd checkedAssocs
return (ctGroup, validAssocs)
where
filterRelevantCtsWith :: [Ct] -> TcTvSet -> [Ct]
filterRelevantCtsWith allCts baseTyCons =
let cts = filter (isAnyClassConstraint relevantClss) allCts
in filter predicate cts
where
predicate :: Ct -> Bool
predicate ct =
let ctBase = getTyConBaseFrom [ct]
in not $ tctvNull $ tctvIntersection ctBase baseTyCons
getTyConBaseFrom :: [Ct] -> TcTvSet
getTyConBaseFrom cts =
let checkedCts = filter (isAnyClassConstraint relevantClss) cts
baseTvs :: Set.Set TyVar
baseTvs = Set.filter (not . isAmbiguousTyVar) $ componentTopTcVars checkedCts
baseTcs :: S.Set TyCon
baseTcs = componentTopTyCons checkedCts
in (baseTcs, baseTvs)
getAmbTyConVarsFrom :: [Ct] -> SupermonadPluginM s (Set.Set TyVar)
getAmbTyConVarsFrom cts = do
let checkedCts = filter (isAnyClassConstraint relevantClss) cts
return $ Set.filter isAmbiguousTyVar $ componentTopTcVars checkedCts
solveSolvedTyConIndices :: [Class] -> SupermonadPluginM InstanceDict ()
solveSolvedTyConIndices relevantClss = do
tyVarEqs <- getTyVarEqualities
let tvSubst = mkTypeVarSubst $ fmap (\(_ct, tv, ty) -> (tv, ty)) tyVarEqs
wantedCts <- getWantedConstraints
let prepWantedCts = catMaybes $ fmap (prepCt tvSubst) wantedCts
printMsg "Unification solve constraints..."
forM_ relevantClss $ \cls -> unificationSolve prepWantedCts (return . isClassConstraint cls) (\tc -> getInstanceFor tc cls)
where
unificationSolve :: [(Ct, TyCon, [Type])]
-> (Ct -> SupermonadPluginM s Bool)
-> (TyCon -> SupermonadPluginM s (Maybe ClsInst))
-> SupermonadPluginM s ()
unificationSolve prepWantedCts isRequiredConstraint getTyConInst = do
cts <- filterTopTyConSolvedConstraints prepWantedCts isRequiredConstraint
forM_ cts $ \ct -> do
eResult <- withTopTyCon ct getTyConInst $ \_topTyCon _ctArgs inst -> do
case deriveUnificationConstraints ct inst of
Left err -> do
printErr $ sDocToStr err
Right (tvTyEqs, tyEqs) -> do
addTyVarEqualities tvTyEqs
addTypeEqualities tyEqs
case eResult of
Left err -> printErr $ sDocToStr err
Right () -> return ()
filterTopTyConSolvedConstraints :: [(WantedCt, TyCon, [Type])]
-> (WantedCt -> SupermonadPluginM s Bool)
-> SupermonadPluginM s [(WantedCt, TyCon, [Type])]
filterTopTyConSolvedConstraints cts p = do
predFilteredCts <- filterM (\(ct, _tc, _args) -> p ct) cts
let filterNoVarCts = filter (\(_ct, _tc, args) -> not $ Set.null
$ Set.filter isAmbiguousTyVar
$ Set.unions
$ fmap collectTyVars args)
predFilteredCts
return $ filter (Set.null . collectTopTcVars . (\(_ct, _tc, args) -> args)) filterNoVarCts
withTopTyCon :: (Ct, TyCon, [Type])
-> (TyCon -> SupermonadPluginM s (Maybe ClsInst))
-> (TyCon -> [Type] -> ClsInst -> SupermonadPluginM s a)
-> SupermonadPluginM s (Either O.SDoc a)
withTopTyCon (ct, _ctClsTyCon, ctArgs) getTyConInst process = do
let mTopTyCon = S.toList $ collectTopTyCons ctArgs
case mTopTyCon of
[topTyCon] -> do
mInst <- getTyConInst topTyCon
case mInst of
Just inst -> Right <$> process topTyCon ctArgs inst
Nothing -> do
return $ Left
$ O.text "Constraints top type constructor does not have an associated instance:"
O.$$ O.ppr topTyCon
_ -> do
return $ Left
$ O.text "Constraint misses a unqiue top-level type constructor:"
O.$$ O.ppr ct
deriveUnificationConstraints :: (Ct, TyCon, [Type]) -> ClsInst -> Either O.SDoc ([(Ct, TyVar, Type)], [(Ct, Type, Type)])
deriveUnificationConstraints (ct, _ctClsTyCon, ctArgs) inst = do
let (instVars, _instCls, instArgs) = instanceHead inst
let ctVars = Set.toList $ Set.unions $ fmap collectTyVars ctArgs
let mSubsts = zipWith tcUnifyTy instArgs ctArgs
if any isNothing mSubsts then do
Left $ O.hang (O.text "Unification constraint solving not possible, because instance and constraint are not unifiable!") 2
$ (O.hang (O.text "Instance:") 2 $ O.ppr inst) O.$$
(O.hang (O.text "Constraint:") 2 $ O.ppr ct) O.$$
(O.hang (O.text "Constraint arguments:") 2 $ O.ppr ctArgs)
else do
let substs = catMaybes mSubsts
let instVarEqGroups = collectEqualityGroup substs instVars
instVarEqGroupsCt <- fmap concat $ forM instVarEqGroups $ \(_, eqGroup) -> do
return $ mkEqGroup ct eqGroup
let ctVarEqGroups = collectEqualityGroup substs $ filter isAmbiguousTyVar ctVars
let ctVarEqCts = mkEqStarGroup ct ctVarEqGroups
return (ctVarEqCts, instVarEqGroupsCt)
mkEqGroup :: Ct -> [Type] -> [(Ct, Type, Type)]
mkEqGroup _ [] = []
mkEqGroup baseCt (ty : tys) = fmap (\ty' -> (baseCt, ty, ty')) tys
mkEqStarGroup :: Ct -> [(TyVar, [Type])] -> [(Ct, TyVar, Type)]
mkEqStarGroup baseCt eqGroups = concatMap (\(tv, tys) -> fmap (\ty -> (baseCt, tv, ty)) tys) eqGroups
collectEqualityGroup :: [TypeVarSubst] -> [TyVar] -> [(TyVar, [Type])]
collectEqualityGroup substs tvs = [ (tv, nubBy eqType $ filter (notElem tv . collectTyVars)
$ [ substTyVar subst tv | subst <- substs]
) | tv <- tvs]
prepCt :: TypeVarSubst -> Ct -> Maybe (Ct, TyCon, [Type])
prepCt subst ct = fmap (\(cls, args) -> (ct, classTyCon cls, substTys subst args)) $ constraintClassType ct