{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Type.Nat (
Nat(..),
toNatural,
fromNatural,
cata,
explicitShow,
explicitShowsPrec,
SNat(..),
snatToNat,
snatToNatural,
SNatI(..),
withSNat,
reify,
reflect,
reflectToNum,
eqNat,
EqNat,
discreteNat,
induction,
induction1,
InlineInduction (..),
inlineInduction,
unfoldedFix,
Plus,
Mult,
Mult2,
DivMod2,
ToGHC,
FromGHC,
nat0, nat1, nat2, nat3, nat4, nat5, nat6, nat7, nat8, nat9,
Nat0, Nat1, Nat2, Nat3, Nat4, Nat5, Nat6, Nat7, Nat8, Nat9,
proofPlusZeroN,
proofPlusNZero,
proofMultZeroN,
proofMultNZero,
proofMultOneN,
proofMultNOne,
) where
import Data.Function (fix)
import Data.Proxy (Proxy (..))
import Data.Type.Dec (Dec (..))
import Data.Type.Equality ((:~:) (..), TestEquality (..))
import Data.Typeable (Typeable)
import Numeric.Natural (Natural)
import qualified GHC.TypeLits as GHC
import Unsafe.Coerce (unsafeCoerce)
#if !MIN_VERSION_base(4,11,0)
import Data.Type.Equality (type (==))
#endif
import Data.Nat
data SNat (n :: Nat) where
SZ :: SNat 'Z
SS :: SNatI n => SNat ('S n)
deriving (Typeable)
deriving instance Show (SNat p)
class SNatI (n :: Nat) where snat :: SNat n
instance SNatI 'Z where snat = SZ
instance SNatI n => SNatI ('S n) where snat = SS
withSNat :: SNat n -> (SNatI n => r) -> r
withSNat SZ k = k
withSNat SS k = k
reflect :: forall n proxy. SNatI n => proxy n -> Nat
reflect _ = unTagged (induction1 (Tagged Z) (retagMap S) :: Tagged n Nat)
reflectToNum :: forall n m proxy. (SNatI n, Num m) => proxy n -> m
reflectToNum _ = unTagged (induction1 (Tagged 0) (retagMap (1+)) :: Tagged n m)
reify :: forall r. Nat -> (forall n. SNatI n => Proxy n -> r) -> r
reify Z f = f (Proxy :: Proxy 'Z)
reify (S n) f = reify n (\(_p :: Proxy n) -> f (Proxy :: Proxy ('S n)))
snatToNat :: forall n. SNat n -> Nat
snatToNat SZ = Z
snatToNat SS = unTagged (induction1 (Tagged Z) (retagMap S) :: Tagged n Nat)
snatToNatural :: forall n. SNat n -> Natural
snatToNatural SZ = 0
snatToNatural SS = unTagged (induction1 (Tagged 0) (retagMap succ) :: Tagged n Natural)
eqNat :: forall n m. (SNatI n, SNatI m) => Maybe (n :~: m)
eqNat = getNatEq $ induction (NatEq start) (\p -> NatEq (step p)) where
start :: forall p. SNatI p => Maybe ('Z :~: p)
start = case snat :: SNat p of
SZ -> Just Refl
SS -> Nothing
step :: forall p q. SNatI q => NatEq p -> Maybe ('S p :~: q)
step hind = case snat :: SNat q of
SZ -> Nothing
SS -> step' hind
step' :: forall p q. SNatI q => NatEq p -> Maybe ('S p :~: 'S q)
step' (NatEq hind) = do
Refl <- hind :: Maybe (p :~: q)
return Refl
newtype NatEq n = NatEq { getNatEq :: forall m. SNatI m => Maybe (n :~: m) }
discreteNat :: forall n m. (SNatI n, SNatI m) => Dec (n :~: m)
discreteNat = getDiscreteNat $ induction (DiscreteNat start) (\p -> DiscreteNat (step p))
where
start :: forall p. SNatI p => Dec ('Z :~: p)
start = case snat :: SNat p of
SZ -> Yes Refl
SS -> No $ \p -> case p of {}
step :: forall p q. SNatI q => DiscreteNat p -> Dec ('S p :~: q)
step rec = case snat :: SNat q of
SZ -> No $ \p -> case p of {}
SS -> step' rec
step' :: forall p q. SNatI q => DiscreteNat p -> Dec ('S p :~: 'S q)
step' (DiscreteNat rec) = case rec :: Dec (p :~: q) of
Yes Refl -> Yes Refl
No np -> No $ \Refl -> np Refl
newtype DiscreteNat n = DiscreteNat { getDiscreteNat :: forall m. SNatI m => Dec (n :~: m) }
instance TestEquality SNat where
testEquality SZ SZ = Just Refl
testEquality SZ SS = Nothing
testEquality SS SZ = Nothing
testEquality SS SS = eqNat
type family EqNat (n :: Nat) (m :: Nat) where
EqNat 'Z 'Z = 'True
EqNat ('S n) ('S m) = EqNat n m
EqNat n m = 'False
#if !MIN_VERSION_base(4,11,0)
type instance n == m = EqNat n m
#endif
induction1
:: forall n f a. SNatI n
=> f 'Z a
-> (forall m. SNatI m => f m a -> f ('S m) a)
-> f n a
induction1 z f = go where
go :: forall m. SNatI m => f m a
go = case snat :: SNat m of
SZ -> z
SS -> f go
induction
:: forall n f. SNatI n
=> f 'Z
-> (forall m. SNatI m => f m -> f ('S m))
-> f n
induction z f = go where
go :: forall m. SNatI m => f m
go = case snat :: SNat m of
SZ -> z
SS -> f go
class SNatI n => InlineInduction (n :: Nat) where
inlineInduction1 :: f 'Z a -> (forall m. InlineInduction m => f m a -> f ('S m) a) -> f n a
instance InlineInduction 'Z where
inlineInduction1 z _ = z
instance InlineInduction n => InlineInduction ('S n) where
inlineInduction1 z f = f (inlineInduction1 z f)
{-# SPECIALIZE instance InlineInduction ('S 'Z) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S 'Z)) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S 'Z))) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S ('S 'Z)))) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S ('S ('S 'Z))))) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S ('S ('S ('S 'Z)))))) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S ('S ('S ('S ('S 'Z))))))) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S ('S ('S ('S ('S ('S 'Z)))))))) #-}
{-# SPECIALIZE instance InlineInduction ('S ('S ('S ('S ('S ('S ('S ('S ('S 'Z))))))))) #-}
inlineInduction
:: forall n f. InlineInduction n
=> f 'Z
-> (forall m. InlineInduction m => f m -> f ('S m))
-> f n
inlineInduction z f = unConst' $ inlineInduction1 (Const' z) (Const' . f . unConst')
newtype Const' (f :: Nat -> *) (n :: Nat) a = Const' { unConst' :: f n }
unfoldedFix :: forall n a proxy. InlineInduction n => proxy n -> (a -> a) -> a
unfoldedFix _ = getFix (inlineInduction1 start step :: Fix n a) where
start :: Fix 'Z a
start = Fix fix
step :: Fix m a -> Fix ('S m) a
step (Fix go) = Fix $ \f -> f (go f)
newtype Fix (n :: Nat) a = Fix { getFix :: (a -> a) -> a }
type family ToGHC (n :: Nat) :: GHC.Nat where
ToGHC 'Z = 0
ToGHC ('S n) = 1 GHC.+ ToGHC n
type family FromGHC (n :: GHC.Nat) :: Nat where
FromGHC 0 = 'Z
FromGHC n = 'S (FromGHC (n GHC.- 1))
type family Plus (n :: Nat) (m :: Nat) :: Nat where
Plus 'Z m = m
Plus ('S n) m = 'S (Plus n m)
type family Mult (n :: Nat) (m :: Nat) :: Nat where
Mult 'Z m = 'Z
Mult ('S n) m = Plus m (Mult n m)
type family Mult2 (n :: Nat) :: Nat where
Mult2 'Z = 'Z
Mult2 ('S n) = 'S ('S (Mult2 n))
type family DivMod2 (n :: Nat) :: (Nat, Bool) where
DivMod2 'Z = '( 'Z, 'False)
DivMod2 ('S 'Z) = '( 'Z, 'True)
DivMod2 ('S ('S n)) = DivMod2' (DivMod2 n)
type family DivMod2' (p :: (Nat, Bool)) :: (Nat, Bool) where
DivMod2' '(n, b) = '( 'S n, b)
type Nat0 = 'Z
type Nat1 = 'S Nat0
type Nat2 = 'S Nat1
type Nat3 = 'S Nat2
type Nat4 = 'S Nat3
type Nat5 = 'S Nat4
type Nat6 = 'S Nat5
type Nat7 = 'S Nat6
type Nat8 = 'S Nat7
type Nat9 = 'S Nat8
proofPlusZeroN :: Plus Nat0 n :~: n
proofPlusZeroN = Refl
proofPlusNZero :: SNatI n => Plus n Nat0 :~: n
proofPlusNZero = getProofPlusNZero $ induction (ProofPlusNZero Refl) step where
step :: forall m. ProofPlusNZero m -> ProofPlusNZero ('S m)
step (ProofPlusNZero Refl) = ProofPlusNZero Refl
{-# NOINLINE [1] proofPlusNZero #-}
{-# RULES "Nat: n + 0 = n" proofPlusNZero = unsafeCoerce (Refl :: () :~: ()) #-}
newtype ProofPlusNZero n = ProofPlusNZero { getProofPlusNZero :: Plus n Nat0 :~: n }
proofMultZeroN :: Mult Nat0 n :~: Nat0
proofMultZeroN = Refl
proofMultNZero :: forall n proxy. SNatI n => proxy n -> Mult n Nat0 :~: Nat0
proofMultNZero _ =
getProofMultNZero (induction (ProofMultNZero Refl) step :: ProofMultNZero n)
where
step :: forall m. ProofMultNZero m -> ProofMultNZero ('S m)
step (ProofMultNZero Refl) = ProofMultNZero Refl
{-# NOINLINE [1] proofMultNZero #-}
{-# RULES "Nat: n * 0 = n" proofMultNZero = unsafeCoerce (Refl :: () :~: ()) #-}
newtype ProofMultNZero n = ProofMultNZero { getProofMultNZero :: Mult n Nat0 :~: Nat0 }
proofMultOneN :: SNatI n => Mult Nat1 n :~: n
proofMultOneN = proofPlusNZero
{-# NOINLINE [1] proofMultOneN #-}
{-# RULES "Nat: 1 * n = n" proofMultOneN = unsafeCoerce (Refl :: () :~: ()) #-}
proofMultNOne :: SNatI n => Mult n Nat1 :~: n
proofMultNOne = getProofMultNOne $ induction (ProofMultNOne Refl) step where
step :: forall m. ProofMultNOne m -> ProofMultNOne ('S m)
step (ProofMultNOne Refl) = ProofMultNOne Refl
{-# NOINLINE [1] proofMultNOne #-}
{-# RULES "Nat: n * 1 = n" proofMultNOne = unsafeCoerce (Refl :: () :~: ()) #-}
newtype ProofMultNOne n = ProofMultNOne { getProofMultNOne :: Mult n Nat1 :~: n }
newtype Tagged (n :: Nat) a = Tagged a deriving Show
unTagged :: Tagged n a -> a
unTagged (Tagged a) = a
retagMap :: (a -> b) -> Tagged n a -> Tagged m b
retagMap f = Tagged . f . unTagged