{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Clash.Promoted.Nat
(
SNat (..)
, snatProxy
, withSNat
, snatToInteger, snatToNatural, snatToNum
, natToInteger, natToNatural, natToNum
, addSNat, mulSNat, powSNat, minSNat, maxSNat, succSNat
, subSNat, divSNat, modSNat, flogBaseSNat, clogBaseSNat, logBaseSNat, predSNat
, pow2SNat
, SNatLE (..), compareSNat
, UNat (..)
, toUNat
, fromUNat
, addUNat, mulUNat, powUNat
, predUNat, subUNat
, BNat (..)
, toBNat
, fromBNat
, showBNat
, succBNat, addBNat, mulBNat, powBNat
, predBNat, div2BNat, div2Sub1BNat, log2BNat
, stripZeros
, leToPlus
, leToPlusKN
)
where
import Data.Kind (Type)
import GHC.Show (appPrec)
import GHC.TypeLits (KnownNat, Nat, type (+), type (-), type (*),
type (^), type (<=), natVal)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift (..))
import Numeric.Natural (Natural)
import Unsafe.Coerce (unsafeCoerce)
import Clash.XException (ShowX (..), showsPrecXWith)
data SNat (n :: Nat) where
SNat :: KnownNat n => SNat n
instance Lift (SNat n) where
lift s = sigE [| SNat |]
(appT (conT ''SNat) (litT $ numTyLit (snatToInteger s)))
snatProxy :: KnownNat n => proxy n -> SNat n
snatProxy _ = SNat
instance Show (SNat n) where
showsPrec d p@SNat | n <= 1024 = showChar 'd' . shows n
| otherwise = showParen (d > appPrec) $
showString "SNat @" . shows n
where
n = snatToInteger p
instance ShowX (SNat n) where
showsPrecX = showsPrecXWith showsPrec
{-# INLINE withSNat #-}
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat f = f SNat
natToInteger :: forall n . KnownNat n => Integer
natToInteger = snatToInteger (SNat @n)
{-# INLINE natToInteger #-}
snatToInteger :: SNat n -> Integer
snatToInteger p@SNat = natVal p
{-# INLINE snatToInteger #-}
natToNatural :: forall n . KnownNat n => Natural
natToNatural = snatToNatural (SNat @n)
{-# INLINE natToNatural #-}
snatToNatural :: SNat n -> Natural
snatToNatural = naturalFromInteger . snatToInteger
{-# INLINE snatToNatural #-}
natToNum :: forall n a . (Num a, KnownNat n) => a
natToNum = snatToNum (SNat @n)
{-# INLINE natToNum #-}
snatToNum :: forall a n . Num a => SNat n -> a
snatToNum p@SNat = fromInteger (snatToInteger p)
{-# INLINE snatToNum #-}
data UNat :: Nat -> Type where
UZero :: UNat 0
USucc :: UNat n -> UNat (n + 1)
instance KnownNat n => Show (UNat n) where
show x = 'u':show (natVal x)
instance KnownNat n => ShowX (UNat n) where
showsPrecX = showsPrecXWith showsPrec
toUNat :: forall n . SNat n -> UNat n
toUNat p@SNat = fromI @n (snatToInteger p)
where
fromI :: forall m . Integer -> UNat m
fromI 0 = unsafeCoerce @(UNat 0) @(UNat m) UZero
fromI n = unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (USucc (fromI @(m-1) (n - 1)))
fromUNat :: UNat n -> SNat n
fromUNat UZero = SNat :: SNat 0
fromUNat (USucc x) = addSNat (fromUNat x) (SNat :: SNat 1)
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UZero y = y
addUNat x UZero = x
addUNat (USucc x) y = USucc (addUNat x y)
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat UZero _ = UZero
mulUNat _ UZero = UZero
mulUNat (USucc x) y = addUNat y (mulUNat x y)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat _ UZero = USucc UZero
powUNat x (USucc y) = mulUNat x (powUNat x y)
predUNat :: UNat (n+1) -> UNat n
predUNat (USucc x) = x
predUNat UZero =
error "predUNat: impossible: 0 minus 1, -1 is not a natural number"
subUNat :: UNat (m+n) -> UNat n -> UNat m
subUNat x UZero = x
subUNat (USucc x) (USucc y) = subUNat x y
subUNat UZero _ = error "subUNat: impossible: 0 + (n + 1) ~ 0"
predSNat :: SNat (a+1) -> SNat (a)
predSNat SNat = SNat
{-# INLINE predSNat #-}
succSNat :: SNat a -> SNat (a+1)
succSNat SNat = SNat
{-# INLINE succSNat #-}
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat SNat SNat = SNat
{-# INLINE addSNat #-}
infixl 6 `addSNat`
subSNat :: SNat (a+b) -> SNat b -> SNat a
subSNat SNat SNat = SNat
{-# INLINE subSNat #-}
infixl 6 `subSNat`
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat SNat SNat = SNat
{-# INLINE mulSNat #-}
infixl 7 `mulSNat`
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat SNat SNat = SNat
{-# NOINLINE powSNat #-}
infixr 8 `powSNat`
divSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Div a b)
divSNat SNat SNat = SNat
{-# INLINE divSNat #-}
infixl 7 `divSNat`
modSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Mod a b)
modSNat SNat SNat = SNat
{-# INLINE modSNat #-}
infixl 7 `modSNat`
minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat SNat SNat = SNat
maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat SNat SNat = SNat
flogBaseSNat :: (2 <= base, 1 <= x)
=> SNat base
-> SNat x
-> SNat (FLog base x)
flogBaseSNat SNat SNat = SNat
{-# NOINLINE flogBaseSNat #-}
clogBaseSNat :: (2 <= base, 1 <= x)
=> SNat base
-> SNat x
-> SNat (CLog base x)
clogBaseSNat SNat SNat = SNat
{-# NOINLINE clogBaseSNat #-}
logBaseSNat :: (FLog base x ~ CLog base x)
=> SNat base
-> SNat x
-> SNat (Log base x)
logBaseSNat SNat SNat = SNat
{-# NOINLINE logBaseSNat #-}
pow2SNat :: SNat a -> SNat (2^a)
pow2SNat SNat = SNat
{-# INLINE pow2SNat #-}
data SNatLE a b where
SNatLE :: forall a b . a <= b => SNatLE a b
SNatGT :: forall a b . (b+1) <= a => SNatLE a b
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
compareSNat a b =
if snatToInteger a <= snatToInteger b
then unsafeCoerce (SNatLE @0 @0)
else unsafeCoerce (SNatGT @1 @0)
data BNat :: Nat -> Type where
BT :: BNat 0
B0 :: BNat n -> BNat (2*n)
B1 :: BNat n -> BNat ((2*n) + 1)
instance KnownNat n => Show (BNat n) where
show x = 'b':show (natVal x)
instance KnownNat n => ShowX (BNat n) where
showsPrecX = showsPrecXWith showsPrec
showBNat :: BNat n -> String
showBNat = go []
where
go :: String -> BNat m -> String
go xs BT = "0b" ++ xs
go xs (B0 x) = go ('0':xs) x
go xs (B1 x) = go ('1':xs) x
toBNat :: SNat n -> BNat n
toBNat s@SNat = toBNat' (snatToInteger s)
where
toBNat' :: forall m . Integer -> BNat m
toBNat' 0 = unsafeCoerce BT
toBNat' n = case n `divMod` 2 of
(n',1) -> unsafeCoerce (B1 (toBNat' @(Div (m-1) 2) n'))
(n',_) -> unsafeCoerce (B0 (toBNat' @(Div m 2) n'))
fromBNat :: BNat n -> SNat n
fromBNat BT = SNat :: SNat 0
fromBNat (B0 x) = mulSNat (SNat :: SNat 2) (fromBNat x)
fromBNat (B1 x) = addSNat (mulSNat (SNat :: SNat 2) (fromBNat x))
(SNat :: SNat 1)
addBNat :: BNat n -> BNat m -> BNat (n+m)
addBNat (B0 a) (B0 b) = B0 (addBNat a b)
addBNat (B0 a) (B1 b) = B1 (addBNat a b)
addBNat (B1 a) (B0 b) = B1 (addBNat a b)
addBNat (B1 a) (B1 b) = B0 (succBNat (addBNat a b))
addBNat BT b = b
addBNat a BT = a
mulBNat :: BNat n -> BNat m -> BNat (n*m)
mulBNat BT _ = BT
mulBNat _ BT = BT
mulBNat (B0 a) b = B0 (mulBNat a b)
mulBNat (B1 a) b = addBNat (B0 (mulBNat a b)) b
powBNat :: BNat n -> BNat m -> BNat (n^m)
powBNat _ BT = B1 BT
powBNat a (B0 b) = let z = powBNat a b
in mulBNat z z
powBNat a (B1 b) = let z = powBNat a b
in mulBNat a (mulBNat z z)
succBNat :: BNat n -> BNat (n+1)
succBNat BT = B1 BT
succBNat (B0 a) = B1 a
succBNat (B1 a) = B0 (succBNat a)
predBNat :: (1 <= n) => BNat n -> BNat (n-1)
predBNat (B1 a) = case stripZeros a of
BT -> BT
a' -> B0 a'
predBNat (B0 x) = B1 (predBNat x)
div2BNat :: BNat (2*n) -> BNat n
div2BNat BT = BT
div2BNat (B0 x) = x
div2BNat (B1 _) = error "div2BNat: impossible: 2*n ~ 2*n+1"
div2Sub1BNat :: BNat (2*n+1) -> BNat n
div2Sub1BNat (B1 x) = x
div2Sub1BNat _ = error "div2Sub1BNat: impossible: 2*n+1 ~ 2*n"
log2BNat :: BNat (2^n) -> BNat n
log2BNat BT = error "log2BNat: log2(0) not defined"
log2BNat (B1 x) = case stripZeros x of
BT -> BT
_ -> error "log2BNat: impossible: 2^n ~ 2x+1"
log2BNat (B0 x) = succBNat (log2BNat x)
stripZeros :: BNat n -> BNat n
stripZeros BT = BT
stripZeros (B1 x) = B1 (stripZeros x)
stripZeros (B0 BT) = BT
stripZeros (B0 x) = case stripZeros x of
BT -> BT
k -> B0 k
leToPlus
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
)
=> (forall m . (n ~ (k + m)) => r)
-> r
leToPlus r = r @(n - k)
{-# INLINE leToPlus #-}
leToPlusKN
:: forall (k :: Nat) (n :: Nat) r
. ( k <= n
, KnownNat k
, KnownNat n
)
=> (forall m . (n ~ (k + m), KnownNat m) => r)
-> r
leToPlusKN r = r @(n - k)
{-# INLINE leToPlusKN #-}