{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
#if __GLASGOW_HASKELL__ < 801
#define nonDetCmpType cmpType
#endif
module GHC.TypeLits.Normalise.Unify
(
CType (..)
, CoreSOP
, normaliseNat
, reifySOP
, UnifyItem (..)
, CoreUnify
, substsSOP
, substsSubst
, UnifyResult (..)
, unifyNats
, unifiers
, fvSOP
, subtractIneq
, solveIneq
, ineqToSubst
, subtractionToPred
, isNatural
)
where
import Control.Monad.Trans.Writer.Strict
import Data.Function (on)
import Data.List ((\\), intersect, mapAccumL, nub)
import Data.Maybe (fromMaybe, mapMaybe)
import GHC.Base (isTrue#,(==#))
import GHC.Integer (smallInteger)
import GHC.Integer.Logarithms (integerLogBase#)
import Outputable (Outputable (..), (<+>), ($$), text)
import TcPluginM (TcPluginM, tcPluginTrace)
import TcRnMonad (Ct, ctEvidence, isGiven)
import TcRnTypes (ctEvPred)
import TcTypeNats (typeNatAddTyCon, typeNatExpTyCon, typeNatMulTyCon,
typeNatSubTyCon, typeNatLeqTyCon)
import Type (EqRel (NomEq), PredTree (EqPred), TyVar, classifyPredType,
coreView, eqType, mkNumLitTy, mkTyConApp, mkTyVarTy,
nonDetCmpType, PredType, mkPrimEqPred)
import TyCoRep (Kind, Type (..), TyLit (..))
import TysWiredIn (boolTy, promotedTrueDataCon)
import UniqSet (UniqSet, unionManyUniqSets, emptyUniqSet, unionUniqSets,
unitUniqSet)
import GHC.TypeLits.Normalise.SOP
import GHC.TypeLits (Nat)
newtype CType = CType { unCType :: Type }
deriving Outputable
instance Eq CType where
(CType ty1) == (CType ty2) = eqType ty1 ty2
instance Ord CType where
compare (CType ty1) (CType ty2) = nonDetCmpType ty1 ty2
type CoreSOP = SOP TyVar CType
type CoreProduct = Product TyVar CType
type CoreSymbol = Symbol TyVar CType
normaliseNat :: Type -> Writer [(Type,Type)] CoreSOP
normaliseNat ty | Just ty1 <- coreView ty = normaliseNat ty1
normaliseNat (TyVarTy v) = return (S [P [V v]])
normaliseNat (LitTy (NumTyLit i)) = return (S [P [I i]])
normaliseNat (TyConApp tc [x,y])
| tc == typeNatAddTyCon = mergeSOPAdd <$> normaliseNat x <*> normaliseNat y
| tc == typeNatSubTyCon = do
tell [(x,y)]
mergeSOPAdd <$> normaliseNat x
<*> (mergeSOPMul (S [P [I (-1)]]) <$> normaliseNat y)
| tc == typeNatMulTyCon = mergeSOPMul <$> normaliseNat x <*> normaliseNat y
| tc == typeNatExpTyCon = normaliseExp <$> normaliseNat x <*> normaliseNat y
normaliseNat t = return (S [P [C (CType t)]])
reifySOP :: CoreSOP -> Type
reifySOP = combineP . map negateP . unS
where
negateP :: CoreProduct -> Either CoreProduct CoreProduct
negateP (P ((I i):ps@(_:_))) | i == (-1) = Left (P ps)
negateP (P ((I i):ps)) | i < 0 = Left (P ((I (abs i)):ps))
negateP ps = Right ps
combineP :: [Either CoreProduct CoreProduct] -> Type
combineP [] = mkNumLitTy 0
combineP [p] = either (\p' -> mkTyConApp typeNatSubTyCon
[mkNumLitTy 0, reifyProduct p'])
reifyProduct p
combineP [p1,p2] = either
(\x -> either
(\y -> let r = mkTyConApp typeNatSubTyCon [reifyProduct x
,reifyProduct y]
in mkTyConApp typeNatSubTyCon [mkNumLitTy 0, r])
(\y -> mkTyConApp typeNatSubTyCon [reifyProduct y, reifyProduct x])
p2)
(\x -> either
(\y -> mkTyConApp typeNatSubTyCon [reifyProduct x, reifyProduct y])
(\y -> mkTyConApp typeNatAddTyCon [reifyProduct x, reifyProduct y])
p2)
p1
combineP (p:ps) = let es = combineP ps
in either (\x -> mkTyConApp typeNatSubTyCon
[es, reifyProduct x])
(\x -> mkTyConApp typeNatAddTyCon
[reifyProduct x, es])
p
reifyProduct :: CoreProduct -> Type
reifyProduct (P ps) =
let ps' = map reifySymbol (foldr mergeExp [] ps)
in foldr (\t1 t2 -> mkTyConApp typeNatMulTyCon [t1,t2]) (head ps') (tail ps')
where
mergeExp :: CoreSymbol -> [Either CoreSymbol (CoreSOP,[CoreProduct])]
-> [Either CoreSymbol (CoreSOP,[CoreProduct])]
mergeExp (E s p) [] = [Right (s,[p])]
mergeExp (E s1 p1) (y:ys)
| Right (s2,p2) <- y
, s1 == s2
= Right (s1,(p1:p2)) : ys
| otherwise
= Right (s1,[p1]) : y : ys
mergeExp x ys = Left x : ys
reifySymbol :: Either CoreSymbol (CoreSOP,[CoreProduct]) -> Type
reifySymbol (Left (I i) ) = mkNumLitTy i
reifySymbol (Left (C c) ) = unCType c
reifySymbol (Left (V v) ) = mkTyVarTy v
reifySymbol (Left (E s p)) = mkTyConApp typeNatExpTyCon [reifySOP s,reifyProduct p]
reifySymbol (Right (s1,s2)) = mkTyConApp typeNatExpTyCon
[reifySOP s1
,reifySOP (S s2)
]
subtractIneq
:: (CoreSOP, CoreSOP, Bool)
-> CoreSOP
subtractIneq (x,y,isLE)
| isLE
= mergeSOPAdd y (mergeSOPMul (S [P [I (-1)]]) x)
| otherwise
= mergeSOPAdd x (mergeSOPMul (S [P [I (-1)]]) (mergeSOPAdd y (S [P [I 1]])))
sopToIneq
:: CoreSOP
-> Maybe Ineq
sopToIneq (S [P ((I i):l),r])
| i < 0
= Just (mergeSOPMul (S [P [I (negate i)]]) (S [P l]),S [r],True)
sopToIneq (S [r,P ((I i:l))])
| i < 0
= Just (mergeSOPMul (S [P [I (negate i)]]) (S [P l]),S [r],True)
sopToIneq _ = Nothing
ineqToSubst
:: Ineq
-> Maybe CoreUnify
ineqToSubst (x,S [P [V v]],True)
= Just (SubstItem v x)
ineqToSubst _
= Nothing
subtractionToPred
:: (Type,Type)
-> (PredType, Kind)
subtractionToPred (x,y) =
(mkPrimEqPred (mkTyConApp typeNatLeqTyCon [y,x])
(mkTyConApp promotedTrueDataCon [])
,boolTy)
type CoreUnify = UnifyItem TyVar CType
data UnifyItem v c = SubstItem { siVar :: v
, siSOP :: SOP v c
}
| UnifyItem { siLHS :: SOP v c
, siRHS :: SOP v c
}
deriving Eq
instance (Outputable v, Outputable c) => Outputable (UnifyItem v c) where
ppr (SubstItem {..}) = ppr siVar <+> text " := " <+> ppr siSOP
ppr (UnifyItem {..}) = ppr siLHS <+> text " :~ " <+> ppr siRHS
substsSOP :: (Ord v, Ord c) => [UnifyItem v c] -> SOP v c -> SOP v c
substsSOP [] u = u
substsSOP ((SubstItem {..}):s) u = substsSOP s (substSOP siVar siSOP u)
substsSOP ((UnifyItem {}):s) u = substsSOP s u
substSOP :: (Ord v, Ord c) => v -> SOP v c -> SOP v c -> SOP v c
substSOP tv e = foldr1 mergeSOPAdd . map (substProduct tv e) . unS
substProduct :: (Ord v, Ord c) => v -> SOP v c -> Product v c -> SOP v c
substProduct tv e = foldr1 mergeSOPMul . map (substSymbol tv e) . unP
substSymbol :: (Ord v, Ord c) => v -> SOP v c -> Symbol v c -> SOP v c
substSymbol _ _ s@(I _) = S [P [s]]
substSymbol _ _ s@(C _) = S [P [s]]
substSymbol tv e (V tv')
| tv == tv' = e
| otherwise = S [P [V tv']]
substSymbol tv e (E s p) = normaliseExp (substSOP tv e s) (substProduct tv e p)
substsSubst :: (Ord v, Ord c) => [UnifyItem v c] -> [UnifyItem v c] -> [UnifyItem v c]
substsSubst s = map subt
where
subt si@(SubstItem {..}) = si {siSOP = substsSOP s siSOP}
subt si@(UnifyItem {..}) = si {siLHS = substsSOP s siLHS, siRHS = substsSOP s siRHS}
{-# INLINEABLE substsSubst #-}
data UnifyResult
= Win
| Lose
| Draw [CoreUnify]
instance Outputable UnifyResult where
ppr Win = text "Win"
ppr (Draw subst) = text "Draw" <+> ppr subst
ppr Lose = text "Lose"
unifyNats :: Ct -> CoreSOP -> CoreSOP -> TcPluginM UnifyResult
unifyNats ct u v = do
tcPluginTrace "unifyNats" (ppr ct $$ ppr u $$ ppr v)
return (unifyNats' ct u v)
unifyNats' :: Ct -> CoreSOP -> CoreSOP -> UnifyResult
unifyNats' ct u v
= if eqFV u v
then if containsConstants u || containsConstants v
then if u == v
then Win
else Draw (filter diffFromConstraint (unifiers ct u v))
else if u == v
then Win
else Lose
else Draw (filter diffFromConstraint (unifiers ct u v))
where
diffFromConstraint (UnifyItem x y) = not (x == u && y == v)
diffFromConstraint _ = True
unifiers :: Ct -> CoreSOP -> CoreSOP -> [CoreUnify]
unifiers ct u@(S [P [V x]]) v
= case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 _
| CType (reifySOP u) /= CType t1 || isGiven (ctEvidence ct) -> [SubstItem x v]
_ -> []
unifiers ct u v@(S [P [V x]])
= case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq _ t2
| CType (reifySOP v) /= CType t2 || isGiven (ctEvidence ct) -> [SubstItem x u]
_ -> []
unifiers ct u@(S [P [C _]]) v
= case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
| CType (reifySOP u) /= CType t1 || CType (reifySOP v) /= CType t2 -> [UnifyItem u v]
_ -> []
unifiers ct u v@(S [P [C _]])
= case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
| CType (reifySOP u) /= CType t1 || CType (reifySOP v) /= CType t2 -> [UnifyItem u v]
_ -> []
unifiers ct u v = unifiers' ct u v
unifiers' :: Ct -> CoreSOP -> CoreSOP -> [CoreUnify]
unifiers' _ct (S [P [V x]]) (S []) = [SubstItem x (S [P [I 0]])]
unifiers' _ct (S []) (S [P [V x]]) = [SubstItem x (S [P [I 0]])]
unifiers' _ct (S [P [V x]]) s = [SubstItem x s]
unifiers' _ct s (S [P [V x]]) = [SubstItem x s]
unifiers' _ct s1@(S [P [C _]]) s2 = [UnifyItem s1 s2]
unifiers' _ct s1 s2@(S [P [C _]]) = [UnifyItem s1 s2]
unifiers' ct (S [P [E s1 p1]]) (S [P [E s2 p2]])
| s1 == s2 = unifiers' ct (S [p1]) (S [p2])
unifiers' ct (S [P [E (S [P s1]) p1]]) (S [P p2])
| all (`elem` p2) s1
= let base = intersect s1 p2
diff = p2 \\ s1
in unifiers ct (S [P diff]) (S [P [E (S [P base]) (P [I (-1)]),E (S [P base]) p1]])
unifiers' ct (S [P p2]) (S [P [E (S [P s1]) p1]])
| all (`elem` p2) s1
= let base = intersect s1 p2
diff = p2 \\ s1
in unifiers ct (S [P [E (S [P base]) (P [I (-1)]),E (S [P base]) p1]]) (S [P diff])
unifiers' ct (S [P [E (S [P [I i]]) p]]) (S [P [I j]])
= case integerLogBase i j of
Just k -> unifiers' ct (S [p]) (S [P [I k]])
Nothing -> []
unifiers' ct (S [P [I j]]) (S [P [E (S [P [I i]]) p]])
= case integerLogBase i j of
Just k -> unifiers' ct (S [p]) (S [P [I k]])
Nothing -> []
unifiers' ct (S [P [E s1 p1]]) (S [p2]) = case collectBases p2 of
Just (b:bs,ps) | all (== s1) (b:bs) ->
unifiers' ct (S [p1]) (S ps)
_ -> []
unifiers' ct (S [p2]) (S [P [E s1 p1]]) = case collectBases p2 of
Just (b:bs,ps) | all (== s1) (b:bs) ->
unifiers' ct (S ps) (S [p1])
_ -> []
unifiers' ct (S [P ((I i):ps)]) (S [P [I j]]) =
case safeDiv j i of
Just k -> unifiers' ct (S [P ps]) (S [P [I k]])
_ -> []
unifiers' ct (S [P [I j]]) (S [P ((I i):ps)]) =
case safeDiv j i of
Just k -> unifiers' ct (S [P ps]) (S [P [I k]])
_ -> []
unifiers' ct (S [P ps1]) (S [P ps2])
| null psx = []
| otherwise = unifiers' ct (S [P ps1'']) (S [P ps2''])
where
ps1' = ps1 \\ psx
ps2' = ps2 \\ psx
ps1'' | null ps1' = [I 1]
| otherwise = ps1'
ps2'' | null ps2' = [I 1]
| otherwise = ps2'
psx = intersect ps1 ps2
unifiers' ct (S ((P [I i]):ps1)) (S ((P [I j]):ps2))
| i < j = unifiers' ct (S ps1) (S ((P [I (j-i)]):ps2))
| i > j = unifiers' ct (S ((P [I (i-j)]):ps1)) (S ps2)
unifiers' ct s1@(S ps1) s2@(S ps2) = case sopToIneq k1 of
Just (s1',s2',_)
| s1' /= s1 || s2' /= s1
, fromMaybe True (isNatural s1')
, fromMaybe True (isNatural s2')
-> unifiers' ct s1' s2'
_ | null psx
, length ps1 == length ps2
-> case nub (concat (zipWith (\x y -> unifiers' ct (S [x]) (S [y])) ps1 ps2)) of
[] -> unifiers'' ct (S ps1) (S ps2)
[k] | length ps1 == length ps2 -> [k]
_ -> []
| null psx
, isGiven (ctEvidence ct)
-> unifiers'' ct (S ps1) (S ps2)
| null psx
-> []
_ -> unifiers' ct (S ps1'') (S ps2'')
where
k1 = subtractIneq (s1,s2,True)
ps1' = ps1 \\ psx
ps2' = ps2 \\ psx
ps1'' | null ps1' = [P [I 0]]
| otherwise = ps1'
ps2'' | null ps2' = [P [I 0]]
| otherwise = ps2'
psx = intersect ps1 ps2
unifiers'' :: Ct -> CoreSOP -> CoreSOP -> [CoreUnify]
unifiers'' ct (S [P [I i],P [V v]]) s2
| isGiven (ctEvidence ct) = [SubstItem v (mergeSOPAdd s2 (S [P [I (negate i)]]))]
unifiers'' ct s1 (S [P [I i],P [V v]])
| isGiven (ctEvidence ct) = [SubstItem v (mergeSOPAdd s1 (S [P [I (negate i)]]))]
unifiers'' _ _ _ = []
collectBases :: CoreProduct -> Maybe ([CoreSOP],[CoreProduct])
collectBases = fmap unzip . traverse go . unP
where
go (E s1 p1) = Just (s1,p1)
go _ = Nothing
fvSOP :: CoreSOP -> UniqSet TyVar
fvSOP = unionManyUniqSets . map fvProduct . unS
fvProduct :: CoreProduct -> UniqSet TyVar
fvProduct = unionManyUniqSets . map fvSymbol . unP
fvSymbol :: CoreSymbol -> UniqSet TyVar
fvSymbol (I _) = emptyUniqSet
fvSymbol (C _) = emptyUniqSet
fvSymbol (V v) = unitUniqSet v
fvSymbol (E s p) = fvSOP s `unionUniqSets` fvProduct p
eqFV :: CoreSOP -> CoreSOP -> Bool
eqFV = (==) `on` fvSOP
containsConstants :: CoreSOP -> Bool
containsConstants = any (any (\c -> case c of {(C _) -> True; _ -> False}) . unP) . unS
safeDiv :: Integer -> Integer -> Maybe Integer
safeDiv i j
| j == 0 = Just 0
| otherwise = case divMod i j of
(k,0) -> Just k
_ -> Nothing
integerLogBase :: Integer -> Integer -> Maybe Integer
integerLogBase x y | x > 1 && y > 0 =
let z1 = integerLogBase# x y
z2 = integerLogBase# x (y-1)
in if isTrue# (z1 ==# z2)
then Nothing
else Just (smallInteger z1)
integerLogBase _ _ = Nothing
isNatural :: CoreSOP -> Maybe Bool
isNatural (S []) = return True
isNatural (S [P []]) = return True
isNatural (S [P (I i:ps)])
| i >= 0 = isNatural (S [P ps])
| otherwise = return False
isNatural (S [P (V _:ps)]) = isNatural (S [P ps])
isNatural (S [P (E s p:ps)]) = do
sN <- isNatural s
pN <- isNatural (S [p])
if sN && pN
then isNatural (S [P ps])
else Nothing
isNatural (S [P [I (-1)],P [E s p]]) = (&&) <$> isNatural s <*> isNatural (S [p])
isNatural (S [P _]) = Nothing
isNatural (S (p:ps)) = do
pN <- isNatural (S [p])
pK <- isNatural (S ps)
case (pN,pK) of
(True,True) -> return True
(False,False) -> return False
_ -> Nothing
solveIneq
:: Word
-> Ineq
-> Ineq
-> Maybe Bool
solveIneq 0 _ _ = Nothing
solveIneq k want@(_,_,True) have@(_,_,True)
| want == have
= Just True
| null solved
= Nothing
| otherwise
= Just (or solved)
where
solved = mapMaybe (uncurry (solveIneq (k - 1))) new
new = concatMap (\f -> f want have) ineqRules
solveIneq _ _ _ = Just False
type Ineq = (CoreSOP, CoreSOP, Bool)
type IneqRule = Ineq -> Ineq -> [(Ineq,Ineq)]
ineqRules
:: [IneqRule]
ineqRules =
[ leTrans
, plusMonotone
, timesMonotone
, powMonotone
, pow2MonotoneSpecial
, haveSmaller
, haveBigger
]
leTrans :: IneqRule
leTrans want@(a,b,le) (x,y,_)
| S [P [I a']] <- a
, S [P [I x']] <- x
, x' >= a'
= [(want,(a,y,le))]
| S [P [I b']] <- b
, S [P [I y']] <- y
, y' < b'
= [(want,(x,b,le))]
leTrans _ _ = []
plusMonotone :: IneqRule
plusMonotone want have
| Just want' <- sopToIneq (subtractIneq want)
, want' /= want
= [(want',have)]
| Just have' <- sopToIneq (subtractIneq have)
, have' /= have
= [(want,have')]
plusMonotone _ _ = []
haveSmaller :: IneqRule
haveSmaller want have
| (S (x:y:ys),us,True) <- have
= [(want,(S (x:ys),us,True))
,(want,(S (y:ys),us,True))
]
| (S [P [I 1]], S [P (I _:p@(_:_))],True) <- have
= [(want,(S [P [I 1]],S [P p],True))]
haveSmaller _ _ = []
haveBigger :: IneqRule
haveBigger want have
| (_,S vs,True) <- want
, (as,S bs,True) <- have
, let vs' = vs \\ bs
, not (null vs')
= [(want,(as,mergeSOPAdd (S bs) (S vs'),True))]
haveBigger _ _ = []
timesMonotone :: IneqRule
timesMonotone want@(a,b,le) have@(x,y,_)
| S [P a'@(_:_:_)] <- a
, S [P x'] <- x
, S [P y'] <- y
, let ax = a' \\ x'
, let ay = a' \\ y'
, not (null ax)
, not (null ay)
, let az = if length ax <= length ay then S [P ax] else S [P ay]
= [(want,(mergeSOPMul az x, mergeSOPMul az y,le))]
| S [P b'@(_:_:_)] <- b
, S [P x'] <- x
, S [P y'] <- y
, let bx = b' \\ x'
, let by = b' \\ y'
, not (null bx)
, not (null by)
, let bz = if length bx <= length by then S [P bx] else S [P by]
= [(want,(mergeSOPMul bz x, mergeSOPMul bz y,le))]
| S [P x'@(_:_:_)] <- x
, S [P a'] <- a
, S [P b'] <- b
, let xa = x' \\ a'
, let xb = x' \\ b'
, not (null xa)
, not (null xb)
, let xz = if length xa <= length xb then S [P xa] else S [P xb]
= [((mergeSOPMul xz a, mergeSOPMul xz b,le),have)]
| S [P y'@(_:_:_)] <- y
, S [P a'] <- a
, S [P b'] <- b
, let ya = y' \\ a'
, let yb = y' \\ b'
, not (null ya)
, not (null yb)
, let yz = if length ya <= length yb then S [P ya] else S [P yb]
= [((mergeSOPMul yz a, mergeSOPMul yz b,le),have)]
timesMonotone _ _ = []
powMonotone :: IneqRule
powMonotone want (x, S [P [E yS yP]],le)
= case x of
S [P [E xS xP]]
| xS == yS
-> [(want,(S [xP],S [yP],le))]
| xP == yP
-> [(want,(xS,yS,le))]
_ | x == yS
-> [(want,(S [P [I 1]],S [yP],le))]
_ -> []
powMonotone (a,S [P [E bS bP]],le) have
= case a of
S [P [E aS aP]]
| aS == bS
-> [((S [aP],S [bP],le),have)]
| aP == bP
-> [((aS,bS,le),have)]
_ | a == bS
-> [((S [P [I 1]],S [bP],le),have)]
_ -> []
powMonotone _ _ = []
pow2MonotoneSpecial :: IneqRule
pow2MonotoneSpecial (a,b,le) have
| Just a' <- facSOP 2 a
, Just b' <- facSOP 2 b
= [((a',b',le),have)]
pow2MonotoneSpecial want (x,y,le)
| Just x' <- facSOP 2 x
, Just y' <- facSOP 2 y
= [(want,(x',y',le))]
pow2MonotoneSpecial _ _ = []
facSOP
:: Integer
-> CoreSOP
-> Maybe CoreSOP
facSOP n (S [P ps]) = fmap (S . concat . map unS) (traverse (facSymbol n) ps)
facSOP _ _ = Nothing
facSymbol
:: Integer
-> CoreSymbol
-> Maybe CoreSOP
facSymbol n (I i)
| Just j <- integerLogBase n i
= Just (S [P [I j]])
facSymbol n (E s p)
| Just s' <- facSOP n s
= Just (mergeSOPMul s' (S [p]))
facSymbol _ _ = Nothing