module Language.Haskell.Refact.Utils.GhcUtils (
everywhereM'
, everywhereMStaged'
, everywhereStaged
, everywhereStaged'
, listifyStaged
, zeverywhereStaged
, zopenStaged
, zsomewhereStaged
, transZ
, transZM
, zopenStaged'
, ztransformStagedM
, upUntil
, findAbove
) where
import qualified Data.Generics as SYB
import qualified GHC.SYB.Utils as SYB
import Control.Monad
import Data.Data
import Data.Maybe
import qualified Data.Generics.Zipper as Z
everywhereMStaged' :: Monad m => SYB.Stage -> SYB.GenericM m -> SYB.GenericM m
everywhereMStaged' stage f x
#if __GLASGOW_HASKELL__ <= 708
| checkItemStage stage x = return x
#endif
| otherwise = do x' <- f x
gmapM (everywhereMStaged' stage f) x'
everywhereM' :: Monad m => SYB.GenericM m -> SYB.GenericM m
everywhereM' f x
= do x' <- f x
gmapM (everywhereM' f) x'
everywhereStaged :: SYB.Stage -> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhereStaged stage f x
#if __GLASGOW_HASKELL__ <= 708
| checkItemStage stage x = x
#endif
| otherwise = (f . gmapT (everywhereStaged stage f)) x
everywhereStaged' :: SYB.Stage -> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhereStaged' stage f x
#if __GLASGOW_HASKELL__ <= 708
| checkItemStage stage x = x
#endif
| otherwise = (gmapT (everywhereStaged stage f) . f) x
#if __GLASGOW_HASKELL__ <= 708
checkItemStage :: (Typeable a, Data a) => SYB.Stage -> a -> Bool
checkItemStage stage x = (checkItemStage1 stage x)
#if __GLASGOW_HASKELL__ > 704
|| (checkItemStage2 stage x)
#endif
checkItemStage1 :: (Typeable a) => SYB.Stage -> a -> Bool
checkItemStage1 stage x = (const False `SYB.extQ` postTcType `SYB.extQ` fixity `SYB.extQ` nameSet) x
where nameSet = const (stage `elem` [SYB.Parser,SYB.TypeChecker]) :: GHC.NameSet -> Bool
postTcType = const (stage < SYB.TypeChecker ) :: GHC.PostTcType -> Bool
fixity = const (stage < SYB.Renamer ) :: GHC.Fixity -> Bool
#if __GLASGOW_HASKELL__ > 704
checkItemStage2 :: Data a => SYB.Stage -> a -> Bool
checkItemStage2 stage x = (const False `SYB.ext1Q` hsWithBndrs) x
where
hsWithBndrs = const (stage < SYB.Renamer) :: GHC.HsWithBndrs a -> Bool
#endif
checkItemRenamer :: (Data a, Typeable a) => a -> Bool
checkItemRenamer x = checkItemStage SYB.Renamer x
#endif
listifyStaged
:: (Data a, Typeable a1) => SYB.Stage -> (a1 -> Bool) -> a -> [a1]
listifyStaged stage p = SYB.everythingStaged stage (++) [] ([] `SYB.mkQ` (\x -> [ x | p x ]))
#if __GLASGOW_HASKELL__ <= 708
full_tdTUGhc :: (MonadPlus m, Monoid a) => TU a m -> TU a m
full_tdTUGhc s = op2TU mappend s (allTUGhc' (full_tdTUGhc s))
stop_tdTUGhc :: (MonadPlus m, Monoid a) => TU a m -> TU a m
stop_tdTUGhc s = (s `choiceTU` (allTUGhc' (stop_tdTUGhc s)))
stop_tdTPGhc :: MonadPlus m => TP m -> TP m
stop_tdTPGhc s = s `choiceTP` (allTPGhc (stop_tdTPGhc s))
allTUGhc' :: (MonadPlus m, Monoid a) => TU a m -> TU a m
allTUGhc' = allTUGhc mappend mempty
once_tdTPGhc :: MonadPlus m => TP m -> TP m
once_tdTPGhc s = s `choiceTP` (oneTPGhc (once_tdTPGhc s))
once_buTPGhc :: MonadPlus m => TP m -> TP m
once_buTPGhc s = (oneTPGhc (once_buTPGhc s)) `choiceTP` s
oneTPGhc :: MonadPlus m => TP m -> TP m
oneTPGhc s = ifTP checkItemRenamer' (const failTP) (oneTP s)
allTPGhc :: MonadPlus m => TP m -> TP m
allTPGhc s = ifTP checkItemRenamer' (const failTP) (oneTP s)
#endif
#if __GLASGOW_HASKELL__ <= 708
allTUGhc :: (MonadPlus m) => (a -> a -> a) -> a -> TU a m -> TU a m
allTUGhc op2 u s = ifTU checkItemRenamer' (const $ constTU u) (allTU op2 u s)
#endif
#if __GLASGOW_HASKELL__ <= 708
checkItemStage' :: forall m. (MonadPlus m) => SYB.Stage -> TU () m
checkItemStage' stage = failTU `adhocTU` postTcType `adhocTU` fixity `adhocTU` nameSet
where nameSet = const (guard $ stage `elem` [SYB.Parser,SYB.TypeChecker]) :: GHC.NameSet -> m ()
postTcType = const (guard $ stage<SYB.TypeChecker) :: GHC.PostTcType -> m ()
fixity = const (guard $ stage<SYB.Renamer) :: GHC.Fixity -> m ()
checkItemRenamer' :: (MonadPlus m) => TU () m
checkItemRenamer' = checkItemStage' SYB.Renamer
#endif
zeverywhereStaged :: (Typeable a) => SYB.Stage -> SYB.GenericT -> Z.Zipper a -> Z.Zipper a
zeverywhereStaged stage f z
#if __GLASGOW_HASKELL__ <= 708
| checkZipperStaged stage z = z
#endif
| otherwise = Z.trans f (Z.downT g z)
where
g z' = Z.leftT g (zeverywhereStaged stage f z')
zopenStaged :: (Typeable a) => SYB.Stage -> SYB.GenericQ Bool -> Z.Zipper a -> [Z.Zipper a]
zopenStaged stage q z
#if __GLASGOW_HASKELL__ <= 708
| checkZipperStaged stage z = []
#endif
| Z.query q z = [z]
| otherwise = reverse $ Z.downQ [] g z
where
g z' = (zopenStaged stage q z') ++ (Z.leftQ [] g z')
zsomewhereStaged :: (MonadPlus m) => SYB.Stage -> SYB.GenericM m -> Z.Zipper a -> m (Z.Zipper a)
zsomewhereStaged stage f z
#if __GLASGOW_HASKELL__ <= 708
| checkZipperStaged stage z = return z
#endif
| otherwise = Z.transM f z `mplus` Z.downM mzero (g . Z.leftmost) z
where
g z' = Z.transM f z `mplus` Z.rightM mzero (zsomewhereStaged stage f) z'
transZ :: SYB.Stage -> SYB.GenericQ Bool -> (SYB.Stage -> Z.Zipper a -> Z.Zipper a) -> Z.Zipper a -> Z.Zipper a
transZ stage q t z
| Z.query q z = t stage z
| otherwise = z
transZM :: Monad m
=> SYB.Stage
-> SYB.GenericQ Bool
-> (SYB.Stage -> Z.Zipper a -> m (Z.Zipper a))
-> Z.Zipper a
-> m (Z.Zipper a)
transZM stage q t z
| Z.query q z = t stage z
| otherwise = return z
upUntil :: SYB.GenericQ Bool -> Z.Zipper a -> Maybe (Z.Zipper a)
upUntil q z
| Z.query q z = Just z
| otherwise = Z.upQ Nothing (upUntil q) z
findAbove :: (Data a) => (a -> Bool) -> Z.Zipper a -> Maybe a
findAbove cond z = do
zu <- upUntil (False `SYB.mkQ` cond) z
res <- (Z.getHole zu)
return res
zopenStaged' :: (Typeable a)
=> SYB.Stage
-> SYB.GenericQ (Maybe b)
-> Z.Zipper a
-> [(Z.Zipper a,b)]
zopenStaged' stage q z
| isJust zq = [(z,fromJust zq)]
| otherwise = reverse $ Z.downQ [] g z
where
g z' = (zopenStaged' stage q z') ++ (Z.leftQ [] g z')
zq = Z.query q z
ztransformStagedM :: (Typeable a,Monad m)
=> SYB.Stage
-> SYB.GenericQ (Maybe (SYB.Stage -> Z.Zipper a -> m (Z.Zipper a)))
-> Z.Zipper a
-> m (Z.Zipper a)
ztransformStagedM stage q z = do
let zs = zopenStaged' stage q z
z' <- case zs of
[(zz,t)] -> t stage zz
_ -> return z
return z'