module Language.Haskell.TH.Desugar.Util (
newUniqueName,
impossible,
nameOccursIn, allNamesIn, mkTypeName, mkDataName, isDataName,
stripVarP_maybe, extractBoundNamesStmt,
concatMapM, mapMaybeM, expectJustM,
stripPlainTV_maybe,
thirdOf3, splitAtList, extractBoundNamesDec,
extractBoundNamesPat,
tvbName, tvbToType, nameMatches, freeNamesOfTypes, thdOf3, firstMatch,
unboxedSumDegree_maybe, unboxedSumNameDegree_maybe,
tupleDegree_maybe, tupleNameDegree_maybe, unboxedTupleDegree_maybe,
unboxedTupleNameDegree_maybe, splitTuple_maybe,
topEverywhereM
) where
import Prelude hiding (mapM, foldl, concatMap, any)
import Language.Haskell.TH hiding ( cxt )
import Language.Haskell.TH.Syntax
import qualified Data.Set as S
import Data.Foldable
import Data.Generics hiding ( Fixity )
import Data.Traversable
import Data.Maybe
import Data.Monoid
newUniqueName :: Quasi q => String -> q Name
newUniqueName str = do
n <- qNewName str
qNewName $ show n
mkTypeName :: Quasi q => String -> q Name
mkTypeName str = do
m_name <- qLookupName True str
case m_name of
Just name -> return name
Nothing -> do
Loc { loc_package = pkg, loc_module = modu } <- qLocation
return $ mkNameG_tc pkg modu str
mkDataName :: Quasi q => String -> q Name
mkDataName str = do
m_name <- qLookupName False str
case m_name of
Just name -> return name
Nothing -> do
Loc { loc_package = pkg, loc_module = modu } <- qLocation
return $ mkNameG_d pkg modu str
isDataName :: Name -> Bool
isDataName (Name _ (NameG DataName _ _)) = True
isDataName _ = False
stripVarP_maybe :: Pat -> Maybe Name
stripVarP_maybe (VarP name) = Just name
stripVarP_maybe _ = Nothing
stripPlainTV_maybe :: TyVarBndr -> Maybe Name
stripPlainTV_maybe (PlainTV n) = Just n
stripPlainTV_maybe _ = Nothing
impossible :: Monad q => String -> q a
impossible err = fail (err ++ "\n This should not happen in Haskell.\n Please email rae@cs.brynmawr.edu with your code if you see this.")
tvbName :: TyVarBndr -> Name
tvbName (PlainTV n) = n
tvbName (KindedTV n _) = n
tvbToType :: TyVarBndr -> Type
tvbToType = VarT . tvbName
nameMatches :: Name -> Name -> Bool
nameMatches n1@(Name occ1 flav1) n2@(Name occ2 flav2)
| NameS <- flav1 = occ1 == occ2
| NameS <- flav2 = occ1 == occ2
| NameQ mod1 <- flav1
, NameQ mod2 <- flav2
= mod1 == mod2 && occ1 == occ2
| NameQ mod1 <- flav1
, NameG _ _ mod2 <- flav2
= mod1 == mod2 && occ1 == occ2
| NameG _ _ mod1 <- flav1
, NameQ mod2 <- flav2
= mod1 == mod2 && occ1 == occ2
| otherwise
= n1 == n2
tupleDegree_maybe :: String -> Maybe Int
tupleDegree_maybe s = do
'(' : s1 <- return s
(commas, ")") <- return $ span (== ',') s1
let degree
| "" <- commas = 0
| otherwise = length commas + 1
return degree
tupleNameDegree_maybe :: Name -> Maybe Int
tupleNameDegree_maybe = tupleDegree_maybe . nameBase
unboxedSumDegree_maybe :: String -> Maybe Int
unboxedSumDegree_maybe = unboxedSumTupleDegree_maybe '|'
unboxedSumNameDegree_maybe :: Name -> Maybe Int
unboxedSumNameDegree_maybe = unboxedSumDegree_maybe . nameBase
unboxedTupleDegree_maybe :: String -> Maybe Int
unboxedTupleDegree_maybe = unboxedSumTupleDegree_maybe ','
unboxedSumTupleDegree_maybe :: Char -> String -> Maybe Int
unboxedSumTupleDegree_maybe sep s = do
'(' : '#' : s1 <- return s
(seps, "#)") <- return $ span (== sep) s1
let degree
| "" <- seps = 0
| otherwise = length seps + 1
return degree
unboxedTupleNameDegree_maybe :: Name -> Maybe Int
unboxedTupleNameDegree_maybe = unboxedTupleDegree_maybe . nameBase
splitTuple_maybe :: Type -> Maybe [Type]
splitTuple_maybe t = go [] t
where go args (t1 `AppT` t2) = go (t2:args) t1
go args (t1 `SigT` _k) = go args t1
go args (ConT con_name)
| Just degree <- tupleNameDegree_maybe con_name
, length args == degree
= Just args
go args (TupleT degree)
| length args == degree
= Just args
go _ _ = Nothing
nameOccursIn :: Data a => Name -> a -> Bool
nameOccursIn n = everything (||) $ mkQ False (== n)
allNamesIn :: Data a => a -> [Name]
allNamesIn = everything (++) $ mkQ [] (:[])
extractBoundNamesStmt :: Stmt -> S.Set Name
extractBoundNamesStmt (BindS pat _) = extractBoundNamesPat pat
extractBoundNamesStmt (LetS decs) = foldMap extractBoundNamesDec decs
extractBoundNamesStmt (NoBindS _) = S.empty
extractBoundNamesStmt (ParS stmtss) = foldMap (foldMap extractBoundNamesStmt) stmtss
extractBoundNamesDec :: Dec -> S.Set Name
extractBoundNamesDec (FunD name _) = S.singleton name
extractBoundNamesDec (ValD pat _ _) = extractBoundNamesPat pat
extractBoundNamesDec _ = S.empty
extractBoundNamesPat :: Pat -> S.Set Name
extractBoundNamesPat (LitP _) = S.empty
extractBoundNamesPat (VarP name) = S.singleton name
extractBoundNamesPat (TupP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (UnboxedTupP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (ConP _ pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (InfixP p1 _ p2) = extractBoundNamesPat p1 `S.union`
extractBoundNamesPat p2
extractBoundNamesPat (UInfixP p1 _ p2) = extractBoundNamesPat p1 `S.union`
extractBoundNamesPat p2
extractBoundNamesPat (ParensP pat) = extractBoundNamesPat pat
extractBoundNamesPat (TildeP pat) = extractBoundNamesPat pat
extractBoundNamesPat (BangP pat) = extractBoundNamesPat pat
extractBoundNamesPat (AsP name pat) = S.singleton name `S.union` extractBoundNamesPat pat
extractBoundNamesPat WildP = S.empty
extractBoundNamesPat (RecP _ field_pats) = let (_, pats) = unzip field_pats in
foldMap extractBoundNamesPat pats
extractBoundNamesPat (ListP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (SigP pat _) = extractBoundNamesPat pat
extractBoundNamesPat (ViewP _ pat) = extractBoundNamesPat pat
#if __GLASGOW_HASKELL__ >= 801
extractBoundNamesPat (UnboxedSumP pat _ _) = extractBoundNamesPat pat
#endif
freeNamesOfTypes :: [Type] -> S.Set Name
freeNamesOfTypes = mconcat . map go
where
go (ForallT tvbs cxt ty) = (go ty <> mconcat (map go_pred cxt))
S.\\ S.fromList (map tvbName tvbs)
go (AppT t1 t2) = go t1 <> go t2
go (SigT ty _) = go ty
go (VarT n) = S.singleton n
go _ = S.empty
#if __GLASGOW_HASKELL__ >= 709
go_pred = go
#else
go_pred (ClassP _ tys) = freeNamesOfTypes tys
go_pred (EqualP t1 t2) = go t1 <> go t2
#endif
splitAtList :: [a] -> [b] -> ([b], [b])
splitAtList [] x = ([], x)
splitAtList (_ : t) (x : xs) =
let (as, bs) = splitAtList t xs in
(x : as, bs)
splitAtList (_ : _) [] = ([], [])
thdOf3 :: (a,b,c) -> c
thdOf3 (_,_,c) = c
thirdOf3 :: (a -> b) -> (c, d, a) -> (c, d, b)
thirdOf3 f (c, d, a) = (c, d, f a)
concatMapM :: (Monad monad, Monoid monoid, Traversable t)
=> (a -> monad monoid) -> t a -> monad monoid
concatMapM fn list = do
bss <- mapM fn list
return $ fold bss
mapMaybeM :: Monad m => (a -> m (Maybe b)) -> [a] -> m [b]
mapMaybeM _ [] = return []
mapMaybeM f (x:xs) = do
y <- f x
ys <- mapMaybeM f xs
return $ case y of
Nothing -> ys
Just z -> z : ys
expectJustM :: Monad m => String -> Maybe a -> m a
expectJustM _ (Just x) = return x
expectJustM err Nothing = fail err
firstMatch :: (a -> Maybe b) -> [a] -> Maybe b
firstMatch f xs = listToMaybe $ mapMaybe f xs
topEverywhereM :: (Typeable a, Data b, Monad m) => (a -> m a) -> b -> m b
topEverywhereM handler =
gmapM (topEverywhereM handler) `extM` handler