{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Clash.Promoted.Nat
(
SNat (..)
, snatProxy
, withSNat
, snatToInteger, snatToNatural, snatToNum
, natToInteger, natToNatural, natToNum
, addSNat, mulSNat, powSNat, minSNat, maxSNat, succSNat
, subSNat, divSNat, modSNat, flogBaseSNat, clogBaseSNat, logBaseSNat, predSNat
, pow2SNat
, SNatLE (..), compareSNat
, UNat (..)
, toUNat
, fromUNat
, addUNat, mulUNat, powUNat
, predUNat, subUNat
, BNat (..)
, toBNat
, fromBNat
, showBNat
, succBNat, addBNat, mulBNat, powBNat
, predBNat, div2BNat, div2Sub1BNat, log2BNat
, stripZeros
, leToPlus
, leToPlusKN
)
where
import Data.Kind (Type)
import GHC.Show (appPrec)
import GHC.TypeLits (KnownNat, Nat, type (+), type (-), type (*),
type (^), type (<=), natVal)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift (..))
#if MIN_VERSION_template_haskell(2,16,0)
import Language.Haskell.TH.Compat
#endif
import Numeric.Natural (Natural)
import Unsafe.Coerce (unsafeCoerce)
import Clash.XException (ShowX (..), showsPrecXWith)
data SNat (n :: Nat) where
SNat :: KnownNat n => SNat n
instance Lift (SNat n) where
lift s = sigE [| SNat |]
(appT (conT ''SNat) (litT $ numTyLit (snatToInteger s)))
#if MIN_VERSION_template_haskell(2,16,0)
liftTyped = liftTypedFromUntyped
#endif
snatProxy :: KnownNat n => proxy n -> SNat n
snatProxy _ = SNat
instance Show (SNat n) where
showsPrec d p@SNat | n <= 1024 = showChar 'd' . shows n
| otherwise = showParen (d > appPrec) $
showString "SNat @" . shows n
where
n = snatToInteger p
instance ShowX (SNat n) where
showsPrecX = showsPrecXWith showsPrec
{-# INLINE withSNat #-}
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat f = f SNat
natToInteger :: forall n . KnownNat n => Integer
natToInteger = snatToInteger (SNat @n)
{-# INLINE natToInteger #-}
snatToInteger :: SNat n -> Integer
snatToInteger p@SNat = natVal p
{-# INLINE snatToInteger #-}
natToNatural :: forall n . KnownNat n => Natural
natToNatural = snatToNatural (SNat @n)
{-# INLINE natToNatural #-}
snatToNatural :: SNat n -> Natural
snatToNatural = naturalFromInteger . snatToInteger
{-# INLINE snatToNatural #-}
natToNum :: forall n a . (Num a, KnownNat n) => a
natToNum = snatToNum (SNat @n)
{-# INLINE natToNum #-}
snatToNum :: forall a n . Num a => SNat n -> a
snatToNum p@SNat = fromInteger (snatToInteger p)
{-# INLINE snatToNum #-}
data UNat :: Nat -> Type where
UZero :: UNat 0
USucc :: UNat n -> UNat (n + 1)
instance KnownNat n => Show (UNat n) where
show x = 'u':show (natVal x)
instance KnownNat n => ShowX (UNat n) where
showsPrecX = showsPrecXWith showsPrec
toUNat :: forall n . SNat n -> UNat n
toUNat p@SNat = fromI @n (snatToInteger p)
where
fromI :: forall m . Integer -> UNat m
fromI 0 = unsafeCoerce @(UNat 0) @(UNat m) UZero
fromI n = unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (USucc (fromI @(m-1) (n - 1)))
fromUNat :: UNat n -> SNat n
fromUNat UZero = SNat :: SNat 0
fromUNat (USucc x) = addSNat (fromUNat x) (SNat :: SNat 1)
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UZero y = y
addUNat x UZero = x
addUNat (USucc x) y = USucc (addUNat x y)
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat UZero _ = UZero
mulUNat _ UZero = UZero
mulUNat (USucc x) y = addUNat y (mulUNat x y)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat _ UZero = USucc UZero
powUNat x (USucc y) = mulUNat x (powUNat x y)
predUNat :: UNat (n+1) -> UNat n
predUNat (USucc x) = x
predUNat UZero =
error "predUNat: impossible: 0 minus 1, -1 is not a natural number"
subUNat :: UNat (m+n) -> UNat n -> UNat m
subUNat x UZero = x
subUNat (USucc x) (USucc y) = subUNat x y
subUNat UZero _ = error "subUNat: impossible: 0 + (n + 1) ~ 0"
predSNat :: SNat (a+1) -> SNat (a)
predSNat SNat = SNat
{-# INLINE predSNat #-}
succSNat :: SNat a -> SNat (a+1)
succSNat SNat = SNat
{-# INLINE succSNat #-}
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat SNat SNat = SNat
{-# INLINE addSNat #-}
infixl 6 `addSNat`
subSNat :: SNat (a+b) -> SNat b -> SNat a
subSNat SNat SNat = SNat
{-# INLINE subSNat #-}
infixl 6 `subSNat`
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat SNat SNat = SNat
{-# INLINE mulSNat #-}
infixl 7 `mulSNat`
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat SNat SNat = SNat
{-# NOINLINE powSNat #-}
infixr 8 `powSNat`
divSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Div a b)
divSNat SNat SNat = SNat
{-# INLINE divSNat #-}
infixl 7 `divSNat`
modSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Mod a b)
modSNat SNat SNat = SNat
{-# INLINE modSNat #-}
infixl 7 `modSNat`
minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat SNat SNat = SNat
maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat SNat SNat = SNat
flogBaseSNat :: (2 <= base, 1 <= x)
=> SNat base
-> SNat x
-> SNat (FLog base x)
flogBaseSNat SNat SNat = SNat
{-# NOINLINE flogBaseSNat #-}
clogBaseSNat :: (2 <= base, 1 <= x)
=> SNat base
-> SNat x
-> SNat (CLog base x)
clogBaseSNat SNat SNat = SNat
{-# NOINLINE clogBaseSNat #-}
logBaseSNat :: (FLog base x ~ CLog base x)
=> SNat base
-> SNat x
-> SNat (Log base x)
logBaseSNat SNat SNat = SNat
{-# NOINLINE logBaseSNat #-}
pow2SNat :: SNat a -> SNat (2^a)
pow2SNat SNat = SNat
{-# INLINE pow2SNat #-}
data SNatLE a b where
SNatLE :: forall a b . a <= b => SNatLE a b
SNatGT :: forall a b . (b+1) <= a => SNatLE a b
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
compareSNat a b =
if snatToInteger a <= snatToInteger b
then unsafeCoerce (SNatLE @0 @0)
else unsafeCoerce (SNatGT @1 @0)
data BNat :: Nat -> Type where
BT :: BNat 0
B0 :: BNat n -> BNat (2*n)
B1 :: BNat n -> BNat ((2*n) + 1)
instance KnownNat n => Show (BNat n) where
show x = 'b':show (natVal x)
instance KnownNat n => ShowX (BNat n) where
showsPrecX = showsPrecXWith showsPrec
showBNat :: BNat n -> String
showBNat = go []
where
go :: String -> BNat m -> String
go xs BT = "0b" ++ xs
go xs (B0 x) = go ('0':xs) x
go xs (B1 x) = go ('1':xs) x
toBNat :: SNat n -> BNat n
toBNat s@SNat = toBNat' (snatToInteger s)
where
toBNat' :: forall m . Integer -> BNat m
toBNat' 0 = unsafeCoerce BT
toBNat' n = case n `divMod` 2 of
(n',1) -> unsafeCoerce (B1 (toBNat' @(Div (m-1) 2) n'))
(n',_) -> unsafeCoerce (B0 (toBNat' @(Div m 2) n'))
fromBNat :: BNat n -> SNat n
fromBNat BT = SNat :: SNat 0
fromBNat (B0 x) = mulSNat (SNat :: SNat 2) (fromBNat x)
fromBNat (B1 x) = addSNat (mulSNat (SNat :: SNat 2) (fromBNat x))
(SNat :: SNat 1)
addBNat :: BNat n -> BNat m -> BNat (n+m)
addBNat (B0 a) (B0 b) = B0 (addBNat a b)
addBNat (B0 a) (B1 b) = B1 (addBNat a b)
addBNat (B1 a) (B0 b) = B1 (addBNat a b)
addBNat (B1 a) (B1 b) = B0 (succBNat (addBNat a b))
addBNat BT b = b
addBNat a BT = a
mulBNat :: BNat n -> BNat m -> BNat (n*m)
mulBNat BT _ = BT
mulBNat _ BT = BT
mulBNat (B0 a) b = B0 (mulBNat a b)
mulBNat (B1 a) b = addBNat (B0 (mulBNat a b)) b
powBNat :: BNat n -> BNat m -> BNat (n^m)
powBNat _ BT = B1 BT
powBNat a (B0 b) = let z = powBNat a b
in mulBNat z z
powBNat a (B1 b) = let z = powBNat a b
in mulBNat a (mulBNat z z)
succBNat :: BNat n -> BNat (n+1)
succBNat BT = B1 BT
succBNat (B0 a) = B1 a
succBNat (B1 a) = B0 (succBNat a)
predBNat :: (1 <= n) => BNat n -> BNat (n-1)
predBNat (B1 a) = case stripZeros a of
BT -> BT
a' -> B0 a'
predBNat (B0 x) = B1 (predBNat x)
div2BNat :: BNat (2*n) -> BNat n
div2BNat BT = BT
div2BNat (B0 x) = x
div2BNat (B1 _) = error "div2BNat: impossible: 2*n ~ 2*n+1"
div2Sub1BNat :: BNat (2*n+1) -> BNat n
div2Sub1BNat (B1 x) = x
div2Sub1BNat _ = error "div2Sub1BNat: impossible: 2*n+1 ~ 2*n"
log2BNat :: BNat (2^n) -> BNat n
log2BNat BT = error "log2BNat: log2(0) not defined"
log2BNat (B1 x) = case stripZeros x of
BT -> BT
_ -> error "log2BNat: impossible: 2^n ~ 2x+1"
log2BNat (B0 x) = succBNat (log2BNat x)
stripZeros :: BNat n -> BNat n
stripZeros BT = BT
stripZeros (B1 x) = B1 (stripZeros x)
stripZeros (B0 BT) = BT
stripZeros (B0 x) = case stripZeros x of
BT -> BT
k -> B0 k
leToPlus
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
)
=> (forall m . (n ~ (k + m)) => r)
-> r
leToPlus r = r @(n - k)
{-# INLINE leToPlus #-}
leToPlusKN
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
, KnownNat k
, KnownNat n
)
=> (forall m . (n ~ (k + m), KnownNat m) => r)
-> r
leToPlusKN r = r @(n - k)
{-# INLINE leToPlusKN #-}