module Control.Super.Plugin.Utils (
errIndent
, collectTopTyCons
, collectTopTcVars
, collectTopTcVarsWithArity
, collectTyVars
, mkTcVarSubst
, skolemVarsBindFun
, eqTyVar, eqTyVar'
, getTyConName, getClassName
, isAmbiguousType
, partiallyApplyTyCons
, applyTyCon
, splitKindFunOfTcTv
, atIndex
, t1st, t2nd, t3rd
, associations
, removeDup, removeDupByIndex
, removeDupByIndexEq
, removeDupUnique, removeDupByIndexUnique
, lookupBy
, allM, anyM
, fromLeft, fromRight
, partitionM
) where
import Data.Maybe ( listToMaybe, catMaybes )
import Data.List ( find )
import qualified Data.Set as Set
import qualified Data.Map.Strict as Map
import Control.Monad ( forM )
import Control.Arrow ( second )
import Unique ( Uniquable )
import BasicTypes ( Arity )
import Name ( nameOccName )
import OccName ( occNameString )
import Type
( Type, TyVar
, getTyVar_maybe
, tyConAppTyCon_maybe
, splitTyConApp_maybe, splitFunTy_maybe, splitAppTy_maybe
, getEqPredTys_maybe
, splitAppTys
, mkTyConTy, mkTyVarTy, mkAppTys
, eqType )
import TyCon
( TyCon
, tyConKind, tyConName )
import Var ( tyVarKind )
import TcType ( isAmbiguousTyVar )
import Kind ( Kind )
import Class ( Class, classTyCon )
import Unify ( BindFlag(..) )
import InstEnv ( instanceBindFun )
import TcPluginM ( TcPluginM, newFlexiTyVar )
import Outputable ( ($$) )
import qualified Outputable as O
import Control.Super.Plugin.Collection.Set ( Set )
import qualified Control.Super.Plugin.Collection.Set as S
import qualified Control.Super.Plugin.Collection.Map as M
import Control.Super.Plugin.Wrapper
( TypeVarSubst
, mkTypeVarSubst
, splitKindFunTys
, fromLeft, fromRight
)
errIndent :: Int
errIndent = 4
collectTopTyCons :: [Type] -> Set TyCon
collectTopTyCons tys = S.fromList $ catMaybes $ fmap tyConAppTyCon_maybe tys
collectTopTcVars :: [Type] -> Set.Set TyVar
collectTopTcVars = Set.map fst . collectTopTcVarsWithArity
collectTopTcVarsWithArity :: [Type] -> Set.Set (TyVar, Arity)
collectTopTcVarsWithArity tys = Set.fromList $ catMaybes $ fmap getTyVarAndArity tys
where
getTyVarAndArity :: Type -> Maybe (TyVar, Arity)
getTyVarAndArity t = do
let (tf, _args) = splitAppTys t
tv <- getTyVar_maybe tf
return (tv, tyVarArity tv)
collectTyVars :: Type -> Set.Set TyVar
collectTyVars t =
case getTyVar_maybe t of
Just tv -> Set.singleton tv
Nothing -> case splitTyConApp_maybe t of
Just (_tc, args) -> Set.unions $ fmap collectTyVars args
Nothing -> case splitFunTy_maybe t of
Just (ta, tb) -> collectTyVars ta `Set.union` collectTyVars tb
Nothing -> case splitAppTy_maybe t of
Just (ta, tb) -> collectTyVars ta `Set.union` collectTyVars tb
Nothing -> case getEqPredTys_maybe t of
Just (_r, ta, tb) -> collectTyVars ta `Set.union` collectTyVars tb
Nothing -> Set.empty
mkTcVarSubst :: [(TyVar, TyCon)] -> TypeVarSubst
mkTcVarSubst substs = mkTypeVarSubst $ fmap (second mkTyConTy) substs
skolemVarsBindFun :: [TyVar] -> TyVar -> BindFlag
skolemVarsBindFun tvs var = case find (var ==) tvs of
Just _ -> Skolem
Nothing -> instanceBindFun var
eqTyVar :: Type -> Type -> Bool
eqTyVar ty ty' = case getTyVar_maybe ty of
Just tv -> eqTyVar' tv ty'
_ -> False
eqTyVar' :: TyVar -> Type -> Bool
eqTyVar' tv ty = case getTyVar_maybe ty of
Just tv' -> tv == tv'
Nothing -> False
tyVarArity :: TyVar -> Arity
tyVarArity = length . fst . splitKindFunTys . tyVarKind
getTyConName :: TyCon -> String
getTyConName = occNameString . nameOccName . tyConName
getClassName :: Class -> String
getClassName cls = getTyConName $ classTyCon cls
atIndex :: [a] -> Int -> Maybe a
atIndex xs i = listToMaybe $ drop i xs
t1st :: (a, b, c) -> a
t1st (a, _, _) = a
t2nd :: (a, b, c) -> b
t2nd (_, b, _) = b
t3rd :: (a, b, c) -> c
t3rd (_, _, c) = c
isAmbiguousType :: Type -> Bool
isAmbiguousType ty = maybe False isAmbiguousTyVar $ getTyVar_maybe ty
partiallyApplyTyCons :: [(TyVar, Either TyCon TyVar)] -> TcPluginM (Either O.SDoc [(TyVar, Type, [TyVar])])
partiallyApplyTyCons [] = return $ Right []
partiallyApplyTyCons ((tv, tc) : assocs) = do
let (tvKindArgs, tvKindRes) = splitKindFunOfTcTv $ Right tv
let (tcKindArgs, tcKindRes) = splitKindFunOfTcTv tc
let checkKindLength = length tcKindArgs >= length tvKindArgs
let checkKindMatch = and (uncurry eqType <$> zip (reverse tvKindArgs) (reverse tcKindArgs)) && eqType tcKindRes tvKindRes
case (checkKindLength, checkKindMatch) of
(False, _) -> return $ Left $ O.text "Kind mismatch between type constructor and type variable: "
$$ O.ppr tcKindArgs $$ O.ppr tvKindArgs
(_, False) -> return $ Left $ O.text "Kind mismatch between type constructor and type variable: "
$$ O.ppr tc $$ O.ppr tv
_ -> do
eAppliedAssocs <- partiallyApplyTyCons assocs
case eAppliedAssocs of
Left err -> return $ Left err
Right appliedAssocs -> do
(appliedTcTy, argVars) <- applyTyCon (tc, take (length tcKindArgs - length tvKindArgs) tcKindArgs)
return $ Right $ (tv, appliedTcTy, argVars) : appliedAssocs
applyTyCon :: (Either TyCon TyVar, [Kind]) -> TcPluginM (Type, [TyVar])
applyTyCon (eTcTv, ks) = do
tyVarArgs <- forM ks newFlexiTyVar
let t = either mkTyConTy mkTyVarTy eTcTv
return $ (mkAppTys t $ fmap mkTyVarTy tyVarArgs, tyVarArgs)
splitKindFunOfTcTv :: Either TyCon TyVar -> ([Kind], Kind)
splitKindFunOfTcTv tc = case tc of
Left tyCon -> splitKindFunTys $ tyConKind tyCon
Right tyVar -> splitKindFunTys $ tyVarKind tyVar
associations :: [(key , [value])] -> [[(key, value)]]
associations [] = [[]]
associations ((_x, []) : _xys) = []
associations ((x, y : ys) : xys) = fmap ((x, y) :) (associations xys) ++ associations ((x, ys) : xys)
removeDup :: (Ord a) => [a] -> [a]
removeDup = (Set.toAscList) . (Set.fromList)
removeDupUnique :: (Uniquable a) => [a] -> [a]
removeDupUnique = (S.toList) . (S.fromList)
removeDupByIndex :: (Ord a) => [(a,b)] -> [(a,b)]
removeDupByIndex = Map.toList . Map.fromList
removeDupByIndexEq :: (Eq a) => [(a,b)] -> [(a,b)]
removeDupByIndexEq [] = []
removeDupByIndexEq ((a,b) : l) = (a,b) : (removeDupByIndexEq $ filter (\(a',_) -> a' /= a) l)
removeDupByIndexUnique :: (Uniquable a) => [(a,b)] -> [(a,b)]
removeDupByIndexUnique = M.toList . M.fromList
lookupBy :: (a -> a -> Bool) -> a -> [(a, b)] -> Maybe b
lookupBy _eq _x [] = Nothing
lookupBy eq x ((y, b) : ybs)
| eq x y = Just b
| otherwise = lookupBy eq x ybs
allM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool
allM = quantM (&&) True
anyM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool
anyM = quantM (||) False
quantM :: (Monad m) => (Bool -> Bool -> Bool) -> Bool -> (a -> m Bool) -> [a] -> m Bool
quantM _comp def _p [] = return def
quantM comp def p (x : xs) = do
bx <- p x
bxs <- quantM comp def p xs
return $ bx `comp` bxs
partitionM :: (Monad m) => (a -> m Bool) -> [a] -> m ([a], [a])
partitionM _ [] = return ([], [])
partitionM p (x : xs) = do
(ts, fs) <- partitionM p xs
b <- p x
return $ if b then (x : ts, fs) else (ts, x : fs)