{-# LANGUAGE LambdaCase, OverloadedStrings, RecordWildCards, ViewPatterns #-}
module TypeLevel.Rewrite (plugin) where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer
import Data.Foldable
import Data.Traversable
import Coercion (Role(Representational), mkUnivCo)
import Constraint (CtEvidence(ctev_loc), Ct, ctEvExpr, ctLoc, mkNonCanonical)
import GhcPlugins (PredType, SDoc, eqType, fsep, ppr)
import Plugins (Plugin(pluginRecompile, tcPlugin), CommandLineOption, defaultPlugin, purePlugin)
import TcEvidence (EvExpr, EvTerm, evCast)
import TcPluginM (newWanted)
import TcRnTypes
import TyCoRep (UnivCoProvenance(PluginProv))
import TyCon (synTyConDefn_maybe)
import TypeLevel.Rewrite.Internal.ApplyRules
import TypeLevel.Rewrite.Internal.DecomposedConstraint
import TypeLevel.Rewrite.Internal.Lookup
import TypeLevel.Rewrite.Internal.PrettyPrint
import TypeLevel.Rewrite.Internal.TypeEq
import TypeLevel.Rewrite.Internal.TypeRule
import TypeLevel.Rewrite.Internal.TypeTerm
data ReplaceCt = ReplaceCt
{ ReplaceCt -> EvTerm
evidenceOfCorrectness :: EvTerm
, ReplaceCt -> Ct
replacedConstraint :: Ct
, ReplaceCt -> [Ct]
replacementConstraints :: [Ct]
}
combineReplaceCts
:: [ReplaceCt]
-> TcPluginResult
combineReplaceCts :: [ReplaceCt] -> TcPluginResult
combineReplaceCts [ReplaceCt]
replaceCts
= [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk ((ReplaceCt -> (EvTerm, Ct)) -> [ReplaceCt] -> [(EvTerm, Ct)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ReplaceCt -> (EvTerm, Ct)
solvedConstraint [ReplaceCt]
replaceCts)
((ReplaceCt -> [Ct]) -> [ReplaceCt] -> [Ct]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ReplaceCt -> [Ct]
replacementConstraints [ReplaceCt]
replaceCts)
where
solvedConstraint :: ReplaceCt -> (EvTerm, Ct)
solvedConstraint :: ReplaceCt -> (EvTerm, Ct)
solvedConstraint = (,) (EvTerm -> Ct -> (EvTerm, Ct))
-> (ReplaceCt -> EvTerm) -> ReplaceCt -> Ct -> (EvTerm, Ct)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReplaceCt -> EvTerm
evidenceOfCorrectness (ReplaceCt -> Ct -> (EvTerm, Ct))
-> (ReplaceCt -> Ct) -> ReplaceCt -> (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReplaceCt -> Ct
replacedConstraint
usage
:: String
-> String
-> TcPluginM a
usage :: String -> String -> TcPluginM a
usage String
expected String
actual
= String -> TcPluginM a
forall a. HasCallStack => String -> a
error (String -> TcPluginM a) -> String -> TcPluginM a
forall a b. (a -> b) -> a -> b
$ String
"usage:\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" {-# OPTIONS_GHC -fplugin TypeLevel.Rewrite\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -fplugin-opt=TypeLevel.Rewrite:TypeLevel.Append.RightIdentity\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -fplugin-opt=TypeLevel.Rewrite:TypeLevel.Append.RightAssociative #-}\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Where 'TypeLevel.Append' is a module containing a type synonym named 'RightIdentity':\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" type RightIdentity as = (as ++ '[]) ~ as\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Type expressions which match the left of the '~' will get rewritten to the type\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"expression on the right of the '~'. Be careful not to introduce cycles!\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"expected:\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
expected String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"got:\n"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
actual
lookupTypeRules
:: [CommandLineOption]
-> TcPluginM [TypeRule]
lookupTypeRules :: [String] -> TcPluginM [TypeRule]
lookupTypeRules [] = do
String -> String -> TcPluginM [TypeRule]
forall a. String -> String -> TcPluginM a
usage ([String] -> String
forall a. Show a => a -> String
show [ String
"TypeLevel.Append.RightIdentity" :: String
, String
"TypeLevel.Append.RightAssociative"
])
String
"[]"
lookupTypeRules [String]
fullyQualifiedTypeSynonyms = do
[String] -> (String -> TcPluginM TypeRule) -> TcPluginM [TypeRule]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [String]
fullyQualifiedTypeSynonyms ((String -> TcPluginM TypeRule) -> TcPluginM [TypeRule])
-> (String -> TcPluginM TypeRule) -> TcPluginM [TypeRule]
forall a b. (a -> b) -> a -> b
$ \String
fullyQualifiedTypeSynonym -> do
case String -> Maybe (String, String)
splitLastDot String
fullyQualifiedTypeSynonym of
Maybe (String, String)
Nothing -> do
String -> String -> TcPluginM TypeRule
forall a. String -> String -> TcPluginM a
usage (String -> String
forall a. Show a => a -> String
show (String
"TypeLevel.Append.RightIdentity" :: String))
(String -> String
forall a. Show a => a -> String
show String
fullyQualifiedTypeSynonym)
Just (String
moduleNameStr, String
tyConNameStr) -> do
TyCon
tyCon <- String -> String -> TcPluginM TyCon
lookupTyCon String
moduleNameStr String
tyConNameStr
case TyCon -> Maybe ([TyVar], Type)
synTyConDefn_maybe TyCon
tyCon of
Maybe ([TyVar], Type)
Nothing -> do
String -> String -> TcPluginM TypeRule
forall a. String -> String -> TcPluginM a
usage (String
"type " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TyCon -> String
pprTyCon TyCon
tyCon String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ... = ...")
(TyCon -> String
pprTyCon TyCon
tyCon String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not a type synonym")
Just ([TyVar]
_tyVars, Type
definition) -> do
case Type -> Maybe TypeRule
toTypeRule_maybe Type
definition of
Maybe TypeRule
Nothing -> do
String -> String -> TcPluginM TypeRule
forall a. String -> String -> TcPluginM a
usage String
"... ~ ..."
(Type -> String
pprType Type
definition)
Just TypeRule
typeRule -> do
TypeRule -> TcPluginM TypeRule
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeRule
typeRule
plugin
:: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin
{ tcPlugin :: TcPlugin
tcPlugin = \[String]
args -> TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin
{ tcPluginInit :: TcPluginM [TypeRule]
tcPluginInit = [String] -> TcPluginM [TypeRule]
lookupTypeRules [String]
args
, tcPluginSolve :: [TypeRule] -> TcPluginSolver
tcPluginSolve = [TypeRule] -> TcPluginSolver
solve
, tcPluginStop :: [TypeRule] -> TcPluginM ()
tcPluginStop = \[TypeRule]
_ -> () -> TcPluginM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
}
, pluginRecompile :: [String] -> IO PluginRecompile
pluginRecompile = [String] -> IO PluginRecompile
purePlugin
}
mkErrCtx
:: SDoc
-> ErrCtxt
mkErrCtx :: SDoc -> ErrCtxt
mkErrCtx SDoc
errDoc = (Bool
True, \TidyEnv
env -> (TidyEnv, SDoc) -> TcM (TidyEnv, SDoc)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TidyEnv
env, SDoc
errDoc))
newRuleInducedWanted
:: Ct
-> TypeRule
-> PredType
-> TcPluginM CtEvidence
newRuleInducedWanted :: Ct -> TypeRule -> Type -> TcPluginM CtEvidence
newRuleInducedWanted Ct
oldCt TypeRule
rule Type
newPredType = do
let loc :: CtLoc
loc = Ct -> CtLoc
ctLoc Ct
oldCt
let errMsg :: SDoc
errMsg = [SDoc] -> SDoc
fsep [ SDoc
"From the typelevel rewrite rule:"
, Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr (TypeRule -> Type
fromTypeRule TypeRule
rule)
]
let loc' :: CtLoc
loc' = ErrCtxt -> CtLoc -> CtLoc
pushErrCtxtSameOrigin (SDoc -> ErrCtxt
mkErrCtx SDoc
errMsg) CtLoc
loc
CtEvidence
wanted <- CtLoc -> Type -> TcPluginM CtEvidence
newWanted CtLoc
loc' Type
newPredType
CtEvidence -> TcPluginM CtEvidence
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CtEvidence -> TcPluginM CtEvidence)
-> CtEvidence -> TcPluginM CtEvidence
forall a b. (a -> b) -> a -> b
$ CtEvidence
wanted { ctev_loc :: CtLoc
ctev_loc = CtLoc
loc' }
solve
:: [TypeRule]
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
solve :: [TypeRule] -> TcPluginSolver
solve [TypeRule]
_ [Ct]
_ [Ct]
_ [] = do
TcPluginResult -> TcPluginM TcPluginResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] []
solve [TypeRule]
rules [Ct]
givens [Ct]
_ [Ct]
wanteds = do
[(TypeEq, TypeTerm)]
typeSubst <- WriterT [(TypeEq, TypeTerm)] TcPluginM ()
-> TcPluginM [(TypeEq, TypeTerm)]
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (WriterT [(TypeEq, TypeTerm)] TcPluginM ()
-> TcPluginM [(TypeEq, TypeTerm)])
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
-> TcPluginM [(TypeEq, TypeTerm)]
forall a b. (a -> b) -> a -> b
$ do
[Ct]
-> (Ct -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Ct]
givens ((Ct -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> (Ct -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \Ct
given -> do
Maybe (Type, Type)
-> ((Type, Type) -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (Ct -> Maybe (Type, Type)
asEqualityConstraint Ct
given) (((Type, Type) -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> ((Type, Type) -> WriterT [(TypeEq, TypeTerm)] TcPluginM ())
-> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \(Type
lhs, Type
rhs) -> do
let var :: TypeEq
var = Type -> TypeEq
TypeEq Type
rhs
let val :: TypeTerm
val = Type -> TypeTerm
toTypeTerm Type
lhs
[(TypeEq, TypeTerm)] -> WriterT [(TypeEq, TypeTerm)] TcPluginM ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell [(TypeEq
var, TypeTerm
val)]
[ReplaceCt]
replaceCts <- WriterT [ReplaceCt] TcPluginM () -> TcPluginM [ReplaceCt]
forall (m :: * -> *) w a. Monad m => WriterT w m a -> m w
execWriterT (WriterT [ReplaceCt] TcPluginM () -> TcPluginM [ReplaceCt])
-> WriterT [ReplaceCt] TcPluginM () -> TcPluginM [ReplaceCt]
forall a b. (a -> b) -> a -> b
$ do
[Ct]
-> (Ct -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Ct]
wanteds ((Ct -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ())
-> (Ct -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \Ct
wanted -> do
Maybe (DecomposedConstraint Type)
-> (DecomposedConstraint Type -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (Ct -> Maybe (DecomposedConstraint Type)
asDecomposedConstraint Ct
wanted) ((DecomposedConstraint Type -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ())
-> (DecomposedConstraint Type -> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \DecomposedConstraint Type
types -> do
let typeTerms :: DecomposedConstraint TypeTerm
typeTerms = (Type -> TypeTerm)
-> DecomposedConstraint Type -> DecomposedConstraint TypeTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> TypeTerm
toTypeTerm DecomposedConstraint Type
types
let predType :: Type
predType = DecomposedConstraint Type -> Type
fromDecomposeConstraint DecomposedConstraint Type
types
Maybe (TypeRule, DecomposedConstraint TypeTerm)
-> ((TypeRule, DecomposedConstraint TypeTerm)
-> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ ([(TypeEq, TypeTerm)]
-> [TypeRule]
-> DecomposedConstraint TypeTerm
-> Maybe (TypeRule, DecomposedConstraint TypeTerm)
forall (t :: * -> *).
Traversable t =>
[(TypeEq, TypeTerm)]
-> [TypeRule] -> t TypeTerm -> Maybe (TypeRule, t TypeTerm)
applyRules [(TypeEq, TypeTerm)]
typeSubst [TypeRule]
rules DecomposedConstraint TypeTerm
typeTerms) (((TypeRule, DecomposedConstraint TypeTerm)
-> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ())
-> ((TypeRule, DecomposedConstraint TypeTerm)
-> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \(TypeRule
rule, DecomposedConstraint TypeTerm
typeTerms') -> do
let types' :: DecomposedConstraint Type
types' = (TypeTerm -> Type)
-> DecomposedConstraint TypeTerm -> DecomposedConstraint Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeTerm -> Type
fromTypeTerm DecomposedConstraint TypeTerm
typeTerms'
let predType' :: Type
predType' = DecomposedConstraint Type -> Type
fromDecomposeConstraint DecomposedConstraint Type
types'
Bool
-> WriterT [ReplaceCt] TcPluginM ()
-> WriterT [ReplaceCt] TcPluginM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Type -> Bool
eqType Type
predType' Type
predType) (WriterT [ReplaceCt] TcPluginM ()
-> WriterT [ReplaceCt] TcPluginM ())
-> WriterT [ReplaceCt] TcPluginM ()
-> WriterT [ReplaceCt] TcPluginM ()
forall a b. (a -> b) -> a -> b
$ do
let co :: Coercion
co = UnivCoProvenance -> Role -> Type -> Type -> Coercion
mkUnivCo
(String -> UnivCoProvenance
PluginProv String
"TypeLevel.Rewrite")
Role
Representational
Type
predType'
Type
predType
CtEvidence
evWanted' <- TcPluginM CtEvidence -> WriterT [ReplaceCt] TcPluginM CtEvidence
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TcPluginM CtEvidence -> WriterT [ReplaceCt] TcPluginM CtEvidence)
-> TcPluginM CtEvidence -> WriterT [ReplaceCt] TcPluginM CtEvidence
forall a b. (a -> b) -> a -> b
$ Ct -> TypeRule -> Type -> TcPluginM CtEvidence
newRuleInducedWanted Ct
wanted TypeRule
rule Type
predType'
let wanted' :: Ct
wanted' = CtEvidence -> Ct
mkNonCanonical CtEvidence
evWanted'
let futureDict :: EvExpr
futureDict :: EvExpr
futureDict = CtEvidence -> EvExpr
ctEvExpr CtEvidence
evWanted'
let replaceCt :: ReplaceCt
replaceCt :: ReplaceCt
replaceCt = ReplaceCt :: EvTerm -> Ct -> [Ct] -> ReplaceCt
ReplaceCt
{ evidenceOfCorrectness :: EvTerm
evidenceOfCorrectness = EvExpr -> Coercion -> EvTerm
evCast EvExpr
futureDict Coercion
co
, replacedConstraint :: Ct
replacedConstraint = Ct
wanted
, replacementConstraints :: [Ct]
replacementConstraints = [Ct
wanted']
}
[ReplaceCt] -> WriterT [ReplaceCt] TcPluginM ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell [ReplaceCt
replaceCt]
TcPluginResult -> TcPluginM TcPluginResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [ReplaceCt] -> TcPluginResult
combineReplaceCts [ReplaceCt]
replaceCts