{-# LANGUAGE CPP, DeriveDataTypeable, RankNTypes, ScopedTypeVariables, TupleSections #-}
#if __GLASGOW_HASKELL__ >= 800
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TypeApplications #-}
#endif
module Language.Haskell.TH.Desugar.Util (
newUniqueName,
impossible,
nameOccursIn, allNamesIn, mkTypeName, mkDataName, mkNameWith, isDataName,
stripVarP_maybe, extractBoundNamesStmt,
concatMapM, mapAccumLM, mapMaybeM, expectJustM,
stripPlainTV_maybe,
thirdOf3, splitAtList, extractBoundNamesDec,
extractBoundNamesPat,
tvbToType, tvbToTypeWithSig, tvbToTANormalWithSig,
nameMatches, thdOf3, firstMatch,
unboxedSumDegree_maybe, unboxedSumNameDegree_maybe,
tupleDegree_maybe, tupleNameDegree_maybe, unboxedTupleDegree_maybe,
unboxedTupleNameDegree_maybe, splitTuple_maybe,
topEverywhereM, isInfixDataCon,
isTypeKindName, typeKindName,
mkExtraKindBindersGeneric, unravelType, unSigType, unfoldType,
TypeArg(..), applyType, filterTANormals, unSigTypeArg, probablyWrongUnTypeArg
#if __GLASGOW_HASKELL__ >= 800
, bindIP
#endif
) where
import Prelude hiding (mapM, foldl, concatMap, any)
import Language.Haskell.TH hiding ( cxt )
import Language.Haskell.TH.Datatype (tvName)
import qualified Language.Haskell.TH.Desugar.OSet as OS
import Language.Haskell.TH.Desugar.OSet (OSet)
import Language.Haskell.TH.Syntax
import Control.Monad ( replicateM )
import qualified Control.Monad.Fail as Fail
import Data.Foldable
import Data.Generics hiding ( Fixity )
import Data.Traversable
import Data.Maybe
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid
#endif
#if __GLASGOW_HASKELL__ >= 800
import qualified Data.Kind as Kind
import GHC.Classes ( IP )
import Unsafe.Coerce ( unsafeCoerce )
#endif
newUniqueName :: Quasi q => String -> q Name
newUniqueName str = do
n <- qNewName str
qNewName $ show n
mkNameWith :: Quasi q => (String -> q (Maybe Name))
-> (String -> String -> String -> Name)
-> String -> q Name
mkNameWith lookup_fun mkName_fun str = do
m_name <- lookup_fun str
case m_name of
Just name -> return name
Nothing -> do
Loc { loc_package = pkg, loc_module = modu } <- qLocation
return $ mkName_fun pkg modu str
mkTypeName :: Quasi q => String -> q Name
mkTypeName = mkNameWith (qLookupName True) mkNameG_tc
mkDataName :: Quasi q => String -> q Name
mkDataName = mkNameWith (qLookupName False) mkNameG_d
isDataName :: Name -> Bool
isDataName (Name _ (NameG DataName _ _)) = True
isDataName _ = False
stripVarP_maybe :: Pat -> Maybe Name
stripVarP_maybe (VarP name) = Just name
stripVarP_maybe _ = Nothing
stripPlainTV_maybe :: TyVarBndr -> Maybe Name
stripPlainTV_maybe (PlainTV n) = Just n
stripPlainTV_maybe _ = Nothing
impossible :: Fail.MonadFail q => String -> q a
impossible err = Fail.fail (err ++ "\n This should not happen in Haskell.\n Please email rae@cs.brynmawr.edu with your code if you see this.")
tvbToType :: TyVarBndr -> Type
tvbToType = VarT . tvName
tvbToTypeWithSig :: TyVarBndr -> Type
tvbToTypeWithSig (PlainTV n) = VarT n
tvbToTypeWithSig (KindedTV n k) = SigT (VarT n) k
tvbToTANormalWithSig :: TyVarBndr -> TypeArg
tvbToTANormalWithSig = TANormal . tvbToTypeWithSig
nameMatches :: Name -> Name -> Bool
nameMatches n1@(Name occ1 flav1) n2@(Name occ2 flav2)
| NameS <- flav1 = occ1 == occ2
| NameS <- flav2 = occ1 == occ2
| NameQ mod1 <- flav1
, NameQ mod2 <- flav2
= mod1 == mod2 && occ1 == occ2
| NameQ mod1 <- flav1
, NameG _ _ mod2 <- flav2
= mod1 == mod2 && occ1 == occ2
| NameG _ _ mod1 <- flav1
, NameQ mod2 <- flav2
= mod1 == mod2 && occ1 == occ2
| otherwise
= n1 == n2
tupleDegree_maybe :: String -> Maybe Int
tupleDegree_maybe s = do
'(' : s1 <- return s
(commas, ")") <- return $ span (== ',') s1
let degree
| "" <- commas = 0
| otherwise = length commas + 1
return degree
tupleNameDegree_maybe :: Name -> Maybe Int
tupleNameDegree_maybe = tupleDegree_maybe . nameBase
unboxedSumDegree_maybe :: String -> Maybe Int
unboxedSumDegree_maybe = unboxedSumTupleDegree_maybe '|'
unboxedSumNameDegree_maybe :: Name -> Maybe Int
unboxedSumNameDegree_maybe = unboxedSumDegree_maybe . nameBase
unboxedTupleDegree_maybe :: String -> Maybe Int
unboxedTupleDegree_maybe = unboxedSumTupleDegree_maybe ','
unboxedSumTupleDegree_maybe :: Char -> String -> Maybe Int
unboxedSumTupleDegree_maybe sep s = do
'(' : '#' : s1 <- return s
(seps, "#)") <- return $ span (== sep) s1
let degree
| "" <- seps = 0
| otherwise = length seps + 1
return degree
unboxedTupleNameDegree_maybe :: Name -> Maybe Int
unboxedTupleNameDegree_maybe = unboxedTupleDegree_maybe . nameBase
splitTuple_maybe :: Type -> Maybe [Type]
splitTuple_maybe t = go [] t
where go args (t1 `AppT` t2) = go (t2:args) t1
go args (t1 `SigT` _k) = go args t1
go args (ConT con_name)
| Just degree <- tupleNameDegree_maybe con_name
, length args == degree
= Just args
go args (TupleT degree)
| length args == degree
= Just args
go _ _ = Nothing
mkExtraKindBindersGeneric
:: Quasi q
=> (kind -> ([tyVarBndr], [pred], [kind], kind))
-> (Name -> kind -> tyVarBndr)
-> kind -> q [tyVarBndr]
mkExtraKindBindersGeneric unravel mkKindedTV k = do
let (_, _, args, _) = unravel k
names <- replicateM (length args) (qNewName "a")
return (zipWith mkKindedTV names args)
unravelType :: Type -> ([TyVarBndr], [Pred], [Type], Type)
unravelType (ForallT tvbs cxt ty) =
let (tvbs', cxt', tys, res) = unravelType ty in
(tvbs ++ tvbs', cxt ++ cxt', tys, res)
unravelType (AppT (AppT ArrowT t1) t2) =
let (tvbs, cxt, tys, res) = unravelType t2 in
(tvbs, cxt, t1 : tys, res)
unravelType t = ([], [], [], t)
unSigType :: Type -> Type
unSigType (SigT t _) = t
unSigType (AppT f x) = AppT (unSigType f) (unSigType x)
unSigType (ForallT tvbs ctxt t) =
ForallT tvbs (map unSigPred ctxt) (unSigType t)
#if __GLASGOW_HASKELL__ >= 800
unSigType (InfixT t1 n t2) = InfixT (unSigType t1) n (unSigType t2)
unSigType (UInfixT t1 n t2) = UInfixT (unSigType t1) n (unSigType t2)
unSigType (ParensT t) = ParensT (unSigType t)
#endif
#if __GLASGOW_HASKELL__ >= 807
unSigType (AppKindT t k) = AppKindT (unSigType t) (unSigType k)
unSigType (ImplicitParamT n t) = ImplicitParamT n (unSigType t)
#endif
unSigType t = t
unSigPred :: Pred -> Pred
#if __GLASGOW_HASKELL__ >= 710
unSigPred = unSigType
#else
unSigPred (ClassP n tys) = ClassP n (map unSigType tys)
unSigPred (EqualP t1 t2) = EqualP (unSigType t1) (unSigType t2)
#endif
unfoldType :: Type -> (Type, [TypeArg])
unfoldType = go []
where
go :: [TypeArg] -> Type -> (Type, [TypeArg])
go acc (ForallT _ _ ty) = go acc ty
go acc (AppT ty1 ty2) = go (TANormal ty2:acc) ty1
go acc (SigT ty _) = go acc ty
#if __GLASGOW_HASKELL__ >= 800
go acc (ParensT ty) = go acc ty
#endif
#if __GLASGOW_HASKELL__ >= 807
go acc (AppKindT ty ki) = go (TyArg ki:acc) ty
#endif
go acc ty = (ty, acc)
data TypeArg
= TANormal Type
| TyArg Kind
deriving (Eq, Show, Typeable, Data)
applyType :: Type -> [TypeArg] -> Type
applyType = foldl apply
where
apply :: Type -> TypeArg -> Type
apply f (TANormal x) = f `AppT` x
apply f (TyArg _x) =
#if __GLASGOW_HASKELL__ >= 807
f `AppKindT` _x
#else
f
#endif
filterTANormals :: [TypeArg] -> [Type]
filterTANormals = mapMaybe getTANormal
where
getTANormal :: TypeArg -> Maybe Type
getTANormal (TANormal t) = Just t
getTANormal (TyArg {}) = Nothing
unSigTypeArg :: TypeArg -> TypeArg
unSigTypeArg (TANormal t) = TANormal (unSigType t)
unSigTypeArg (TyArg k) = TyArg (unSigType k)
probablyWrongUnTypeArg :: TypeArg -> Type
probablyWrongUnTypeArg (TANormal t) = t
probablyWrongUnTypeArg (TyArg k) = k
nameOccursIn :: Data a => Name -> a -> Bool
nameOccursIn n = everything (||) $ mkQ False (== n)
allNamesIn :: Data a => a -> [Name]
allNamesIn = everything (++) $ mkQ [] (:[])
extractBoundNamesStmt :: Stmt -> OSet Name
extractBoundNamesStmt (BindS pat _) = extractBoundNamesPat pat
extractBoundNamesStmt (LetS decs) = foldMap extractBoundNamesDec decs
extractBoundNamesStmt (NoBindS _) = OS.empty
extractBoundNamesStmt (ParS stmtss) = foldMap (foldMap extractBoundNamesStmt) stmtss
#if __GLASGOW_HASKELL__ >= 807
extractBoundNamesStmt (RecS stmtss) = foldMap extractBoundNamesStmt stmtss
#endif
extractBoundNamesDec :: Dec -> OSet Name
extractBoundNamesDec (FunD name _) = OS.singleton name
extractBoundNamesDec (ValD pat _ _) = extractBoundNamesPat pat
extractBoundNamesDec _ = OS.empty
extractBoundNamesPat :: Pat -> OSet Name
extractBoundNamesPat (LitP _) = OS.empty
extractBoundNamesPat (VarP name) = OS.singleton name
extractBoundNamesPat (TupP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (UnboxedTupP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (ConP _ pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (InfixP p1 _ p2) = extractBoundNamesPat p1 `OS.union`
extractBoundNamesPat p2
extractBoundNamesPat (UInfixP p1 _ p2) = extractBoundNamesPat p1 `OS.union`
extractBoundNamesPat p2
extractBoundNamesPat (ParensP pat) = extractBoundNamesPat pat
extractBoundNamesPat (TildeP pat) = extractBoundNamesPat pat
extractBoundNamesPat (BangP pat) = extractBoundNamesPat pat
extractBoundNamesPat (AsP name pat) = OS.singleton name `OS.union`
extractBoundNamesPat pat
extractBoundNamesPat WildP = OS.empty
extractBoundNamesPat (RecP _ field_pats) = let (_, pats) = unzip field_pats in
foldMap extractBoundNamesPat pats
extractBoundNamesPat (ListP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (SigP pat _) = extractBoundNamesPat pat
extractBoundNamesPat (ViewP _ pat) = extractBoundNamesPat pat
#if __GLASGOW_HASKELL__ >= 801
extractBoundNamesPat (UnboxedSumP pat _ _) = extractBoundNamesPat pat
#endif
#if __GLASGOW_HASKELL__ >= 800
newtype MagicIP name a r = MagicIP (IP name a => r)
bindIP :: forall name a r. a -> (IP name a => r) -> r
bindIP val k = (unsafeCoerce (MagicIP @name k) :: a -> r) val
#endif
splitAtList :: [a] -> [b] -> ([b], [b])
splitAtList [] x = ([], x)
splitAtList (_ : t) (x : xs) =
let (as, bs) = splitAtList t xs in
(x : as, bs)
splitAtList (_ : _) [] = ([], [])
thdOf3 :: (a,b,c) -> c
thdOf3 (_,_,c) = c
thirdOf3 :: (a -> b) -> (c, d, a) -> (c, d, b)
thirdOf3 f (c, d, a) = (c, d, f a)
concatMapM :: (Monad monad, Monoid monoid, Traversable t)
=> (a -> monad monoid) -> t a -> monad monoid
concatMapM fn list = do
bss <- mapM fn list
return $ fold bss
mapAccumLM :: Monad m
=> (acc -> x -> m (acc, y))
-> acc
-> [x]
-> m (acc, [y])
mapAccumLM _ s [] = return (s, [])
mapAccumLM f s (x:xs) = do
(s1, x') <- f s x
(s2, xs') <- mapAccumLM f s1 xs
return (s2, x' : xs')
mapMaybeM :: Monad m => (a -> m (Maybe b)) -> [a] -> m [b]
mapMaybeM _ [] = return []
mapMaybeM f (x:xs) = do
y <- f x
ys <- mapMaybeM f xs
return $ case y of
Nothing -> ys
Just z -> z : ys
expectJustM :: Fail.MonadFail m => String -> Maybe a -> m a
expectJustM _ (Just x) = return x
expectJustM err Nothing = Fail.fail err
firstMatch :: (a -> Maybe b) -> [a] -> Maybe b
firstMatch f xs = listToMaybe $ mapMaybe f xs
topEverywhereM :: (Typeable a, Data b, Monad m) => (a -> m a) -> b -> m b
topEverywhereM handler =
gmapM (topEverywhereM handler) `extM` handler
isInfixDataCon :: String -> Bool
isInfixDataCon (':':_) = True
isInfixDataCon _ = False
isTypeKindName :: Name -> Bool
isTypeKindName n = n == typeKindName
#if __GLASGOW_HASKELL__ < 805
|| n == starKindName
|| n == uniStarKindName
#endif
typeKindName :: Name
#if __GLASGOW_HASKELL__ >= 800
typeKindName = ''Kind.Type
#else
typeKindName = starKindName
#endif
#if __GLASGOW_HASKELL__ < 805
starKindName :: Name
#if __GLASGOW_HASKELL__ >= 800
starKindName = ''(Kind.*)
#else
starKindName = mkNameG_tc "ghc-prim" "GHC.Prim" "*"
#endif
uniStarKindName :: Name
#if __GLASGOW_HASKELL__ >= 800
uniStarKindName = ''(Kind.★)
#else
uniStarKindName = starKindName
#endif
#endif