{-# LANGUAGE OverloadedStrings #-}
module Retrie.Rewrites
( RewriteSpec(..)
, QualifiedName
, parseRewriteSpecs
, parseQualified
, parseAdhocs
) where
import Control.Exception
import Data.Either (partitionEithers)
import qualified Data.Map as Map
import qualified Data.Text as Text
import Data.Traversable
import System.FilePath
import Retrie.CPP
import Retrie.ExactPrint
import Retrie.Fixity
import Retrie.GHC
import Retrie.Rewrites.Function
import Retrie.Rewrites.Rules
import Retrie.Rewrites.Types
import Retrie.Types
import Retrie.Universe
type QualifiedName = String
data RewriteSpec
= Adhoc String
| Fold QualifiedName
| RuleBackward QualifiedName
| RuleForward QualifiedName
| TypeBackward QualifiedName
| TypeForward QualifiedName
| Unfold QualifiedName
parseRewriteSpecs
:: (FilePath -> IO (CPP AnnotatedModule))
-> FixityEnv
-> [RewriteSpec]
-> IO [Rewrite Universe]
parseRewriteSpecs parser fixityEnv specs = do
(adhocs, fileBased) <- partitionEithers <$> sequence
[ case spec of
Adhoc rule -> return $ Left rule
Fold name -> mkFileBased FoldUnfold RightToLeft name
RuleBackward name -> mkFileBased Rule RightToLeft name
RuleForward name -> mkFileBased Rule LeftToRight name
TypeBackward name -> mkFileBased Type RightToLeft name
TypeForward name -> mkFileBased Type LeftToRight name
Unfold name -> mkFileBased FoldUnfold LeftToRight name
| spec <- specs
]
fbRewrites <- parseFileBased parser fileBased
adhocRewrites <- parseAdhocs fixityEnv adhocs
return $ fbRewrites ++ adhocRewrites
where
mkFileBased ty dir name =
case parseQualified name of
Left err -> throwIO $ ErrorCall $ "parseRewriteSpecs: " ++ err
Right (fp, fs) -> return $ Right (fp, [(ty, [(fs, dir)])])
data FileBasedTy = FoldUnfold | Rule | Type
deriving (Eq, Ord)
parseFileBased
:: (FilePath -> IO (CPP AnnotatedModule))
-> [(FilePath, [(FileBasedTy, [(FastString, Direction)])])]
-> IO [Rewrite Universe]
parseFileBased _ [] = return []
parseFileBased parser specs = concat <$> mapM (uncurry goFile) (gather specs)
where
gather :: Ord a => [(a,[b])] -> [(a,[b])]
gather = Map.toList . Map.fromListWith (++)
goFile
:: FilePath
-> [(FileBasedTy, [(FastString, Direction)])]
-> IO [Rewrite Universe]
goFile fp rules = do
cpp <- parser fp
concat <$> mapM (uncurry $ constructRewrites cpp) (gather rules)
parseAdhocs :: FixityEnv -> [String] -> IO [Rewrite Universe]
parseAdhocs _ [] = return []
parseAdhocs fixities adhocs = do
cpp <-
parseCPP (parseContent fixities "parseAdhocs") (Text.unlines adhocRules)
constructRewrites cpp Rule adhocSpecs
where
addRHS s
| '=' `elem` s = s
| otherwise = s ++ " = undefined"
(adhocSpecs, adhocRules) = unzip
[ ( (mkFastString nm, LeftToRight)
, "{-# RULES \"" <> Text.pack nm <> "\" " <> Text.pack s <> " #-}"
)
| (i,s) <- zip [1..] $ map addRHS adhocs
, let nm = "adhoc" ++ show (i::Int)
]
constructRewrites
:: CPP AnnotatedModule
-> FileBasedTy
-> [(FastString, Direction)]
-> IO [Rewrite Universe]
constructRewrites cpp ty specs = do
cppM <- traverse (tyBuilder ty specs) cpp
let
names = nonDetEltsUniqSet $ mkUniqSet $ map fst specs
nameOf FoldUnfold = "definition"
nameOf Rule = "rule"
nameOf Type = "type synonym"
m = foldr (plusUFM_C (++)) emptyUFM cppM
fmap concat $ forM names $ \fs ->
case lookupUFM m fs of
Nothing ->
fail $ "could not find " ++ nameOf ty ++ " named " ++ unpackFS fs
Just rrs -> return rrs
tyBuilder
:: FileBasedTy
-> [(FastString, Direction)]
-> AnnotatedModule
-> IO (UniqFM [Rewrite Universe])
tyBuilder FoldUnfold specs am = promote <$> dfnsToRewrites specs am
tyBuilder Rule specs am = promote <$> rulesToRewrites specs am
tyBuilder Type specs am = promote <$> typeSynonymsToRewrites specs am
promote :: Matchable a => UniqFM [Rewrite a] -> UniqFM [Rewrite Universe]
promote = fmap (map toURewrite)
parseQualified :: String -> Either String (FilePath, FastString)
parseQualified [] = Left "qualified name is empty"
parseQualified fqName =
case span isHsSymbol reversed of
(_,[]) -> mkError "unqualified operator name"
([],_) ->
case span (/='.') reversed of
(_,[]) -> mkError "unqualified function name"
(rname,_:rmod) -> mkResult (reverse rmod) (reverse rname)
(rop,rmod) ->
case reverse rop of
'.':op -> mkResult (reverse rmod) op
_ -> mkError "malformed qualified operator"
where
reversed = reverse fqName
mkError str = Left $ str ++ ": " ++ fqName
mkResult moduleNameStr occNameStr = Right
( moduleNameSlashes (mkModuleName moduleNameStr) <.> "hs"
, mkFastString occNameStr
)
isHsSymbol :: Char -> Bool
isHsSymbol = (`elem` symbols)
where
symbols :: String
symbols = "!#$%&*+./<=>?@\\^|-~"