-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Retrie.Rewrites.Types where

import Control.Monad
import Data.Maybe

import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Types

typeSynonymsToRewrites
  :: [(FastString, Direction)]
  -> AnnotatedModule
#if __GLASGOW_HASKELL__ < 900
  -> IO (UniqFM [Rewrite (LHsType GhcPs)])
#else
  -> IO (UniqFM FastString [Rewrite (LHsType GhcPs)])
#endif
typeSynonymsToRewrites :: [(FastString, Direction)]
-> AnnotatedModule
-> IO (UniqFM FastString [Rewrite (LHsType GhcPs)])
typeSynonymsToRewrites [(FastString, Direction)]
specs AnnotatedModule
am = (Annotated
   (UniqFM
      FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])
 -> UniqFM FastString [Rewrite (LHsType GhcPs)])
-> IO
     (Annotated
        (UniqFM
           FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
-> IO (UniqFM FastString [Rewrite (LHsType GhcPs)])
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Annotated
  (UniqFM
     FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])
-> UniqFM FastString [Rewrite (LHsType GhcPs)]
Annotated
  (UniqFM
     FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])
-> UniqFM
     FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]
forall ast. Annotated ast -> ast
astA (IO
   (Annotated
      (UniqFM
         FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
 -> IO (UniqFM FastString [Rewrite (LHsType GhcPs)]))
-> IO
     (Annotated
        (UniqFM
           FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
-> IO (UniqFM FastString [Rewrite (LHsType GhcPs)])
forall a b. (a -> b) -> a -> b
$ AnnotatedModule
-> (Located (HsModule GhcPs)
    -> TransformT
         IO
         (UniqFM
            FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
-> IO
     (Annotated
        (UniqFM
           FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA AnnotatedModule
am ((Located (HsModule GhcPs)
  -> TransformT
       IO
       (UniqFM
          FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
 -> IO
      (Annotated
         (UniqFM
            FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])))
-> (Located (HsModule GhcPs)
    -> TransformT
         IO
         (UniqFM
            FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
-> IO
     (Annotated
        (UniqFM
           FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
forall a b. (a -> b) -> a -> b
$ \ Located (HsModule GhcPs)
m -> do
  let
    fsMap :: UniqFM FastString [Direction]
fsMap = [(FastString, Direction)] -> UniqFM FastString [Direction]
forall a b. Uniquable a => [(a, b)] -> UniqFM a [b]
uniqBag [(FastString, Direction)]
specs
    tySyns :: [(FastString,
  (Direction,
   (GenLocated SrcSpanAnnN RdrName,
    [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
    GenLocated SrcSpanAnnA (HsType GhcPs))))]
tySyns =
      [ (FastString
rdr, (Direction
dir, (LIdP GhcPs
GenLocated SrcSpanAnnN RdrName
nm, LHsQTyVars GhcPs -> [LHsTyVarBndr () GhcPs]
forall pass. LHsQTyVars pass -> [LHsTyVarBndr () pass]
hsq_explicit LHsQTyVars GhcPs
vars, LHsType GhcPs
GenLocated SrcSpanAnnA (HsType GhcPs)
rhs)))
        -- only hsq_explicit is available pre-renaming
      | L SrcSpanAnnA
_ (TyClD XTyClD GhcPs
_ (SynDecl XSynDecl GhcPs
_ LIdP GhcPs
nm LHsQTyVars GhcPs
vars LexicalFixity
_ LHsType GhcPs
rhs)) <- HsModule GhcPs -> [LHsDecl GhcPs]
forall p. HsModule p -> [LHsDecl p]
hsmodDecls (HsModule GhcPs -> [LHsDecl GhcPs])
-> HsModule GhcPs -> [LHsDecl GhcPs]
forall a b. (a -> b) -> a -> b
$ Located (HsModule GhcPs) -> HsModule GhcPs
forall l e. GenLocated l e -> e
unLoc Located (HsModule GhcPs)
m
      , let rdr :: FastString
rdr = RdrName -> FastString
rdrFS (GenLocated SrcSpanAnnN RdrName -> RdrName
forall l e. GenLocated l e -> e
unLoc LIdP GhcPs
GenLocated SrcSpanAnnN RdrName
nm)
      , Direction
dir <- [Direction] -> Maybe [Direction] -> [Direction]
forall a. a -> Maybe a -> a
fromMaybe [] (UniqFM FastString [Direction] -> FastString -> Maybe [Direction]
forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM UniqFM FastString [Direction]
fsMap FastString
rdr)
      ]
  ([(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
 -> UniqFM
      FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])
-> TransformT
     IO [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
-> TransformT
     IO
     (UniqFM
        FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])
forall a b. (a -> b) -> TransformT IO a -> TransformT IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
-> UniqFM
     FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]
forall a b. Uniquable a => [(a, b)] -> UniqFM a [b]
uniqBag (TransformT
   IO [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
 -> TransformT
      IO
      (UniqFM
         FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))]))
-> TransformT
     IO [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
-> TransformT
     IO
     (UniqFM
        FastString [Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))])
forall a b. (a -> b) -> a -> b
$
    [(FastString,
  (Direction,
   (GenLocated SrcSpanAnnN RdrName,
    [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
    GenLocated SrcSpanAnnA (HsType GhcPs))))]
-> ((FastString,
     (Direction,
      (GenLocated SrcSpanAnnN RdrName,
       [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
       GenLocated SrcSpanAnnA (HsType GhcPs))))
    -> TransformT
         IO (FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))))
-> TransformT
     IO [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(FastString,
  (Direction,
   (GenLocated SrcSpanAnnN RdrName,
    [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
    GenLocated SrcSpanAnnA (HsType GhcPs))))]
tySyns (((FastString,
   (Direction,
    (GenLocated SrcSpanAnnN RdrName,
     [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
     GenLocated SrcSpanAnnA (HsType GhcPs))))
  -> TransformT
       IO (FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))))
 -> TransformT
      IO [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))])
-> ((FastString,
     (Direction,
      (GenLocated SrcSpanAnnN RdrName,
       [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
       GenLocated SrcSpanAnnA (HsType GhcPs))))
    -> TransformT
         IO (FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))))
-> TransformT
     IO [(FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))]
forall a b. (a -> b) -> a -> b
$ \(FastString
rdr, (Direction,
 (GenLocated SrcSpanAnnN RdrName,
  [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
  GenLocated SrcSpanAnnA (HsType GhcPs)))
args) -> (FastString
rdr,) (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))
 -> (FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))))
-> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))
-> TransformT
     IO (FastString, Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Direction
 -> (GenLocated SrcSpanAnnN RdrName,
     [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
     GenLocated SrcSpanAnnA (HsType GhcPs))
 -> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))))
-> (Direction,
    (GenLocated SrcSpanAnnN RdrName,
     [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
     GenLocated SrcSpanAnnA (HsType GhcPs)))
-> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Direction
-> (GenLocated SrcSpanAnnN RdrName, [LHsTyVarBndr () GhcPs],
    LHsType GhcPs)
-> TransformT IO (Rewrite (LHsType GhcPs))
Direction
-> (GenLocated SrcSpanAnnN RdrName,
    [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
    GenLocated SrcSpanAnnA (HsType GhcPs))
-> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))
mkTypeRewrite (Direction,
 (GenLocated SrcSpanAnnN RdrName,
  [GenLocated SrcSpanAnnA (HsTyVarBndr () GhcPs)],
  GenLocated SrcSpanAnnA (HsType GhcPs)))
args

------------------------------------------------------------------------

-- | Compile a list of RULES into a list of rewrites.
mkTypeRewrite
  :: Direction
#if __GLASGOW_HASKELL__ < 908
  -> (LocatedN RdrName, [LHsTyVarBndr () GhcPs], LHsType GhcPs)
#else
  -> (LocatedN RdrName, [LHsTyVarBndr (HsBndrVis GhcPs) GhcPs], LHsType GhcPs)
#endif
  -> TransformT IO (Rewrite (LHsType GhcPs))
mkTypeRewrite :: Direction
-> (GenLocated SrcSpanAnnN RdrName, [LHsTyVarBndr () GhcPs],
    LHsType GhcPs)
-> TransformT IO (Rewrite (LHsType GhcPs))
mkTypeRewrite Direction
d (GenLocated SrcSpanAnnN RdrName
lhsName, [LHsTyVarBndr () GhcPs]
vars, LHsType GhcPs
rhs) = do
  let lhsName' :: GenLocated SrcSpanAnnN RdrName
lhsName' = GenLocated SrcSpanAnnN RdrName
-> DeltaPos -> GenLocated SrcSpanAnnN RdrName
forall t a. Default t => LocatedAn t a -> DeltaPos -> LocatedAn t a
setEntryDP GenLocated SrcSpanAnnN RdrName
lhsName (Int -> DeltaPos
SameLine Int
0)
  GenLocated SrcSpanAnnA (HsType GhcPs)
tc <- GenLocated SrcSpanAnnN RdrName -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
GenLocated SrcSpanAnnN RdrName -> TransformT m (LHsType GhcPs)
mkTyVar GenLocated SrcSpanAnnN RdrName
lhsName'
  let
    lvs :: [GenLocated SrcSpanAnnN RdrName]
lvs = [LHsTyVarBndr () GhcPs] -> [GenLocated SrcSpanAnnN RdrName]
forall s.
[LHsTyVarBndr s GhcPs] -> [GenLocated SrcSpanAnnN RdrName]
tyBindersToLocatedRdrNames [LHsTyVarBndr () GhcPs]
vars
  [GenLocated SrcSpanAnnA (HsType GhcPs)]
args <- [GenLocated SrcSpanAnnN RdrName]
-> (GenLocated SrcSpanAnnN RdrName
    -> TransformT IO (GenLocated SrcSpanAnnA (HsType GhcPs)))
-> TransformT IO [GenLocated SrcSpanAnnA (HsType GhcPs)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [GenLocated SrcSpanAnnN RdrName]
lvs ((GenLocated SrcSpanAnnN RdrName
  -> TransformT IO (GenLocated SrcSpanAnnA (HsType GhcPs)))
 -> TransformT IO [GenLocated SrcSpanAnnA (HsType GhcPs)])
-> (GenLocated SrcSpanAnnN RdrName
    -> TransformT IO (GenLocated SrcSpanAnnA (HsType GhcPs)))
-> TransformT IO [GenLocated SrcSpanAnnA (HsType GhcPs)]
forall a b. (a -> b) -> a -> b
$ \ GenLocated SrcSpanAnnN RdrName
lv -> do
    GenLocated SrcSpanAnnA (HsType GhcPs)
tv <- GenLocated SrcSpanAnnN RdrName -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
GenLocated SrcSpanAnnN RdrName -> TransformT m (LHsType GhcPs)
mkTyVar GenLocated SrcSpanAnnN RdrName
lv
    let tv' :: GenLocated SrcSpanAnnA (HsType GhcPs)
tv' = GenLocated SrcSpanAnnA (HsType GhcPs)
-> DeltaPos -> GenLocated SrcSpanAnnA (HsType GhcPs)
forall t a. Default t => LocatedAn t a -> DeltaPos -> LocatedAn t a
setEntryDP GenLocated SrcSpanAnnA (HsType GhcPs)
tv (Int -> DeltaPos
SameLine Int
1)
    GenLocated SrcSpanAnnA (HsType GhcPs)
-> TransformT IO (GenLocated SrcSpanAnnA (HsType GhcPs))
forall a. a -> TransformT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return GenLocated SrcSpanAnnA (HsType GhcPs)
tv'
  GenLocated SrcSpanAnnA (HsType GhcPs)
lhsApps <- [LHsType GhcPs] -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
[LHsType GhcPs] -> TransformT m (LHsType GhcPs)
mkHsAppsTy (GenLocated SrcSpanAnnA (HsType GhcPs)
tcGenLocated SrcSpanAnnA (HsType GhcPs)
-> [GenLocated SrcSpanAnnA (HsType GhcPs)]
-> [GenLocated SrcSpanAnnA (HsType GhcPs)]
forall a. a -> [a] -> [a]
:[GenLocated SrcSpanAnnA (HsType GhcPs)]
args)
  let
    (GenLocated SrcSpanAnnA (HsType GhcPs)
pat, GenLocated SrcSpanAnnA (HsType GhcPs)
tmp) = case Direction
d of
      Direction
LeftToRight -> (GenLocated SrcSpanAnnA (HsType GhcPs)
lhsApps, LHsType GhcPs
GenLocated SrcSpanAnnA (HsType GhcPs)
rhs)
      Direction
RightToLeft -> (LHsType GhcPs
GenLocated SrcSpanAnnA (HsType GhcPs)
rhs, GenLocated SrcSpanAnnA (HsType GhcPs)
lhsApps)
  Annotated (GenLocated SrcSpanAnnA (HsType GhcPs))
p <- GenLocated SrcSpanAnnA (HsType GhcPs)
-> TransformT
     IO (Annotated (GenLocated SrcSpanAnnA (HsType GhcPs)))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA GenLocated SrcSpanAnnA (HsType GhcPs)
pat
  Annotated (GenLocated SrcSpanAnnA (HsType GhcPs))
t <- GenLocated SrcSpanAnnA (HsType GhcPs)
-> TransformT
     IO (Annotated (GenLocated SrcSpanAnnA (HsType GhcPs)))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA GenLocated SrcSpanAnnA (HsType GhcPs)
tmp
  Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))
-> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))
forall a. a -> TransformT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))
 -> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))))
-> Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))
-> TransformT IO (Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs)))
forall a b. (a -> b) -> a -> b
$ Quantifiers
-> Annotated (GenLocated SrcSpanAnnA (HsType GhcPs))
-> Annotated (GenLocated SrcSpanAnnA (HsType GhcPs))
-> Rewrite (GenLocated SrcSpanAnnA (HsType GhcPs))
forall ast.
Quantifiers -> Annotated ast -> Annotated ast -> Rewrite ast
mkRewrite ([RdrName] -> Quantifiers
mkQs ([RdrName] -> Quantifiers) -> [RdrName] -> Quantifiers
forall a b. (a -> b) -> a -> b
$ (GenLocated SrcSpanAnnN RdrName -> RdrName)
-> [GenLocated SrcSpanAnnN RdrName] -> [RdrName]
forall a b. (a -> b) -> [a] -> [b]
map GenLocated SrcSpanAnnN RdrName -> RdrName
forall l e. GenLocated l e -> e
unLoc [GenLocated SrcSpanAnnN RdrName]
lvs) Annotated (GenLocated SrcSpanAnnA (HsType GhcPs))
p Annotated (GenLocated SrcSpanAnnA (HsType GhcPs))
t