{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeApplications #-}
module Refact.Internal
( apply
, runRefactoring
, applyRefactorings
, Verbosity(..)
, rigidLayout
, removeOverlap
, refactOptions
, type Errors
, onError
, mkErr
) where
import Language.Haskell.GHC.ExactPrint
import Language.Haskell.GHC.ExactPrint.Annotate
import Language.Haskell.GHC.ExactPrint.Delta
import Language.Haskell.GHC.ExactPrint.Parsers
import Language.Haskell.GHC.ExactPrint.Print
import Language.Haskell.GHC.ExactPrint.Types hiding (GhcPs, GhcTc, GhcRn)
import Language.Haskell.GHC.ExactPrint.Utils
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Maybe (MaybeT(..))
import Control.Monad.Trans.State
import Data.Char (isAlphaNum)
import Data.Data
import Data.Functor.Identity (Identity(..))
import Data.Generics hiding (GT)
import qualified Data.Map as Map
import Data.Maybe
import Data.List
import Data.Ord
import System.IO
import System.IO.Unsafe
import Debug.Trace
#if __GLASGOW_HASKELL__ >= 810
import GHC.Hs.Expr as GHC hiding (Stmt)
import GHC.Hs.ImpExp
import GHC.Hs hiding (Pat, Stmt)
import Outputable hiding ((<>))
import ErrUtils
import Bag
#else
import HsExpr as GHC hiding (Stmt)
import HsImpExp
import HsSyn hiding (Pat, Stmt, noExt)
#endif
import SrcLoc
import qualified GHC hiding (parseModule)
import qualified Name as GHC
import qualified RdrName as GHC
import Refact.Fixity
import Refact.Types hiding (SrcSpan)
import qualified Refact.Types as R
import Refact.Utils (Stmt, Pat, Name, Decl, M, Module, Expr, Type, FunBind
, modifyAnnKey, replaceAnnKey, Import, toGhcSrcSpan, setSrcSpanFile)
#if __GLASGOW_HASKELL__ >= 810
type Errors = ErrorMessages
onError :: String -> Errors -> a
onError s = pprPanic s . vcat . pprErrMsgBagWithLoc
#else
type Errors = (SrcSpan, String)
onError :: String -> Errors -> a
onError _ = error . show
#endif
#if __GLASGOW_HASKELL__ <= 806
composeSrcSpan :: a -> a
composeSrcSpan = id
decomposeSrcSpan :: a -> a
decomposeSrcSpan = id
type SrcSpanLess a = a
#endif
refactOptions :: PrintOptions Identity String
refactOptions = stringOptions { epRigidity = RigidLayout }
rigidLayout :: DeltaOptions
rigidLayout = deltaOptions RigidLayout
apply
:: Maybe (Int, Int)
-> Bool
-> [(String, [Refactoring R.SrcSpan])]
-> FilePath
-> Verbosity
-> Anns
-> Module
-> IO String
apply mpos step inp file verb as0 m0 = do
let noOverlapInp = removeOverlap verb inp
allRefacts = (fmap . fmap . fmap) (toGhcSrcSpan file) <$> noOverlapInp
posFilter (_, rs) =
case mpos of
Nothing -> True
Just p -> any (flip spans p . pos) rs
filtRefacts = filter posFilter allRefacts
refacts = concatMap snd filtRefacts
when (verb >= Normal) (traceM $ "Applying " ++ show (length refacts) ++ " hints")
when (verb == Loud) (traceM $ show filtRefacts)
(as, m) <- if step
then fromMaybe (as0, m0) <$> runMaybeT (refactoringLoop as0 m0 filtRefacts)
else pure . flip evalState 0 $
foldM (uncurry runRefactoring) (as0, m0) refacts
pure . runIdentity $ exactPrintWithOptions refactOptions m as
data LoopOption = LoopOption
{ desc :: String
, perform :: MaybeT IO (Anns, Module) }
refactoringLoop :: Anns -> Module -> [(String, [Refactoring GHC.SrcSpan])]
-> MaybeT IO (Anns, Module)
refactoringLoop as m [] = pure (as, m)
refactoringLoop as m ((_, []): rs) = refactoringLoop as m rs
refactoringLoop as m hints@((hintDesc, rs): rss) =
do inp <- liftIO $ do
putStrLn hintDesc
putStrLn $ "Apply hint [" ++ intercalate ", " (map fst opts) ++ "]"
withFile "/dev/tty" ReadMode hGetLine
maybe loopHelp perform (lookup inp opts)
where
opts =
[ ("y", LoopOption "Apply current hint" yAction)
, ("n", LoopOption "Don't apply the current hint" (refactoringLoop as m rss))
, ("q", LoopOption "Apply no further hints" (return (as, m)))
, ("d", LoopOption "Discard previous changes" mzero )
, ("v", LoopOption "View current file" (liftIO (putStrLn (exactPrint m as))
>> refactoringLoop as m hints))
, ("?", LoopOption "Show this help menu" loopHelp)]
loopHelp = do
liftIO . putStrLn . unlines . map mkLine $ opts
refactoringLoop as m hints
mkLine (c, opt) = c ++ " - " ++ desc opt
yAction =
let (!r1, !r2) = flip evalState 0 $ foldM (uncurry runRefactoring) (as, m) rs
in do
exactPrint r2 r1 `seq` return ()
refactoringLoop r1 r2 rss
applyRefactorings
:: Maybe (Int, Int)
-> [(String, [Refactoring R.SrcSpan])]
-> FilePath
-> IO String
applyRefactorings optionsPos inp file = do
(as, m) <- either (onError "apply") (uncurry applyFixities)
<$> parseModuleWithOptions rigidLayout file
apply optionsPos False inp file Silent as m
data Verbosity = Silent | Normal | Loud deriving (Eq, Show, Ord)
removeOverlap :: Verbosity -> [(String, [Refactoring R.SrcSpan])] -> [(String, [Refactoring R.SrcSpan])]
removeOverlap verb = dropOverlapping . sortBy f . summarize
where
summarize :: [(String, [Refactoring R.SrcSpan])] -> [(String, (R.SrcSpan, [Refactoring R.SrcSpan]))]
summarize ideas = [ (s, (foldr1 summary (map pos rs), rs)) | (s, rs) <- ideas, not (null rs) ]
summary (R.SrcSpan sl1 sc1 el1 ec1)
(R.SrcSpan sl2 sc2 el2 ec2) =
let (sl, sc) = case compare sl1 sl2 of
LT -> (sl1, sc1)
EQ -> (sl1, min sc1 sc2)
GT -> (sl2, sc2)
(el, ec) = case compare el1 el2 of
LT -> (el2, ec2)
EQ -> (el2, max ec1 ec2)
GT -> (el1, ec1)
in R.SrcSpan sl sc el ec
f (_,(s1,_)) (_,(s2,_)) =
comparing startLine s1 s2 <>
comparing startCol s1 s2 <>
comparing endLine s2 s1 <>
comparing endCol s2 s1
dropOverlapping [] = []
dropOverlapping (p:ps) = go p ps
go (s,(_,rs)) [] = [(s,rs)]
go p@(s,(_,rs)) (x:xs)
| p `overlaps` x = (if verb > Silent
then trace ("Ignoring " ++ show (snd (snd x)) ++ " due to overlap.")
else id) go p xs
| otherwise = (s,rs) : go x xs
overlaps (_,(s1,_)) (_,(s2,_)) =
case compare (startLine s2) (endLine s1) of
LT -> True
EQ -> startCol s2 <= endCol s1
GT -> False
getSeed :: State Int Int
getSeed = get <* modify (+1)
runRefactoring :: Data a => Anns -> a -> Refactoring GHC.SrcSpan -> State Int (Anns, a)
runRefactoring as m r@Replace{} = do
seed <- getSeed
return $ case rtype r of
Expr -> replaceWorker as m parseExpr seed r
Decl -> replaceWorker as m parseDecl seed r
Type -> replaceWorker as m parseType seed r
Pattern -> replaceWorker as m parsePattern seed r
Stmt -> replaceWorker as m parseStmt seed r
Bind -> replaceWorker as m parseBind seed r
R.Match -> replaceWorker as m parseMatch seed r
ModuleName -> replaceWorker as m (parseModuleName (pos r)) seed r
Import -> replaceWorker as m parseImport seed r
runRefactoring as m ModifyComment{..} =
return (Map.map go as, m)
where
go a@Ann{ annPriorComments, annsDP } =
a { annsDP = map changeComment annsDP
, annPriorComments = map (first change) annPriorComments }
changeComment (AnnComment d, dp) = (AnnComment (change d), dp)
changeComment e = e
change old@Comment{..}= if ss2pos commentIdentifier == ss2pos pos
then old { commentContents = newComment}
else old
runRefactoring as m Delete{rtype, pos} = do
let f = case rtype of
Stmt -> doDeleteStmt ((/= pos) . getLoc)
Import -> doDeleteImport ((/= pos) . getLoc)
_ -> id
return (as, f m)
runRefactoring as m InsertComment{..} =
let exprkey = mkAnnKey (findDecl m pos) in
return (insertComment exprkey newComment as, m)
runRefactoring as m RemoveAsKeyword{..} =
return (as, removeAsKeyword m)
where
removeAsKeyword = everywhere (mkT go)
go :: LImportDecl GHC.GhcPs -> LImportDecl GHC.GhcPs
go imp@(GHC.L l i) | l == pos = GHC.L l (i { ideclAs = Nothing })
| otherwise = imp
mkErr :: GHC.DynFlags -> SrcSpan -> String -> Errors
#if __GLASGOW_HASKELL__ >= 810
mkErr df l s = unitBag (mkPlainErrMsg df l (text s))
#else
mkErr = const (,)
#endif
parseModuleName :: GHC.SrcSpan -> Parser (GHC.Located GHC.ModuleName)
parseModuleName ss _ _ s =
let newMN = GHC.L ss (GHC.mkModuleName s)
newAnns = relativiseApiAnns newMN (Map.empty, Map.empty)
in return (newAnns, newMN)
parseBind :: Parser (GHC.LHsBind GHC.GhcPs)
parseBind dyn fname s =
case parseDecl dyn fname s of
Right (as, GHC.L l (GHC.ValD _ b)) -> Right (as, GHC.L l b)
Right (_, GHC.L l _) -> Left (mkErr dyn l "Not a HsBind")
Left e -> Left e
parseMatch :: Parser (GHC.LMatch GHC.GhcPs (GHC.LHsExpr GHC.GhcPs))
parseMatch dyn fname s =
case parseBind dyn fname s of
Right (as, GHC.L l GHC.FunBind{fun_matches}) ->
case unLoc (GHC.mg_alts fun_matches) of
[x] -> Right (as, x)
_ -> Left (mkErr dyn l "Not a single match")
Right (_, GHC.L l _) -> Left (mkErr dyn l "Not a funbind")
Left e -> Left e
substTransform :: (Data a, Data b) => b -> [(String, GHC.SrcSpan)] -> a -> M a
substTransform m ss = everywhereM (mkM (typeSub m ss)
`extM` identSub m ss
`extM` patSub m ss
`extM` stmtSub m ss
`extM` exprSub m ss
)
stmtSub :: Data a => a -> [(String, GHC.SrcSpan)] -> Stmt -> M Stmt
stmtSub m subs old@(GHC.L _ (BodyStmt _ (GHC.L _ (HsVar _ (L _ name))) _ _) ) =
resolveRdrName m (findStmt m) old subs name
stmtSub _ _ e = return e
patSub :: Data a => a -> [(String, GHC.SrcSpan)] -> Pat -> M Pat
patSub m subs old@(GHC.L _ (VarPat _ (L _ name))) =
resolveRdrName m (findPat m) old subs name
patSub _ _ e = return e
typeSub :: Data a => a -> [(String, GHC.SrcSpan)] -> Type -> M Type
typeSub m subs old@(GHC.L _ (HsTyVar _ _ (L _ name))) =
resolveRdrName m (findType m) old subs name
typeSub _ _ e = return e
exprSub :: Data a => a -> [(String, GHC.SrcSpan)] -> Expr -> M Expr
exprSub m subs old@(GHC.L _ (HsVar _ (L _ name))) =
resolveRdrName m (findExpr m) old subs name
exprSub _ _ e = return e
identSub :: Data a => a -> [(String, GHC.SrcSpan)] -> FunBind -> M FunBind
identSub m subs old@(GHC.FunRhs (GHC.L _ name) _ _) =
resolveRdrName' subst (findName m) old subs name
where
subst :: FunBind -> Name -> M FunBind
subst (GHC.FunRhs n b s) new = do
let fakeExpr :: Located (GHC.Pat GhcPs)
fakeExpr = GHC.L (getLoc new) (GHC.VarPat noExt new)
modify (\r -> replaceAnnKey r (mkAnnKey n) (mkAnnKey fakeExpr) (mkAnnKey new) (mkAnnKey fakeExpr))
return $ GHC.FunRhs new b s
subst o _ = return o
identSub _ _ e = return e
resolveRdrName' :: (a -> b -> M a)
-> (SrcSpan -> b)
-> a
-> [(String, GHC.SrcSpan)]
-> GHC.RdrName
-> M a
resolveRdrName' g f old subs name =
case name of
GHC.Unqual (GHC.occNameString . GHC.occName -> oname)
-> case lookup oname subs of
Just (f -> new) -> g old new
Nothing -> return old
_ -> return old
resolveRdrName :: (Data old, Data a)
=> a
-> (SrcSpan -> Located old)
-> Located old
-> [(String, SrcSpan)]
-> GHC.RdrName
-> M (Located old)
resolveRdrName m = resolveRdrName' (modifyAnnKey m)
insertComment :: AnnKey -> String
-> Map.Map AnnKey Annotation
-> Map.Map AnnKey Annotation
insertComment k s as =
let comment = Comment s GHC.noSrcSpan Nothing in
Map.adjust (\a@Ann{..} -> a { annPriorComments = annPriorComments ++ [(comment, DP (1,0))]
, annEntryDelta = DP (1,0) }) k as
#if __GLASGOW_HASKELL__ <= 806
doGenReplacement
:: forall ast a. (Data ast, Data a)
=> a
-> (GHC.Located ast -> Bool)
-> GHC.Located ast
-> GHC.Located ast
-> State (Anns, Bool) (GHC.Located ast)
#else
doGenReplacement
:: forall ast a. (Data (SrcSpanLess ast), HasSrcSpan ast, Data a)
=> a
-> (ast -> Bool)
-> ast
-> ast
-> State (Anns, Bool) ast
#endif
doGenReplacement m p new old
| p old = do
anns <- gets fst
let n = decomposeSrcSpan new
o = decomposeSrcSpan old
newAnns = execState (modifyAnnKey m o n) anns
put (newAnns, True)
pure new
| Just Refl <- eqT @(SrcSpanLess ast) @(HsDecl GHC.GhcPs)
, L _ (ValD xvald newBind@FunBind{}) <- decomposeSrcSpan new
, Just (oldNoLocal, oldLocal) <- stripLocalBind (decomposeSrcSpan old)
, newLoc@(RealSrcSpan newLocReal) <- getLoc new
, p (composeSrcSpan oldNoLocal) = do
anns <- gets fst
let n = decomposeSrcSpan new
o = decomposeSrcSpan old
intAnns = execState (modifyAnnKey m o n) anns
newFile = srcSpanFile newLocReal
newLocal = everywhere (mkT $ setSrcSpanFile newFile) oldLocal
newLocalLoc = getLoc newLocal
ensureLoc = combineSrcSpans newLocalLoc
newMG = fun_matches newBind
L locMG [L locMatch newMatch] = mg_alts newMG
newGRHSs = m_grhss newMatch
finalLoc = ensureLoc newLoc
newWithLocalBinds = setLocalBind newLocal xvald newBind finalLoc
newMG (ensureLoc locMG) newMatch (ensureLoc locMatch) newGRHSs
addLocalBindsToAnns = addAnnWhere
. Map.fromList
. map (first (expandTemplateLoc . updateFile . expandGRHSLoc))
. Map.toList
where
addAnnWhere :: Anns -> Anns
addAnnWhere oldAnns =
let oldAnns' = Map.toList oldAnns
po = \case
(AnnKey loc@(RealSrcSpan r) con, _) ->
loc == getLoc old && con == CN "Match" && srcSpanFile r /= newFile
_ -> False
pn = \case
(AnnKey loc@(RealSrcSpan r) con, _) ->
loc == finalLoc && con == CN "Match" && srcSpanFile r == newFile
_ -> False
in fromMaybe oldAnns $ do
oldAnn <- snd <$> find po oldAnns'
annWhere <- find ((== G GHC.AnnWhere) . fst) (annsDP oldAnn)
newKey <- fst <$> find pn oldAnns'
pure $ Map.adjust (\ann -> ann {annsDP = annsDP ann ++ [annWhere]}) newKey oldAnns
expandGRHSLoc = \case
AnnKey loc@(RealSrcSpan r) con
| con == CN "GRHS", srcSpanFile r == newFile -> AnnKey (ensureLoc loc) con
other -> other
updateFile = \case
AnnKey loc con
| loc `isSubspanOf` getLoc oldLocal -> AnnKey (setSrcSpanFile newFile loc) con
other -> other
expandTemplateLoc = \case
AnnKey loc con
| loc == newLoc -> AnnKey finalLoc con
other -> other
newAnns = addLocalBindsToAnns intAnns
put (newAnns, True)
pure $ composeSrcSpan newWithLocalBinds
| otherwise = pure old
stripLocalBind
:: LHsDecl GHC.GhcPs
-> Maybe (LHsDecl GHC.GhcPs, LHsLocalBinds GHC.GhcPs)
stripLocalBind = \case
L _ (ValD xvald origBind@FunBind{})
| let origMG = fun_matches origBind
, L locMG [L locMatch origMatch] <- mg_alts origMG
, let origGRHSs = m_grhss origMatch
, [L _ (GRHS _ _ (L loc2 _))] <- grhssGRHSs origGRHSs ->
let loc1 = getLoc (fun_id origBind)
newLoc = combineSrcSpans loc1 loc2
withoutLocalBinds = setLocalBind (noLoc (EmptyLocalBinds noExt)) xvald origBind newLoc origMG locMG
origMatch locMatch origGRHSs
in Just (withoutLocalBinds, grhssLocalBinds origGRHSs)
_ -> Nothing
setLocalBind
:: LHsLocalBinds GHC.GhcPs
-> XValD GhcPs
-> HsBind GhcPs
-> SrcSpan
-> MatchGroup GhcPs (LHsExpr GhcPs)
-> SrcSpan
-> Match GhcPs (LHsExpr GhcPs)
-> SrcSpan
-> GRHSs GhcPs (LHsExpr GhcPs)
-> LHsDecl GhcPs
setLocalBind newLocalBinds xvald origBind newLoc origMG locMG origMatch locMatch origGRHSs =
L newLoc (ValD xvald newBind)
where
newGRHSs = origGRHSs{grhssLocalBinds = newLocalBinds}
newMatch = origMatch{m_grhss = newGRHSs}
newMG = origMG{mg_alts = L locMG [L locMatch newMatch]}
newBind = origBind{fun_matches = newMG}
#if __GLASGOW_HASKELL__ <= 806
replaceWorker :: (Annotate a, Data mod)
=> Anns
-> mod
-> Parser (GHC.Located a)
-> Int
-> Refactoring GHC.SrcSpan
-> (Anns, mod)
#else
replaceWorker :: (Annotate a, HasSrcSpan a, Data mod, Data (SrcSpanLess a))
=> Anns
-> mod
-> Parser a
-> Int
-> Refactoring GHC.SrcSpan
-> (Anns, mod)
#endif
replaceWorker as m parser seed Replace{..} =
let replExprLocation = pos
uniqueName = "template" ++ show seed
p s = unsafePerformIO (withDynFlags (\d -> parser d uniqueName s))
(relat, template) = case p orig of
Right xs -> xs
Left err -> onError "replaceWorked" err
(newExpr, newAnns) = runState (substTransform m subts template) (mergeAnns as relat)
lst = listToMaybe . reverse . GHC.occNameString . GHC.rdrNameOcc
adjacent (srcSpanEnd -> RealSrcLoc loc1) (srcSpanStart -> RealSrcLoc loc2) = loc1 == loc2
adjacent _ _ = False
ensureSpace :: Anns -> Anns
ensureSpace = fromMaybe id $ do
(L _ (HsVar _ (L _ newName))) :: LHsExpr GhcPs <- cast newExpr
hd <- listToMaybe $ case newName of
GHC.Unqual occName -> GHC.occNameString occName
GHC.Qual moduleName _ -> GHC.moduleNameString moduleName
GHC.Orig modu _ -> GHC.moduleNameString (GHC.moduleName modu)
GHC.Exact name -> GHC.occNameString (GHC.nameOccName name)
guard $ isAlphaNum hd
let prev :: [LHsExpr GhcPs] =
listify
(\case
(L loc (HsVar _ (L _ rdr))) -> maybe False isAlphaNum (lst rdr) && adjacent loc pos
_ -> False
)
m
guard . not . null $ prev
pure . flip Map.adjust (mkAnnKey newExpr) $ \ann ->
if annEntryDelta ann == DP (0, 0)
then ann { annEntryDelta = DP (0, 1) }
else ann
replacementPred (GHC.L l _) = l == replExprLocation
transformation = everywhereM (mkM (doGenReplacement m (replacementPred . decomposeSrcSpan) newExpr))
in case runState (transformation m) (newAnns, False) of
(finalM, (finalAs, True)) -> (ensureSpace finalAs, finalM)
_ -> (as, m)
replaceWorker as m _ _ _ = (as, m)
findGen :: forall ast a . (Data ast, Data a) => String -> a -> SrcSpan -> GHC.Located ast
findGen s m ss = fromMaybe (error (s ++ " " ++ showGhc ss)) (doTrans m)
where
doTrans :: a -> Maybe (GHC.Located ast)
doTrans = something (mkQ Nothing (findLargestExpression ss))
findExpr :: Data a => a -> SrcSpan -> Expr
findExpr = findGen "expr"
findPat :: Data a => a -> SrcSpan -> Pat
findPat = findGen "pat"
findType :: Data a => a -> SrcSpan -> Type
findType = findGen "type"
findDecl :: Data a => a -> SrcSpan -> Decl
findDecl = findGen "decl"
findStmt :: Data a => a -> SrcSpan -> Stmt
findStmt = findGen "stmt"
findName :: Data a => a -> SrcSpan -> Name
findName = findGen "name"
findLargestExpression :: SrcSpan -> GHC.Located ast
-> Maybe (GHC.Located ast)
findLargestExpression ss e@(GHC.L l _) =
if l == ss
then Just e
else Nothing
doDeleteStmt :: Data a => (Stmt -> Bool) -> a -> a
doDeleteStmt p = everywhere (mkT (filter p))
doDeleteImport :: Data a => (Import -> Bool) -> a -> a
doDeleteImport p = everywhere (mkT (filter p))