module Control.Super.Plugin.Separation
( ConstraintGroup
, separateContraints
, componentTopTyCons
, componentTopTcVars
, componentMonoTyCon
) where
import Data.Maybe ( fromJust, fromMaybe )
import qualified Data.Set as Set
import Data.Graph.Inductive.Graph
( LNode, Edge
, mkGraph, toLEdge )
import Data.Graph.Inductive.PatriciaTree ( Gr )
import Data.Graph.Inductive.Query.DFS ( components )
import TcRnTypes ( Ct )
import TyCon ( TyCon )
import Type ( Type, TyVar )
import TcType ( isAmbiguousTyVar )
import Class ( Class )
import qualified Control.Super.Plugin.Collection.Set as S
import Control.Super.Plugin.Constraint
( WantedCt
, constraintClassTyArgs
, isAnyClassConstraint )
import Control.Super.Plugin.Utils
( collectTopTyCons, collectTopTcVars )
type SCNode = LNode WantedCt
type ConstraintGroup = [WantedCt]
componentMonoTyCon :: [Class] -> ConstraintGroup -> Maybe TyCon
componentMonoTyCon releventClss cts =
let
smCts = filter (isAnyClassConstraint releventClss) cts
tyVars = Set.filter (not . isAmbiguousTyVar) $ componentTopTcVars smCts
tyCons = componentTopTyCons smCts
in case (S.toList tyCons, Set.size tyVars) of
([tc], 0) -> Just tc
_ -> Nothing
componentTopTyCons :: [WantedCt] -> S.Set TyCon
componentTopTyCons = collectInternalSet collectTopTyCons
componentTopTcVars :: [WantedCt] -> Set.Set TyVar
componentTopTcVars = collect collectTopTcVars
collect :: (Ord a) => ([Type] -> Set.Set a) -> [Ct] -> Set.Set a
collect f cts = Set.unions $ fmap collectLocal cts
where
collectLocal ct = maybe Set.empty f $ constraintClassTyArgs ct
collectInternalSet :: ([Type] -> S.Set a) -> [Ct] -> S.Set a
collectInternalSet f cts = S.unions $ fmap collectLocal cts
where
collectLocal ct = maybe S.empty f $ constraintClassTyArgs ct
separateContraints :: [WantedCt] -> [ConstraintGroup]
separateContraints wantedCts = comps
where
comps :: [ConstraintGroup]
comps = fmap (\n -> fromJust $ lookup n nodes) <$> components g
g :: Gr WantedCt ()
g = mkGraph nodes (fmap (\e -> toLEdge e ()) edges)
nodes :: [SCNode]
nodes = zip [0..] wantedCts
edges :: [Edge]
edges = [ e | e <- allEdgesFor nodes, isEdge e ]
isEdge :: Edge -> Bool
isEdge (na, nb) = fromMaybe False $ do
caArgs <- lookup na nodes >>= constraintClassTyArgs
cbArgs <- lookup nb nodes >>= constraintClassTyArgs
let ta = Set.filter isAmbiguousTyVar $ collectTopTcVars caArgs
let tb = Set.filter isAmbiguousTyVar $ collectTopTcVars cbArgs
return $ not $ Set.null $ Set.intersection ta tb
allEdgesFor :: [SCNode] -> [Edge]
allEdgesFor [] = []
allEdgesFor (n : ns) = concat [ fmap (\m -> (fst m, fst n)) ns
, fmap (\m -> (fst n, fst m)) ns
, allEdgesFor ns ]