{-|
Copyright  :  (C) 2015-2016, University of Twente
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

To use the plugin, add the

@
{\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
@

pragma to the header of your file

-}

{-# LANGUAGE CPP           #-}
{-# LANGUAGE TupleSections #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TypeLits.Extra.Solver
  ( plugin )
where

-- external
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Maybe                (catMaybes)
import GHC.TcPluginM.Extra       (evByFiat, lookupModule, lookupName
                                 ,tracePlugin, newWanted)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens)
#else
import Control.Monad ((<=<))
#endif

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Names (eqPrimTyConKey, hasKey)
import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types (boolTy, naturalTy)
#else
import GHC.Builtin.Types (typeNatKind)
#endif
import GHC.Builtin.Types.Literals (typeNatTyCons)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types.Literals (typeNatCmpTyCon)
#else
import GHC.Builtin.Types.Literals (typeNatLeqTyCon)
#endif
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType)
import GHC.Core.TyCo.Rep (Type (..))
import GHC.Core.Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult (..))
import GHC.Tc.Types.Constraint
  (Ct, ctEvidence, ctEvPred, ctLoc, isWantedCt, cc_ev)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Tc.Types.Constraint (Ct (CQuantCan), qci_ev)
#endif
import GHC.Tc.Types.Evidence (EvTerm)
import GHC.Types.Name.Occurrence (mkTcOcc)
import GHC.Unit.Module (mkModuleName)
import GHC.Utils.Outputable (Outputable (..), (<+>), ($$), text)
#else
import FastString (fsLit)
import Module     (mkModuleName)
import OccName    (mkTcOcc)
import Outputable (Outputable (..), (<+>), ($$), text)
import Plugins    (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins    (purePlugin)
#endif
import PrelNames  (eqPrimTyConKey, hasKey)
import TcEvidence (EvTerm)
import TcPluginM  (TcPluginM, tcLookupTyCon, tcPluginTrace)
import TcRnTypes  (TcPlugin(..), TcPluginResult (..))
import Type       (Kind, eqType, mkTyConApp, splitTyConApp_maybe)
import TyCoRep    (Type (..))
import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon)
import TcTypeNats (typeNatLeqTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatTyCons)
#else
import TcPluginM  (zonkCt)
#endif

#if MIN_VERSION_ghc(8,10,0)
import Constraint (Ct, ctEvidence, ctEvPred, ctLoc, isWantedCt, cc_ev)
import Predicate  (EqRel (NomEq), Pred (EqPred), classifyPredType)
import Type       (typeKind)
#else
import TcRnTypes  (Ct, CtEvidence, ctEvidence, ctEvPred, ctLoc, isWantedCt, cc_ev)
import TcType     (typeKind)
import Type       (EqRel (NomEq), PredTree (EqPred), classifyPredType)
#endif
#endif

-- internal
import GHC.TypeLits.Extra.Solver.Operations
import GHC.TypeLits.Extra.Solver.Unify

#if MIN_VERSION_ghc(9,2,0)
typeNatKind :: Type
typeNatKind = naturalTy
#endif

-- | A solver implement as a type-checker plugin for:
--
--     * 'Div': type-level 'div'
--
--     * 'Mod': type-level 'mod'
--
--     * 'FLog': type-level equivalent of <https://hackage.haskell.org/package/integer-gmp/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
--       .i.e. the exact integer equivalent to "@'floor' ('logBase' x y)@"
--
--     * 'CLog': type-level equivalent of /the ceiling of/ <https://hackage.haskell.org/package/integer-gmp/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
--       .i.e. the exact integer equivalent to "@'ceiling' ('logBase' x y)@"
--
--     * 'Log': type-level equivalent of <https://hackage.haskell.org/package/integer-gmp/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
--        where the operation only reduces when "@'floor' ('logBase' b x) ~ 'ceiling' ('logBase' b x)@"
--
--     * 'GCD': a type-level 'gcd'
--
--     * 'LCM': a type-level 'lcm'
--
-- To use the plugin, add
--
-- @
-- {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
-- @
--
-- To the header of your file.
plugin :: Plugin
plugin :: Plugin
plugin
  = Plugin
defaultPlugin
  { tcPlugin :: TcPlugin
tcPlugin = Maybe TcPlugin -> TcPlugin
forall a b. a -> b -> a
const (Maybe TcPlugin -> TcPlugin) -> Maybe TcPlugin -> TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just TcPlugin
normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
  , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
#endif
  }

normalisePlugin :: TcPlugin
normalisePlugin :: TcPlugin
normalisePlugin = CommandLineOption -> TcPlugin -> TcPlugin
tracePlugin CommandLineOption
"ghc-typelits-extra"
  TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin { tcPluginInit :: TcPluginM ExtraDefs
tcPluginInit  = TcPluginM ExtraDefs
lookupExtraDefs
           , tcPluginSolve :: ExtraDefs -> TcPluginSolver
tcPluginSolve = ExtraDefs -> TcPluginSolver
decideEqualSOP
           , tcPluginStop :: ExtraDefs -> TcPluginM ()
tcPluginStop  = TcPluginM () -> ExtraDefs -> TcPluginM ()
forall a b. a -> b -> a
const (() -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
           }

decideEqualSOP :: ExtraDefs -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
decideEqualSOP :: ExtraDefs -> TcPluginSolver
decideEqualSOP ExtraDefs
_    [Ct]
_givens [Ct]
_deriveds []      = TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
decideEqualSOP ExtraDefs
defs [Ct]
givens  [Ct]
_deriveds [Ct]
wanteds = do
  -- GHC 7.10.1 puts deriveds with the wanteds, so filter them out
  let wanteds' :: [Ct]
wanteds' = (Ct -> Bool) -> [Ct] -> [Ct]
forall a. (a -> Bool) -> [a] -> [a]
filter Ct -> Bool
isWantedCt [Ct]
wanteds
  [SolverConstraint]
unit_wanteds <- [Maybe SolverConstraint] -> [SolverConstraint]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe SolverConstraint] -> [SolverConstraint])
-> TcPluginM [Maybe SolverConstraint]
-> TcPluginM [SolverConstraint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ct -> TcPluginM (Maybe SolverConstraint))
-> [Ct] -> TcPluginM [Maybe SolverConstraint]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (MaybeT TcPluginM SolverConstraint
-> TcPluginM (Maybe SolverConstraint)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT TcPluginM SolverConstraint
 -> TcPluginM (Maybe SolverConstraint))
-> (Ct -> MaybeT TcPluginM SolverConstraint)
-> Ct
-> TcPluginM (Maybe SolverConstraint)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint ExtraDefs
defs) [Ct]
wanteds'
  case [SolverConstraint]
unit_wanteds of
    [] -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
    [SolverConstraint]
_  -> do
#if MIN_VERSION_ghc(8,4,0)
      [SolverConstraint]
unit_givens <- [Maybe SolverConstraint] -> [SolverConstraint]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe SolverConstraint] -> [SolverConstraint])
-> TcPluginM [Maybe SolverConstraint]
-> TcPluginM [SolverConstraint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ct -> TcPluginM (Maybe SolverConstraint))
-> [Ct] -> TcPluginM [Maybe SolverConstraint]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (MaybeT TcPluginM SolverConstraint
-> TcPluginM (Maybe SolverConstraint)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT TcPluginM SolverConstraint
 -> TcPluginM (Maybe SolverConstraint))
-> (Ct -> MaybeT TcPluginM SolverConstraint)
-> Ct
-> TcPluginM (Maybe SolverConstraint)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint ExtraDefs
defs) ([Ct]
givens [Ct] -> [Ct] -> [Ct]
forall a. [a] -> [a] -> [a]
++ [Ct] -> [Ct]
flattenGivens [Ct]
givens)
#else
      unit_givens <- catMaybes <$> mapM ((runMaybeT . toSolverConstraint defs) <=< zonkCt) givens
#endif
      SimplifyResult
sr <- ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra ExtraDefs
defs ([SolverConstraint]
unit_givens [SolverConstraint] -> [SolverConstraint] -> [SolverConstraint]
forall a. [a] -> [a] -> [a]
++ [SolverConstraint]
unit_wanteds)
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"normalised" (SimplifyResult -> SDoc
forall a. Outputable a => a -> SDoc
ppr SimplifyResult
sr)
      case SimplifyResult
sr of
        Simplified [(EvTerm, Ct)]
evs [Ct]
new -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk (((EvTerm, Ct) -> Bool) -> [(EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Ct -> Bool
isWantedCt (Ct -> Bool) -> ((EvTerm, Ct) -> Ct) -> (EvTerm, Ct) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EvTerm, Ct) -> Ct
forall a b. (a, b) -> b
snd) [(EvTerm, Ct)]
evs) [Ct]
new)
        Impossible SolverConstraint
eq  -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([Ct] -> TcPluginResult
TcPluginContradiction [SolverConstraint -> Ct
fromSolverConstraint SolverConstraint
eq])

data SolverConstraint
   = NatEquality Ct ExtraOp ExtraOp Normalised
   | NatInequality Ct ExtraOp ExtraOp Bool Normalised

instance Outputable SolverConstraint where
  ppr :: SolverConstraint -> SDoc
ppr (NatEquality Ct
ct ExtraOp
op1 ExtraOp
op2 Normalised
norm) = CommandLineOption -> SDoc
text CommandLineOption
"NatEquality" SDoc -> SDoc -> SDoc
$$ Ct -> SDoc
forall a. Outputable a => a -> SDoc
ppr Ct
ct SDoc -> SDoc -> SDoc
$$ ExtraOp -> SDoc
forall a. Outputable a => a -> SDoc
ppr ExtraOp
op1 SDoc -> SDoc -> SDoc
$$ ExtraOp -> SDoc
forall a. Outputable a => a -> SDoc
ppr ExtraOp
op2 SDoc -> SDoc -> SDoc
$$ Normalised -> SDoc
forall a. Outputable a => a -> SDoc
ppr Normalised
norm
  ppr (NatInequality Ct
_ ExtraOp
op1 ExtraOp
op2 Bool
b Normalised
norm) = CommandLineOption -> SDoc
text CommandLineOption
"NatInequality" SDoc -> SDoc -> SDoc
$$ ExtraOp -> SDoc
forall a. Outputable a => a -> SDoc
ppr ExtraOp
op1 SDoc -> SDoc -> SDoc
$$ ExtraOp -> SDoc
forall a. Outputable a => a -> SDoc
ppr ExtraOp
op2 SDoc -> SDoc -> SDoc
$$ Bool -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bool
b SDoc -> SDoc -> SDoc
$$ Normalised -> SDoc
forall a. Outputable a => a -> SDoc
ppr Normalised
norm

data SimplifyResult
  = Simplified [(EvTerm,Ct)] [Ct]
  | Impossible SolverConstraint

instance Outputable SimplifyResult where
  ppr :: SimplifyResult -> SDoc
ppr (Simplified [(EvTerm, Ct)]
evs [Ct]
new) = CommandLineOption -> SDoc
text CommandLineOption
"Simplified" SDoc -> SDoc -> SDoc
$$ CommandLineOption -> SDoc
text CommandLineOption
"Solved:" SDoc -> SDoc -> SDoc
$$ [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
evs SDoc -> SDoc -> SDoc
$$ CommandLineOption -> SDoc
text CommandLineOption
"New:" SDoc -> SDoc -> SDoc
$$ [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new
  ppr (Impossible SolverConstraint
sct)  = CommandLineOption -> SDoc
text CommandLineOption
"Impossible" SDoc -> SDoc -> SDoc
<+> SolverConstraint -> SDoc
forall a. Outputable a => a -> SDoc
ppr SolverConstraint
sct

simplifyExtra :: ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra :: ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra ExtraDefs
defs [SolverConstraint]
eqs = CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"simplifyExtra" ([SolverConstraint] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [SolverConstraint]
eqs) TcPluginM ()
-> TcPluginM SimplifyResult -> TcPluginM SimplifyResult
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [] [] [SolverConstraint]
eqs
  where
    simples :: [Maybe (EvTerm, Ct)] -> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
    simples :: [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [] = SimplifyResult -> TcPluginM SimplifyResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> SimplifyResult
Simplified ([Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (EvTerm, Ct)]
evs) [Ct]
news)
    simples [Maybe (EvTerm, Ct)]
evs [Ct]
news (eq :: SolverConstraint
eq@(NatEquality Ct
ct ExtraOp
u ExtraOp
v Normalised
norm):[SolverConstraint]
eqs') = do
      UnifyResult
ur <- Ct -> ExtraOp -> ExtraOp -> TcPluginM UnifyResult
unifyExtra Ct
ct ExtraOp
u ExtraOp
v
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"unifyExtra result" (UnifyResult -> SDoc
forall a. Outputable a => a -> SDoc
ppr UnifyResult
ur)
      case UnifyResult
ur of
        UnifyResult
Win                          -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) (EvTerm -> Ct -> (EvTerm, Ct))
-> Maybe EvTerm -> Maybe (Ct -> (EvTerm, Ct))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct Maybe (Ct -> (EvTerm, Ct)) -> Maybe Ct -> Maybe (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ct -> Maybe Ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)Maybe (EvTerm, Ct) -> [Maybe (EvTerm, Ct)] -> [Maybe (EvTerm, Ct)]
forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) [Ct]
news [SolverConstraint]
eqs'
        UnifyResult
Lose | [Maybe (EvTerm, Ct)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (EvTerm, Ct)]
evs Bool -> Bool -> Bool
&& [SolverConstraint] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SolverConstraint]
eqs' -> SimplifyResult -> TcPluginM SimplifyResult
forall (m :: * -> *) a. Monad m => a -> m a
return (SolverConstraint -> SimplifyResult
Impossible SolverConstraint
eq)
        UnifyResult
_ | Normalised
norm Normalised -> Normalised -> Bool
forall a. Eq a => a -> a -> Bool
== Normalised
Normalised Bool -> Bool -> Bool
&& Ct -> Bool
isWantedCt Ct
ct -> do
          Ct
newCt <- ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised ExtraDefs
defs SolverConstraint
eq
          [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) (EvTerm -> Ct -> (EvTerm, Ct))
-> Maybe EvTerm -> Maybe (Ct -> (EvTerm, Ct))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct Maybe (Ct -> (EvTerm, Ct)) -> Maybe Ct -> Maybe (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ct -> Maybe Ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)Maybe (EvTerm, Ct) -> [Maybe (EvTerm, Ct)] -> [Maybe (EvTerm, Ct)]
forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) (Ct
newCtCt -> [Ct] -> [Ct]
forall a. a -> [a] -> [a]
:[Ct]
news) [SolverConstraint]
eqs'
        UnifyResult
Lose -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'
        UnifyResult
Draw -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'
    simples [Maybe (EvTerm, Ct)]
evs [Ct]
news (eq :: SolverConstraint
eq@(NatInequality Ct
ct ExtraOp
u ExtraOp
v Bool
b Normalised
norm):[SolverConstraint]
eqs') = do
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"unifyExtra leq result" ((ExtraOp, ExtraOp, Bool) -> SDoc
forall a. Outputable a => a -> SDoc
ppr (ExtraOp
u,ExtraOp
v,Bool
b))
      case (ExtraOp
u,ExtraOp
v) of
        (I Integer
i,I Integer
j)
          | (Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
j) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
b -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) (EvTerm -> Ct -> (EvTerm, Ct))
-> Maybe EvTerm -> Maybe (Ct -> (EvTerm, Ct))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct Maybe (Ct -> (EvTerm, Ct)) -> Maybe Ct -> Maybe (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ct -> Maybe Ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)Maybe (EvTerm, Ct) -> [Maybe (EvTerm, Ct)] -> [Maybe (EvTerm, Ct)]
forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) [Ct]
news [SolverConstraint]
eqs'
          | Bool
otherwise     -> SimplifyResult -> TcPluginM SimplifyResult
forall (m :: * -> *) a. Monad m => a -> m a
return  (SolverConstraint -> SimplifyResult
Impossible SolverConstraint
eq)
        (ExtraOp
p, Max ExtraOp
x ExtraOp
y)
          | Bool
b Bool -> Bool -> Bool
&& (ExtraOp
p ExtraOp -> ExtraOp -> Bool
forall a. Eq a => a -> a -> Bool
== ExtraOp
x Bool -> Bool -> Bool
|| ExtraOp
p ExtraOp -> ExtraOp -> Bool
forall a. Eq a => a -> a -> Bool
== ExtraOp
y) -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) (EvTerm -> Ct -> (EvTerm, Ct))
-> Maybe EvTerm -> Maybe (Ct -> (EvTerm, Ct))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct Maybe (Ct -> (EvTerm, Ct)) -> Maybe Ct -> Maybe (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ct -> Maybe Ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)Maybe (EvTerm, Ct) -> [Maybe (EvTerm, Ct)] -> [Maybe (EvTerm, Ct)]
forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) [Ct]
news [SolverConstraint]
eqs'

        -- transform:  q ~ Max x y => (p <=? q ~ True)
        -- to:         (p <=? Max x y) ~ True
        -- and try to solve that along with the rest of the eqs'
        (ExtraOp
p, q :: ExtraOp
q@(V TyVar
_))
          | Bool
b -> case ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
findMax ExtraOp
q [SolverConstraint]
eqs of
                   Just ExtraOp
m  -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news (Ct -> ExtraOp -> ExtraOp -> Bool -> Normalised -> SolverConstraint
NatInequality Ct
ct ExtraOp
p ExtraOp
m Bool
b Normalised
normSolverConstraint -> [SolverConstraint] -> [SolverConstraint]
forall a. a -> [a] -> [a]
:[SolverConstraint]
eqs')
                   Maybe ExtraOp
Nothing -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'
        (ExtraOp, ExtraOp)
_ | Normalised
norm Normalised -> Normalised -> Bool
forall a. Eq a => a -> a -> Bool
== Normalised
Normalised Bool -> Bool -> Bool
&& Ct -> Bool
isWantedCt Ct
ct -> do
          Ct
newCt <- ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised ExtraDefs
defs SolverConstraint
eq
          [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) (EvTerm -> Ct -> (EvTerm, Ct))
-> Maybe EvTerm -> Maybe (Ct -> (EvTerm, Ct))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct Maybe (Ct -> (EvTerm, Ct)) -> Maybe Ct -> Maybe (EvTerm, Ct)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ct -> Maybe Ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)Maybe (EvTerm, Ct) -> [Maybe (EvTerm, Ct)] -> [Maybe (EvTerm, Ct)]
forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) (Ct
newCtCt -> [Ct] -> [Ct]
forall a. a -> [a] -> [a]
:[Ct]
news) [SolverConstraint]
eqs'
        (ExtraOp, ExtraOp)
_ -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'

    -- look for given constraint with the form: c ~ Max x y
    findMax :: ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
    findMax :: ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
findMax ExtraOp
c = [SolverConstraint] -> Maybe ExtraOp
go
      where
        go :: [SolverConstraint] -> Maybe ExtraOp
go [] = Maybe ExtraOp
forall a. Maybe a
Nothing
        go ((NatEquality Ct
ct ExtraOp
a b :: ExtraOp
b@(Max ExtraOp
_ ExtraOp
_) Normalised
_) :[SolverConstraint]
_)
          | ExtraOp
c ExtraOp -> ExtraOp -> Bool
forall a. Eq a => a -> a -> Bool
== ExtraOp
a Bool -> Bool -> Bool
&& Bool -> Bool
not (Ct -> Bool
isWantedCt Ct
ct)
            = ExtraOp -> Maybe ExtraOp
forall a. a -> Maybe a
Just ExtraOp
b
        go ((NatEquality Ct
ct a :: ExtraOp
a@(Max ExtraOp
_ ExtraOp
_) ExtraOp
b Normalised
_) :[SolverConstraint]
_)
          | ExtraOp
c ExtraOp -> ExtraOp -> Bool
forall a. Eq a => a -> a -> Bool
== ExtraOp
b Bool -> Bool -> Bool
&& Bool -> Bool
not (Ct -> Bool
isWantedCt Ct
ct)
            = ExtraOp -> Maybe ExtraOp
forall a. a -> Maybe a
Just ExtraOp
a
        go (SolverConstraint
_:[SolverConstraint]
rest) = [SolverConstraint] -> Maybe ExtraOp
go [SolverConstraint]
rest


-- Extract the Nat equality constraints
toSolverConstraint :: ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint :: ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint ExtraDefs
defs Ct
ct = case PredType -> Pred
classifyPredType (PredType -> Pred) -> PredType -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> PredType
ctEvPred (CtEvidence -> PredType) -> CtEvidence -> PredType
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq PredType
t1 PredType
t2
      | PredType -> Bool
isNatKind (HasDebugCallStack => PredType -> PredType
PredType -> PredType
typeKind PredType
t1) Bool -> Bool -> Bool
|| PredType -> Bool
isNatKind (HasDebugCallStack => PredType -> PredType
PredType -> PredType
typeKind PredType
t2)
      -> do
         (ExtraOp
t1', Normalised
n1) <- ExtraDefs -> PredType -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs PredType
t1
         (ExtraOp
t2', Normalised
n2) <- ExtraDefs -> PredType -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs PredType
t2
         SolverConstraint -> MaybeT TcPluginM SolverConstraint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> ExtraOp -> ExtraOp -> Normalised -> SolverConstraint
NatEquality Ct
ct ExtraOp
t1' ExtraOp
t2' (Normalised -> Normalised -> Normalised
mergeNormalised Normalised
n1 Normalised
n2))
#if MIN_VERSION_ghc(9,2,0)
      | TyConApp tc [_,cmpNat,TyConApp tt1 [],TyConApp tt2 [],TyConApp ff1 []] <- t1
      , tc == ordTyCon defs
      , TyConApp cmpNatTc [x,y] <- cmpNat
      , cmpNatTc == typeNatCmpTyCon
      , tt1 == promotedTrueDataCon
      , tt2 == promotedTrueDataCon
      , ff1 == promotedFalseDataCon
#else
      | TyConApp TyCon
tc [PredType
x,PredType
y] <- PredType
t1
      , TyCon
tc TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
typeNatLeqTyCon
#endif
      , TyConApp TyCon
tc' [] <- PredType
t2
      -> do
          (ExtraOp
x', Normalised
n1) <- ExtraDefs -> PredType -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs PredType
x
          (ExtraOp
y', Normalised
n2) <- ExtraDefs -> PredType -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs PredType
y
          let res :: MaybeT TcPluginM SolverConstraint
res | TyCon
tc' TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
promotedTrueDataCon  = SolverConstraint -> MaybeT TcPluginM SolverConstraint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> ExtraOp -> ExtraOp -> Bool -> Normalised -> SolverConstraint
NatInequality Ct
ct ExtraOp
x' ExtraOp
y' Bool
True  (Normalised -> Normalised -> Normalised
mergeNormalised Normalised
n1 Normalised
n2))
                  | TyCon
tc' TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
promotedFalseDataCon = SolverConstraint -> MaybeT TcPluginM SolverConstraint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> ExtraOp -> ExtraOp -> Bool -> Normalised -> SolverConstraint
NatInequality Ct
ct ExtraOp
x' ExtraOp
y' Bool
False (Normalised -> Normalised -> Normalised
mergeNormalised Normalised
n1 Normalised
n2))
                  | Bool
otherwise                   = CommandLineOption -> MaybeT TcPluginM SolverConstraint
forall (m :: * -> *) a. MonadFail m => CommandLineOption -> m a
fail CommandLineOption
"Nothing"
          MaybeT TcPluginM SolverConstraint
res
    Pred
_ -> CommandLineOption -> MaybeT TcPluginM SolverConstraint
forall (m :: * -> *) a. MonadFail m => CommandLineOption -> m a
fail CommandLineOption
"Nothing"
  where
    isNatKind :: Kind -> Bool
    isNatKind :: PredType -> Bool
isNatKind = (PredType -> PredType -> Bool
`eqType` PredType
typeNatKind)

createWantedFromNormalised :: ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised :: ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised ExtraDefs
defs SolverConstraint
sct = do
  let extractCtSides :: SolverConstraint -> (Ct, PredType, PredType)
extractCtSides (NatEquality Ct
ct ExtraOp
t1 ExtraOp
t2 Normalised
_)   = (Ct
ct, ExtraDefs -> ExtraOp -> PredType
reifyEOP ExtraDefs
defs ExtraOp
t1, ExtraDefs -> ExtraOp -> PredType
reifyEOP ExtraDefs
defs ExtraOp
t2)
      extractCtSides (NatInequality Ct
ct ExtraOp
x ExtraOp
y Bool
b Normalised
_) =
        let tc :: TyCon
tc = if Bool
b then TyCon
promotedTrueDataCon else TyCon
promotedFalseDataCon
#if MIN_VERSION_ghc(9,2,0)
            t1 = TyConApp (ordTyCon defs)
                    [ boolTy
                    , TyConApp typeNatCmpTyCon [reifyEOP defs x, reifyEOP defs y]
                    , TyConApp promotedTrueDataCon []
                    , TyConApp promotedTrueDataCon []
                    , TyConApp promotedFalseDataCon []
                    ]
#else
            t1 :: PredType
t1 = TyCon -> [PredType] -> PredType
TyConApp TyCon
typeNatLeqTyCon [ExtraDefs -> ExtraOp -> PredType
reifyEOP ExtraDefs
defs ExtraOp
x, ExtraDefs -> ExtraOp -> PredType
reifyEOP ExtraDefs
defs ExtraOp
y]
#endif
            t2 :: PredType
t2 = TyCon -> [PredType] -> PredType
TyConApp TyCon
tc []
          in (Ct
ct, PredType
t1, PredType
t2)
  let (Ct
ct, PredType
t1, PredType
t2) = SolverConstraint -> (Ct, PredType, PredType)
extractCtSides SolverConstraint
sct
  PredType
newPredTy <- case HasDebugCallStack => PredType -> Maybe (TyCon, [PredType])
PredType -> Maybe (TyCon, [PredType])
splitTyConApp_maybe (PredType -> Maybe (TyCon, [PredType]))
-> PredType -> Maybe (TyCon, [PredType])
forall a b. (a -> b) -> a -> b
$ CtEvidence -> PredType
ctEvPred (CtEvidence -> PredType) -> CtEvidence -> PredType
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    Just (TyCon
tc, [PredType
a, PredType
b, PredType
_, PredType
_]) | TyCon
tc TyCon -> Unique -> Bool
forall a. Uniquable a => a -> Unique -> Bool
`hasKey` Unique
eqPrimTyConKey -> PredType -> TcPluginM PredType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [PredType] -> PredType
mkTyConApp TyCon
tc [PredType
a, PredType
b, PredType
t1, PredType
t2])
    Maybe (TyCon, [PredType])
_ -> CommandLineOption -> TcPluginM PredType
forall (m :: * -> *) a. MonadFail m => CommandLineOption -> m a
fail CommandLineOption
"Nothing"
  CtEvidence
ev <- CtLoc -> PredType -> TcPluginM CtEvidence
newWanted (Ct -> CtLoc
ctLoc Ct
ct) PredType
newPredTy
  let ctN :: Ct
ctN = case Ct
ct of
#if MIN_VERSION_ghc(9,2,0)
              CQuantCan qc -> CQuantCan (qc { qci_ev = ev})
#endif
              Ct
ctX -> Ct
ctX { cc_ev :: CtEvidence
cc_ev = CtEvidence
ev }
  Ct -> TcPluginM Ct
forall (m :: * -> *) a. Monad m => a -> m a
return Ct
ctN

fromSolverConstraint :: SolverConstraint -> Ct
fromSolverConstraint :: SolverConstraint -> Ct
fromSolverConstraint (NatEquality Ct
ct ExtraOp
_ ExtraOp
_ Normalised
_)  = Ct
ct
fromSolverConstraint (NatInequality Ct
ct ExtraOp
_ ExtraOp
_ Bool
_ Normalised
_) = Ct
ct

lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs = do
    Module
md <- ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
myModule FastString
myPackage
#if MIN_VERSION_ghc(9,2,0)
    md2 <- lookupModule ordModule basePackage
#endif
    TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> ExtraDefs
ExtraDefs (TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> TyCon
 -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM
     (TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> ExtraDefs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"Max"
              TcPluginM
  (TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM
     (TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> TyCon
      -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"Min"
#if MIN_VERSION_ghc(8,4,0)
              TcPluginM
  (TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> TyCon
   -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM
     (TyCon
      -> TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TyCon -> TcPluginM TyCon
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TyCon]
typeNatTyCons [TyCon] -> Int -> TyCon
forall a. [a] -> Int -> a
!! Int
5)
              TcPluginM
  (TyCon
   -> TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM
     (TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TyCon -> TcPluginM TyCon
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TyCon]
typeNatTyCons [TyCon] -> Int -> TyCon
forall a. [a] -> Int -> a
!! Int
6)
#else
              <*> look md "Div"
              <*> look md "Mod"
#endif
              TcPluginM
  (TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM
     (TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"FLog"
              TcPluginM (TyCon -> TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM (TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"CLog"
              TcPluginM (TyCon -> TyCon -> TyCon -> TyCon -> ExtraDefs)
-> TcPluginM TyCon
-> TcPluginM (TyCon -> TyCon -> TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"Log"
              TcPluginM (TyCon -> TyCon -> TyCon -> ExtraDefs)
-> TcPluginM TyCon -> TcPluginM (TyCon -> TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"GCD"
              TcPluginM (TyCon -> TyCon -> ExtraDefs)
-> TcPluginM TyCon -> TcPluginM (TyCon -> ExtraDefs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"LCM"
#if MIN_VERSION_ghc(9,2,0)
              <*> look md2 "OrdCond"
#else
              TcPluginM (TyCon -> ExtraDefs)
-> TcPluginM TyCon -> TcPluginM ExtraDefs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TyCon -> TcPluginM TyCon
forall (f :: * -> *) a. Applicative f => a -> f a
pure TyCon
typeNatLeqTyCon
#endif
  where
    look :: Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
s = Name -> TcPluginM TyCon
tcLookupTyCon (Name -> TcPluginM TyCon) -> TcPluginM Name -> TcPluginM TyCon
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
lookupName Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
s)
    myModule :: ModuleName
myModule  = CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"GHC.TypeLits.Extra"
    myPackage :: FastString
myPackage = CommandLineOption -> FastString
fsLit CommandLineOption
"ghc-typelits-extra"
#if MIN_VERSION_ghc(9,2,0)
    ordModule   = mkModuleName "Data.Type.Ord"
    basePackage = fsLit "base"
#endif

-- Utils
evMagic :: Ct -> Maybe EvTerm
evMagic :: Ct -> Maybe EvTerm
evMagic Ct
ct = case PredType -> Pred
classifyPredType (PredType -> Pred) -> PredType -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> PredType
ctEvPred (CtEvidence -> PredType) -> CtEvidence -> PredType
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq PredType
t1 PredType
t2 -> EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (CommandLineOption -> PredType -> PredType -> EvTerm
evByFiat CommandLineOption
"ghc-typelits-extra" PredType
t1 PredType
t2)
    Pred
_                  -> Maybe EvTerm
forall a. Maybe a
Nothing