-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Retrie.Replace
  ( replace
  , Replacement(..)
  , Change(..)
  ) where

import Control.Monad.Trans.Class
import Control.Monad.Writer.Strict
import Data.Char (isSpace)
import Data.Generics

import Retrie.ExactPrint
import Retrie.Expr
import Retrie.FreeVars
import Retrie.GHC
import Retrie.Subst
import Retrie.Types
import Retrie.Universe
import Retrie.Util

------------------------------------------------------------------------

-- | Specializes 'replaceImpl' to each of the AST types that retrie supports.
replace
  :: (Data a, MonadIO m) => Context -> a -> TransformT (WriterT Change m) a
replace :: forall a (m :: * -> *).
(Data a, MonadIO m) =>
Context -> a -> TransformT (WriterT Change m) a
replace Context
c =
  (GenLocated SrcSpanAnnA (HsExpr GhcPs)
 -> TransformT
      (WriterT Change m) (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> a -> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
replaceImpl @(HsExpr GhcPs) Context
c)
    (a -> TransformT (WriterT Change m) a)
-> (LocatedA (Stmt GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
    -> TransformT
         (WriterT Change m)
         (LocatedA (Stmt GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))))
-> a
-> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
replaceImpl @(Stmt GhcPs (LHsExpr GhcPs)) Context
c)
    (a -> TransformT (WriterT Change m) a)
-> (LocatedA (HsType GhcPs)
    -> TransformT (WriterT Change m) (LocatedA (HsType GhcPs)))
-> a
-> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
replaceImpl @(HsType GhcPs) Context
c)
    (a -> TransformT (WriterT Change m) a)
-> (GenLocated SrcSpanAnnA (Pat GhcPs)
    -> TransformT
         (WriterT Change m) (GenLocated SrcSpanAnnA (Pat GhcPs)))
-> a
-> TransformT (WriterT Change m) a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
replacePat Context
c

replacePat :: MonadIO m => Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
-- We need to ensure we have a location available at the top level so we can
-- transfer annotations. This ensures we don't try to rewrite a naked Pat.
replacePat :: forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
replacePat Context
c LPat GhcPs
p
  | Just LPat GhcPs
lp <- LPat GhcPs -> Maybe (LPat GhcPs)
forall (p :: Pass). LPat (GhcPass p) -> Maybe (LPat (GhcPass p))
dLPat LPat GhcPs
p = LPat GhcPs -> LPat GhcPs
GenLocated SrcSpanAnnA (Pat GhcPs)
-> GenLocated SrcSpanAnnA (Pat GhcPs)
forall (p :: Pass). LPat (GhcPass p) -> LPat (GhcPass p)
cLPat (GenLocated SrcSpanAnnA (Pat GhcPs)
 -> GenLocated SrcSpanAnnA (Pat GhcPs))
-> TransformT
     (WriterT Change m) (GenLocated SrcSpanAnnA (Pat GhcPs))
-> TransformT
     (WriterT Change m) (GenLocated SrcSpanAnnA (Pat GhcPs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context
-> GenLocated SrcSpanAnnA (Pat GhcPs)
-> TransformT
     (WriterT Change m) (GenLocated SrcSpanAnnA (Pat GhcPs))
forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
replaceImpl Context
c LPat GhcPs
GenLocated SrcSpanAnnA (Pat GhcPs)
lp
  | Bool
otherwise = GenLocated SrcSpanAnnA (Pat GhcPs)
-> TransformT
     (WriterT Change m) (GenLocated SrcSpanAnnA (Pat GhcPs))
forall a. a -> TransformT (WriterT Change m) a
forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
GenLocated SrcSpanAnnA (Pat GhcPs)
p

-- | Generic replacement function. This is the thing that actually runs the
-- 'Rewriter' carried by the context, instantiates templates, handles parens
-- and other whitespace bookkeeping, and emits resulting 'Replacement's.
replaceImpl
  :: forall ast m. (Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m)
  => Context -> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
replaceImpl :: forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
replaceImpl Context
c LocatedA ast
e = do
  let
    -- Prevent rewriting source of the rewrite itself by refusing to
    -- match under a binding of something that appears in the template.
    f :: RewriterResult ast -> RewriterResult ast
f result :: RewriterResult ast
result@RewriterResult{SrcSpan
Quantifiers
Template ast
MatchResultTransformer
rrOrigin :: SrcSpan
rrQuantifiers :: Quantifiers
rrTransformer :: MatchResultTransformer
rrTemplate :: Template ast
rrOrigin :: forall ast. RewriterResult ast -> SrcSpan
rrQuantifiers :: forall ast. RewriterResult ast -> Quantifiers
rrTransformer :: forall ast. RewriterResult ast -> MatchResultTransformer
rrTemplate :: forall ast. RewriterResult ast -> Template ast
..} = RewriterResult ast
result
      { rrTransformer =
          fmap (fmap (check rrOrigin rrQuantifiers)) <$> rrTransformer
      }
    check :: SrcSpan -> Quantifiers -> MatchResult ast -> MatchResult ast
check SrcSpan
origin Quantifiers
quantifiers MatchResult ast
match
      | LocatedA ast -> SrcSpan
forall a e. GenLocated (SrcSpanAnn' a) e -> SrcSpan
getLocA LocatedA ast
e SrcSpan -> SrcSpan -> Bool
`overlaps` SrcSpan
origin = MatchResult ast
forall ast. MatchResult ast
NoMatch
      | MatchResult Substitution
_ Template{Maybe [Rewrite Universe]
Annotated ast
AnnotatedImports
tTemplate :: Annotated ast
tImports :: AnnotatedImports
tDependents :: Maybe [Rewrite Universe]
tTemplate :: forall ast. Template ast -> Annotated ast
tImports :: forall ast. Template ast -> AnnotatedImports
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
..} <- MatchResult ast
match
      , FreeVars
fvs <- Quantifiers -> ast -> FreeVars
forall a. (Data a, Typeable a) => Quantifiers -> a -> FreeVars
freeVars Quantifiers
quantifiers (Annotated ast -> ast
forall ast. Annotated ast -> ast
astA Annotated ast
tTemplate)
      , (RdrName -> Bool) -> [RdrName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (RdrName -> FreeVars -> Bool
`elemFVs` FreeVars
fvs) (Context -> [RdrName]
ctxtBinders Context
c) = MatchResult ast
forall ast. MatchResult ast
NoMatch
      | Bool
otherwise = MatchResult ast
match

  -- We want to match through HsPar so we can make a decision
  -- about whether to keep the parens or not based on the
  -- resulting expression, but we need to know the entry location
  -- of the parens, not the inner expression, so we have to
  -- keep both expressions around.
  MatchResult (LocatedA ast)
match <- (RewriterResult Universe -> RewriterResult Universe)
-> Context
-> Rewriter
-> LocatedA ast
-> TransformT (WriterT Change m) (MatchResult (LocatedA ast))
forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
(RewriterResult Universe -> RewriterResult Universe)
-> Context -> Rewriter -> ast -> TransformT m (MatchResult ast)
runRewriter RewriterResult Universe -> RewriterResult Universe
forall {ast}. RewriterResult ast -> RewriterResult ast
f Context
c (Context -> Rewriter
ctxtRewriter Context
c) (LocatedA ast -> LocatedA ast
forall k. Data k => k -> k
getUnparened LocatedA ast
e)

  case MatchResult (LocatedA ast)
match of
    MatchResult (LocatedA ast)
NoMatch -> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
forall a. a -> TransformT (WriterT Change m) a
forall (m :: * -> *) a. Monad m => a -> m a
return LocatedA ast
e
    MatchResult Substitution
sub Template{Maybe [Rewrite Universe]
AnnotatedImports
Annotated (LocatedA ast)
tTemplate :: forall ast. Template ast -> Annotated ast
tImports :: forall ast. Template ast -> AnnotatedImports
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tTemplate :: Annotated (LocatedA ast)
tImports :: AnnotatedImports
tDependents :: Maybe [Rewrite Universe]
..} -> do
      -- graft template into target module
      LocatedA ast
t' <- Annotated (LocatedA ast)
-> TransformT (WriterT Change m) (LocatedA ast)
forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA Annotated (LocatedA ast)
tTemplate
      -- substitute for quantifiers in grafted template
      LocatedA ast
r <- Substitution
-> Context
-> LocatedA ast
-> TransformT (WriterT Change m) (LocatedA ast)
forall (m :: * -> *) ast.
(MonadIO m, Data ast) =>
Substitution -> Context -> ast -> TransformT m ast
subst Substitution
sub Context
c LocatedA ast
t'
      -- copy appropriate annotations from old expression to template
      LocatedA ast
r0 <- LocatedA ast
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
forall an a b (m :: * -> *).
(HasCallStack, Monoid an, Data a, Data b, MonadIO m,
 Typeable an) =>
LocatedAn an a -> LocatedAn an b -> TransformT m (LocatedAn an b)
addAllAnnsT LocatedA ast
e LocatedA ast
r
      -- add parens to template if needed
      LocatedA ast
res' <- ((GenLocated SrcSpanAnnA (HsExpr GhcPs)
 -> TransformT
      (WriterT Change m) (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (Context
-> LHsExpr GhcPs -> TransformT (WriterT Change m) (LHsExpr GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
parenify Context
c) (LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast))
-> (LocatedA (HsType GhcPs)
    -> TransformT (WriterT Change m) (LocatedA (HsType GhcPs)))
-> LocatedA ast
-> TransformT (WriterT Change m) (LocatedA ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context
-> LHsType GhcPs -> TransformT (WriterT Change m) (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
parenifyT Context
c (LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast))
-> (GenLocated SrcSpanAnnA (Pat GhcPs)
    -> TransformT
         (WriterT Change m) (GenLocated SrcSpanAnnA (Pat GhcPs)))
-> LocatedA ast
-> TransformT (WriterT Change m) (LocatedA ast)
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context -> LPat GhcPs -> TransformT (WriterT Change m) (LPat GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
parenifyP Context
c) LocatedA ast
r0
      -- Make sure the replacement has the same anchor as the thing
      -- being replaced
      let res :: LocatedA ast
res = LocatedA ast -> LocatedA ast -> LocatedA ast
forall a b. LocatedA a -> LocatedA b -> LocatedA b
transferAnchor LocatedA ast
e LocatedA ast
res'

      -- prune the resulting expression and log it with location
      String
orig <- Annotated (LocatedA ast) -> String
forall k. (Data k, ExactPrint k) => Annotated k -> String
printNoLeadingSpaces (Annotated (LocatedA ast) -> String)
-> TransformT (WriterT Change m) (Annotated (LocatedA ast))
-> TransformT (WriterT Change m) String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LocatedA ast
-> TransformT (WriterT Change m) (Annotated (LocatedA ast))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA LocatedA ast
e
      -- orig <- printA' <$> pruneA e

      String
repl <- Annotated (LocatedA ast) -> String
forall k. (Data k, ExactPrint k) => Annotated k -> String
printNoLeadingSpaces (Annotated (LocatedA ast) -> String)
-> TransformT (WriterT Change m) (Annotated (LocatedA ast))
-> TransformT (WriterT Change m) String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LocatedA ast
-> TransformT (WriterT Change m) (Annotated (LocatedA ast))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA LocatedA ast
res
      -- repl <- printA' <$> pruneA r
      -- repl <- printA' <$> pruneA res
      -- repl <- return $ showAst t'

      -- lift $ liftIO $ debugPrint Loud "replaceImpl:orig="  [orig]
      -- lift $ liftIO $ debugPrint Loud "replaceImpl:repl="  [repl]

      -- lift $ liftIO $ debugPrint Loud "replaceImpl:e="  [showAst e]
      -- lift $ liftIO $ debugPrint Loud "replaceImpl:r="  [showAst r]
      -- lift $ liftIO $ debugPrint Loud "replaceImpl:r0="  [showAst r0]
      -- lift $ liftIO $ debugPrint Loud "replaceImpl:t'=" [showAst t']
      -- lift $ liftIO $ debugPrint Loud "replaceImpl:res=" [showAst res]

      let replacement :: Replacement
replacement = SrcSpan -> String -> String -> Replacement
Replacement (LocatedA ast -> SrcSpan
forall a e. GenLocated (SrcSpanAnn' a) e -> SrcSpan
getLocA LocatedA ast
e) String
orig String
repl
      RWST () [String] Int (WriterT Change m) ()
-> TransformT (WriterT Change m) ()
forall (m :: * -> *) a. RWST () [String] Int m a -> TransformT m a
TransformT (RWST () [String] Int (WriterT Change m) ()
 -> TransformT (WriterT Change m) ())
-> RWST () [String] Int (WriterT Change m) ()
-> TransformT (WriterT Change m) ()
forall a b. (a -> b) -> a -> b
$ WriterT Change m () -> RWST () [String] Int (WriterT Change m) ()
forall (m :: * -> *) a. Monad m => m a -> RWST () [String] Int m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT Change m () -> RWST () [String] Int (WriterT Change m) ())
-> WriterT Change m ()
-> RWST () [String] Int (WriterT Change m) ()
forall a b. (a -> b) -> a -> b
$ Change -> WriterT Change m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Change -> WriterT Change m ()) -> Change -> WriterT Change m ()
forall a b. (a -> b) -> a -> b
$ [Replacement] -> [AnnotatedImports] -> Change
Change [Replacement
replacement] [AnnotatedImports
tImports]
      -- make the actual replacement
      LocatedA ast -> TransformT (WriterT Change m) (LocatedA ast)
forall a. a -> TransformT (WriterT Change m) a
forall (m :: * -> *) a. Monad m => a -> m a
return LocatedA ast
res'


-- | Records a replacement made. In cases where we cannot use ghc-exactprint
-- to print the resulting AST (e.g. CPP modules), we fall back on splicing
-- strings. Can also be used by external tools (search, linters, etc).
data Replacement = Replacement
  { Replacement -> SrcSpan
replLocation :: SrcSpan
  , Replacement -> String
replOriginal :: String
  , Replacement -> String
replReplacement :: String
  } deriving Int -> Replacement -> ShowS
[Replacement] -> ShowS
Replacement -> String
(Int -> Replacement -> ShowS)
-> (Replacement -> String)
-> ([Replacement] -> ShowS)
-> Show Replacement
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Replacement -> ShowS
showsPrec :: Int -> Replacement -> ShowS
$cshow :: Replacement -> String
show :: Replacement -> String
$cshowList :: [Replacement] -> ShowS
showList :: [Replacement] -> ShowS
Show

-- | Used as the writer type during matching to indicate whether any change
-- to the module should be made.
data Change = NoChange | Change [Replacement] [AnnotatedImports]

instance Semigroup Change where
  <> :: Change -> Change -> Change
(<>) = Change -> Change -> Change
forall a. Monoid a => a -> a -> a
mappend

instance Monoid Change where
  mempty :: Change
mempty = Change
NoChange
  mappend :: Change -> Change -> Change
mappend Change
NoChange     Change
other        = Change
other
  mappend Change
other        Change
NoChange     = Change
other
  mappend (Change [Replacement]
rs1 [AnnotatedImports]
is1) (Change [Replacement]
rs2 [AnnotatedImports]
is2) =
    [Replacement] -> [AnnotatedImports] -> Change
Change ([Replacement]
rs1 [Replacement] -> [Replacement] -> [Replacement]
forall a. Semigroup a => a -> a -> a
<> [Replacement]
rs2) ([AnnotatedImports]
[Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]]
is1 [Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]]
-> [Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]]
-> [Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]]
forall a. Semigroup a => a -> a -> a
<> [AnnotatedImports]
[Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]]
is2)

-- The location of 'e' accurately points to the first non-space character
-- of 'e', but when we exactprint 'e', we might get some leading spaces (if
-- annEntryDelta of the first token is non-zero). This means we can't just
-- splice in the printed expression at the desired location and call it a day.
-- Unfortunately, its hard to find the right annEntryDelta (it may not be the
-- top of the redex) and zero it out. As janky as it seems, its easier to just
-- drop leading spaces like this.
printNoLeadingSpaces :: (Data k, ExactPrint k) => Annotated k -> String
printNoLeadingSpaces :: forall k. (Data k, ExactPrint k) => Annotated k -> String
printNoLeadingSpaces = (Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
dropWhile Char -> Bool
isSpace ShowS -> (Annotated k -> String) -> Annotated k -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Annotated k -> String
forall k. (Data k, ExactPrint k) => Annotated k -> String
printA