{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeInType          #-}
{-# LANGUAGE TypeOperators       #-}

-- |
-- Module      : GHC.TypeLits.Compare
-- Copyright   : (c) Justin Le 2016
-- License     : MIT
-- Maintainer  : justin@jle.im
-- Stability   : unstable
-- Portability : non-portable
--
--
-- This module provides the ability to refine given 'KnownNat' instances
-- using "GHC.TypeLits"'s comparison API, and also the ability to prove
-- inequalities and upper/lower limits.
--
-- If a library function requires @1 '<=' n@ constraint, but only
-- @'KnownNat' n@ is available:
--
-- @
-- foo :: (KnownNat n, 1 '<=' n) => 'Data.Proxy.Proxy' n -> Int
--
-- bar :: KnownNat n => Proxy n -> Int
-- bar n = case (Proxy :: Proxy 1) '%<=?' n of
--           'LE'  'Refl' -> foo n
--           'NLE' _    -> 0
-- @
--
-- @foo@ requires that @1 <= n@, but @bar@ has to handle all cases of @n@.
-- @%<=?@ lets you compare the 'KnownNat's in two 'Data.Proxy.Proxy's and returns
-- a @:<=?@, which has two constructors, 'LE' and 'NLE'.
--
-- If you pattern match on the result, in the 'LE' branch, the constraint
-- @1 <= n@ will be satisfied according to GHC, so @bar@ can safely call
-- @foo@, and GHC will recognize that @1 <= n@.
--
-- In the 'NLE' branch, the constraint that @1 > n@ is satisfied, so any
-- functions that require that constraint would be callable.
--
-- For convenience, 'isLE' and 'isNLE' are also offered:
--
-- @
-- bar :: KnownNat n => Proxy n -> Int
-- bar n = case 'isLE' (Proxy :: Proxy 1) n of
--           'Just' Refl -> foo n
--           'Nothing'   -> 0
-- @
--
-- Similarly, if a library function requires something involving 'CmpNat',
-- you can use 'cmpNat' and the 'SCmpNat' type:
--
-- @
-- foo1 :: (KnownNat n, 'CmpNat' 5 n ~ LT) => Proxy n -> Int
-- foo2 :: (KnownNat n, CmpNat 5 n ~ GT) => Proxy n -> Int
--
-- bar :: KnownNat n => Proxy n -> Int
-- bar n = case 'cmpNat' (Proxy :: Proxy 5) n of
--           'CLT' Refl -> foo1 n
--           'CEQ' Refl -> 0
--           'CGT' Refl -> foo2 n
-- @
--
-- You can use the 'Refl' that 'cmpNat' gives you with 'flipCmpNat' and
-- 'cmpNatLE' to "flip" the inequality or turn it into something compatible
-- with '<=?' (useful for when you have to work with libraries that mix the
-- two methods) or 'cmpNatEq' and 'eqCmpNat' to get to/from witnesses for
-- equality of the two 'Nat's.
--
-- This module is useful for helping bridge between libraries that use
-- different 'Nat'-based comparison systems in their type constraints.
module GHC.TypeLits.Compare
  ( -- * '<=' and '<=?'
    (:<=?)(..)
  , (%<=?)
    -- ** Convenience functions
  , isLE
  , isNLE
    -- * 'CmpNat'
  , SCmpNat(..)
  , GHC.TypeLits.Compare.cmpNat
    -- ** Manipulating witnesses
  , flipCmpNat
  , cmpNatEq
  , eqCmpNat
  , reflCmpNat
  , cmpNatLE
  , cmpNatGOrdering
  )
  where

import           Data.Kind
import           Data.Type.Equality
import           GHC.TypeLits ( Nat, KnownNat, CmpNat
                              , type (<=?)
                              , natVal )
import           Unsafe.Coerce
import           Data.GADT.Compare

-- | Simplified version of '%<=?': check if @m@ is less than or equal to to
-- @n@.  If it is, match on @'Just' 'Refl'@ to get GHC to believe it,
-- within the body of the pattern match.
isLE
    :: (KnownNat m, KnownNat n)
    => p m
    -> q n
    -> Maybe ((m <=? n) :~: 'True)
isLE :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> Maybe ((m <=? n) :~: 'True)
isLE p m
m q n
n = case p m
m forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> m :<=? n
%<=? q n
n of
             LE  (m <=? n) :~: 'True
Refl -> forall a. a -> Maybe a
Just forall {k} (a :: k). a :~: a
Refl
             NLE (m <=? n) :~: 'False
_ (n <=? m) :~: 'True
_  -> forall a. Maybe a
Nothing

-- | Simplified version of '%<=?': check if @m@ is not less than or equal
-- to to @n@.  If it is, match on @'Just' 'Refl'@ to get GHC to believe it,
-- within the body of the pattern match.
isNLE
    :: (KnownNat m, KnownNat n)
    => p m
    -> q n
    -> Maybe ((m <=? n) :~: 'False)
isNLE :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> Maybe ((m <=? n) :~: 'False)
isNLE p m
m q n
n = case p m
m forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> m :<=? n
%<=? q n
n of
    NLE (m <=? n) :~: 'False
Refl (n <=? m) :~: 'True
Refl -> forall a. a -> Maybe a
Just forall {k} (a :: k). a :~: a
Refl
    LE  (m <=? n) :~: 'True
_         -> forall a. Maybe a
Nothing

-- | Two possible ordered relationships between two natural numbers.
data (:<=?) :: Nat -> Nat -> Type where
    LE  :: ((m <=? n) :~: 'True)  -> (m :<=? n)
    NLE :: ((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> (m :<=? n)

-- | Compare @m@ and @n@, classifying their relationship into some
-- constructor of ':<=?'.
(%<=?)
     :: (KnownNat m, KnownNat n)
     => p m
     -> q n
     -> (m :<=? n)
p m
m %<=? :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> m :<=? n
%<=? q n
n | forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal p m
m forall a. Ord a => a -> a -> Bool
<= forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal q n
n = forall (m :: Nat) (n :: Nat). ((m <=? n) :~: 'True) -> m :<=? n
LE  (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)
         | Bool
otherwise            = forall (m :: Nat) (n :: Nat).
((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n
NLE (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl) (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)

-- | Three possible ordered relationships between two natural numbers.
data SCmpNat :: Nat -> Nat -> Type where
    CLT :: (CmpNat m n :~: 'LT) -> SCmpNat m n
    CEQ :: (CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
    CGT :: (CmpNat m n :~: 'GT) -> SCmpNat m n

-- | Compare @m@ and @n@, classifying their relationship into some
-- constructor of 'SCmpNat'.
cmpNat
    :: (KnownNat m, KnownNat n)
    => p m
    -> q n
    -> SCmpNat m n
cmpNat :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> SCmpNat m n
cmpNat p m
m q n
n = case forall a. Ord a => a -> a -> Ordering
compare (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal p m
m) (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal q n
n) of
               Ordering
LT -> forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'LT) -> SCmpNat m n
CLT (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)
               Ordering
EQ -> forall (m :: Nat) (n :: Nat).
(CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
CEQ (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl) (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)
               Ordering
GT -> forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'GT) -> SCmpNat m n
CGT (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)

-- | Flip an inequality.
flipCmpNat :: SCmpNat m n -> SCmpNat n m
flipCmpNat :: forall (m :: Nat) (n :: Nat). SCmpNat m n -> SCmpNat n m
flipCmpNat = \case CLT CmpNat m n :~: 'LT
Refl      -> forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'GT) -> SCmpNat m n
CGT (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)
                   CEQ CmpNat m n :~: 'EQ
Refl m :~: n
Refl -> forall (m :: Nat) (n :: Nat).
(CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
CEQ (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl) forall {k} (a :: k). a :~: a
Refl
                   CGT CmpNat m n :~: 'GT
Refl      -> forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'LT) -> SCmpNat m n
CLT (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)

-- | @'CmpNat' m n@ being 'EQ' implies that @m@ is equal to @n@.
cmpNatEq :: (CmpNat m n :~: 'EQ) -> (m :~: n)
cmpNatEq :: forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'EQ) -> m :~: n
cmpNatEq = \case CmpNat m n :~: 'EQ
Refl -> forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl

-- | A witness of equality implies that @'CmpNat' m n@ is 'Eq'.
eqCmpNat :: (m :~: n) -> (CmpNat m n :~: 'EQ)
eqCmpNat :: forall (m :: Nat) (n :: Nat). (m :~: n) -> CmpNat m n :~: 'EQ
eqCmpNat = \case m :~: n
Refl -> forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl

-- | Inject a witness of equality into an 'SCmpNat' at 'CEQ'.
reflCmpNat :: (m :~: n) -> SCmpNat m n
reflCmpNat :: forall (m :: Nat) (n :: Nat). (m :~: n) -> SCmpNat m n
reflCmpNat m :~: n
r = forall (m :: Nat) (n :: Nat).
(CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
CEQ (forall (m :: Nat) (n :: Nat). (m :~: n) -> CmpNat m n :~: 'EQ
eqCmpNat m :~: n
r) m :~: n
r

-- | Convert to ':<=?'
cmpNatLE :: SCmpNat m n -> (m :<=? n)
cmpNatLE :: forall (m :: Nat) (n :: Nat). SCmpNat m n -> m :<=? n
cmpNatLE = \case CLT CmpNat m n :~: 'LT
Refl      -> forall (m :: Nat) (n :: Nat). ((m <=? n) :~: 'True) -> m :<=? n
LE  (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)
                 CEQ CmpNat m n :~: 'EQ
Refl m :~: n
Refl -> forall (m :: Nat) (n :: Nat). ((m <=? n) :~: 'True) -> m :<=? n
LE  (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)
                 CGT CmpNat m n :~: 'GT
Refl      -> forall (m :: Nat) (n :: Nat).
((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n
NLE (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl) (forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl)

-- | Convert to 'GOrdering'
--
-- @since 0.4.0.0
cmpNatGOrdering :: SCmpNat n m -> GOrdering n m
cmpNatGOrdering :: forall (n :: Nat) (m :: Nat). SCmpNat n m -> GOrdering n m
cmpNatGOrdering = \case
    CLT CmpNat n m :~: 'LT
Refl      -> forall {k} (a :: k) (b :: k). GOrdering a b
GLT
    CEQ CmpNat n m :~: 'EQ
Refl n :~: m
Refl -> forall {k} (a :: k). GOrdering a a
GEQ
    CGT CmpNat n m :~: 'GT
Refl      -> forall {k} (a :: k) (b :: k). GOrdering a b
GGT