{-|
Copyright  :  (C) 2013-2016, University of Twente,
                  2016     , Myrtle Software Ltd
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}

{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise       #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Clash.Promoted.Nat
  ( -- * Singleton natural numbers
    -- ** Data type
    SNat (..)
    -- ** Construction
  , snatProxy
  , withSNat
    -- ** Conversion
  , snatToInteger, snatToNatural, snatToNum
    -- ** Arithmetic
  , addSNat, mulSNat, powSNat, minSNat, maxSNat, succSNat
    -- *** Partial
  , subSNat, divSNat, modSNat, flogBaseSNat, clogBaseSNat, logBaseSNat, predSNat
    -- *** Specialised
  , pow2SNat
    -- *** Comparison
  , SNatLE (..), compareSNat
    -- * Unary/Peano-encoded natural numbers
    -- ** Data type
  , UNat (..)
    -- ** Construction
  , toUNat
    -- ** Conversion
  , fromUNat
    -- ** Arithmetic
  , addUNat, mulUNat, powUNat
    -- *** Partial
  , predUNat, subUNat
    -- * Base-2 encoded natural numbers
    -- ** Data type
  , BNat (..)
    -- ** Construction
  , toBNat
    -- ** Conversion
  , fromBNat
    -- ** Pretty printing base-2 encoded natural numbers
  , showBNat
    -- ** Arithmetic
  , succBNat, addBNat, mulBNat, powBNat
    -- *** Partial
  , predBNat, div2BNat, div2Sub1BNat, log2BNat
    -- ** Normalisation
  , stripZeros
    -- * Constraints on natural numbers
  , leToPlus
  , leToPlusKN
  )
where

import Data.Kind          (Type)
import GHC.Show           (appPrec)
import GHC.TypeLits       (KnownNat, Nat, type (+), type (-), type (*),
                           type (^), type (<=), natVal)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural        (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift (..))
import Numeric.Natural    (Natural)
import Unsafe.Coerce      (unsafeCoerce)
import Clash.XException   (ShowX (..), showsPrecXWith)

{- $setup
>>> :set -XBinaryLiterals
>>> import Clash.Promoted.Nat.Literals (d789)
-}

-- | Singleton value for a type-level natural number 'n'
--
-- * "Clash.Promoted.Nat.Literals" contains a list of predefined 'SNat' literals
-- * "Clash.Promoted.Nat.TH" has functions to easily create large ranges of new
--   'SNat' literals
data SNat (n :: Nat) where
  SNat :: KnownNat n => SNat n

instance Lift (SNat n) where
  lift :: SNat n -> Q Exp
lift s :: SNat n
s = Q Exp -> TypeQ -> Q Exp
sigE [| SNat |]
                (TypeQ -> TypeQ -> TypeQ
appT (Name -> TypeQ
conT ''SNat) (TyLitQ -> TypeQ
litT (TyLitQ -> TypeQ) -> TyLitQ -> TypeQ
forall a b. (a -> b) -> a -> b
$ Integer -> TyLitQ
numTyLit (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
s)))

-- | Create an @`SNat` n@ from a proxy for /n/
snatProxy :: KnownNat n => proxy n -> SNat n
snatProxy :: proxy n -> SNat n
snatProxy _ = SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat

instance Show (SNat n) where
  showsPrec :: Int -> SNat n -> ShowS
showsPrec d :: Int
d p :: SNat n
p@SNat n
SNat | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= 1024 = Char -> ShowS
showChar 'd' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ShowS
forall a. Show a => a -> ShowS
shows Integer
n
                     | Bool
otherwise = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
appPrec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
                                     String -> ShowS
showString "SNat @" ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ShowS
forall a. Show a => a -> ShowS
shows Integer
n
   where
    n :: Integer
n = SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
p

instance ShowX (SNat n) where
  showsPrecX :: Int -> SNat n -> ShowS
showsPrecX = (Int -> SNat n -> ShowS) -> Int -> SNat n -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> SNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec

{-# INLINE withSNat #-}
-- | Supply a function with a singleton natural 'n' according to the context
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat :: (SNat n -> a) -> a
withSNat f :: SNat n -> a
f = SNat n -> a
f SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat

-- | Reify the type-level 'Nat' @n@ to it's term-level 'Integer' representation.
snatToInteger :: SNat n -> Integer
snatToInteger :: SNat n -> Integer
snatToInteger p :: SNat n
p@SNat n
SNat = SNat n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal SNat n
p
{-# INLINE snatToInteger #-}

snatToNatural :: SNat n -> Natural
snatToNatural :: SNat n -> Natural
snatToNatural = Integer -> Natural
naturalFromInteger (Integer -> Natural) -> (SNat n -> Integer) -> SNat n -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger
{-# INLINE snatToNatural #-}


-- | Reify the type-level 'Nat' @n@ to it's term-level 'Num'ber.
snatToNum :: forall a n . Num a => SNat n -> a
snatToNum :: SNat n -> a
snatToNum p :: SNat n
p@SNat n
SNat = Integer -> a
forall a. Num a => Integer -> a
fromInteger (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
p)
{-# INLINE snatToNum #-}

-- | Unary representation of a type-level natural
--
-- __NB__: Not synthesizable
data UNat :: Nat -> Type where
  UZero :: UNat 0
  USucc :: UNat n -> UNat (n + 1)

instance KnownNat n => Show (UNat n) where
  show :: UNat n -> String
show x :: UNat n
x = 'u'Char -> ShowS
forall a. a -> [a] -> [a]
:Integer -> String
forall a. Show a => a -> String
show (UNat n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal UNat n
x)

instance KnownNat n => ShowX (UNat n) where
  showsPrecX :: Int -> UNat n -> ShowS
showsPrecX = (Int -> UNat n -> ShowS) -> Int -> UNat n -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> UNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec

-- | Convert a singleton natural number to its unary representation
--
-- __NB__: Not synthesizable
toUNat :: forall n . SNat n -> UNat n
toUNat :: SNat n -> UNat n
toUNat p :: SNat n
p@SNat n
SNat = Integer -> UNat n
forall (m :: Nat). Integer -> UNat m
fromI @n (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
p)
  where
    fromI :: forall m . Integer -> UNat m
    fromI :: Integer -> UNat m
fromI 0 = UNat 0 -> UNat m
forall a b. a -> b
unsafeCoerce @(UNat 0) @(UNat m) UNat 0
UZero
    fromI n :: Integer
n = UNat ((m - 1) + 1) -> UNat m
forall a b. a -> b
unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (UNat (m - 1) -> UNat ((m - 1) + 1)
forall (n :: Nat). UNat n -> UNat (n + 1)
USucc (Integer -> UNat (m - 1)
forall (m :: Nat). Integer -> UNat m
fromI @(m-1) (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- 1)))

-- | Convert a unary-encoded natural number to its singleton representation
--
-- __NB__: Not synthesizable
fromUNat :: UNat n -> SNat n
fromUNat :: UNat n -> SNat n
fromUNat UZero     = SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 0
fromUNat (USucc x :: UNat n
x) = SNat n -> SNat 1 -> SNat (n + 1)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a + b)
addSNat (UNat n -> SNat n
forall (n :: Nat). UNat n -> SNat n
fromUNat UNat n
x) (SNat 1
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 1)

-- | Add two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UZero     y :: UNat m
y     = UNat m
UNat (n + m)
y
addUNat x :: UNat n
x         UZero = UNat n
UNat (n + m)
x
addUNat (USucc x :: UNat n
x) y :: UNat m
y     = UNat (n + m) -> UNat ((n + m) + 1)
forall (n :: Nat). UNat n -> UNat (n + 1)
USucc (UNat n -> UNat m -> UNat (n + m)
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n + m)
addUNat UNat n
x UNat m
y)

-- | Multiply two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat UZero      _     = UNat 0
UNat (n * m)
UZero
mulUNat _          UZero = UNat 0
UNat (n * m)
UZero
mulUNat (USucc x :: UNat n
x) y :: UNat m
y      = UNat m -> UNat (n * m) -> UNat (m + (n * m))
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n + m)
addUNat UNat m
y (UNat n -> UNat m -> UNat (n * m)
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n * m)
mulUNat UNat n
x UNat m
y)

-- | Power of two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat _ UZero     = UNat 0 -> UNat (0 + 1)
forall (n :: Nat). UNat n -> UNat (n + 1)
USucc UNat 0
UZero
powUNat x :: UNat n
x (USucc y :: UNat n
y) = UNat n -> UNat (n ^ n) -> UNat (n * (n ^ n))
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n * m)
mulUNat UNat n
x (UNat n -> UNat n -> UNat (n ^ n)
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n ^ m)
powUNat UNat n
x UNat n
y)

-- | Predecessor of a unary-encoded natural number
--
-- __NB__: Not synthesizable
predUNat :: UNat (n+1) -> UNat n
predUNat :: UNat (n + 1) -> UNat n
predUNat (USucc x :: UNat n
x) = UNat n
UNat n
x
predUNat UZero     =
  String -> UNat n
forall a. HasCallStack => String -> a
error "predUNat: impossible: 0 minus 1, -1 is not a natural number"

-- | Subtract two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
subUNat :: UNat (m+n) -> UNat n -> UNat m
subUNat :: UNat (m + n) -> UNat n -> UNat m
subUNat x :: UNat (m + n)
x         UZero     = UNat m
UNat (m + n)
x
subUNat (USucc x :: UNat n
x) (USucc y :: UNat n
y) = UNat (m + n) -> UNat n -> UNat m
forall (m :: Nat) (n :: Nat). UNat (m + n) -> UNat n -> UNat m
subUNat UNat n
UNat (m + n)
x UNat n
y
subUNat UZero     _         = String -> UNat m
forall a. HasCallStack => String -> a
error "subUNat: impossible: 0 + (n + 1) ~ 0"

-- | Predecessor of a singleton natural number
predSNat :: SNat (a+1) -> SNat (a)
predSNat :: SNat (a + 1) -> SNat a
predSNat SNat = SNat a
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE predSNat #-}

-- | Successor of a singleton natural number
succSNat :: SNat a -> SNat (a+1)
succSNat :: SNat a -> SNat (a + 1)
succSNat SNat = SNat (a + 1)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE succSNat #-}

-- | Add two singleton natural numbers
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat :: SNat a -> SNat b -> SNat (a + b)
addSNat SNat SNat = SNat (a + b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE addSNat #-}
infixl 6 `addSNat`

-- | Subtract two singleton natural numbers
subSNat :: SNat (a+b) -> SNat b -> SNat a
subSNat :: SNat (a + b) -> SNat b -> SNat a
subSNat SNat SNat = SNat a
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE subSNat #-}
infixl 6 `subSNat`

-- | Multiply two singleton natural numbers
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat :: SNat a -> SNat b -> SNat (a * b)
mulSNat SNat SNat = SNat (a * b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE mulSNat #-}
infixl 7 `mulSNat`

-- | Power of two singleton natural numbers
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat :: SNat a -> SNat b -> SNat (a ^ b)
powSNat SNat SNat = SNat (a ^ b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE powSNat #-}
infixr 8 `powSNat`

-- | Division of two singleton natural numbers
divSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Div a b)
divSNat :: SNat a -> SNat b -> SNat (Div a b)
divSNat SNat SNat = SNat (Div a b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE divSNat #-}
infixl 7 `divSNat`

-- | Modulo of two singleton natural numbers
modSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Mod a b)
modSNat :: SNat a -> SNat b -> SNat (Mod a b)
modSNat SNat SNat = SNat (Mod a b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE modSNat #-}
infixl 7 `modSNat`

minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat SNat SNat = SNat (Min a b)
forall (n :: Nat). KnownNat n => SNat n
SNat

maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat SNat SNat = SNat (Max a b)
forall (n :: Nat). KnownNat n => SNat n
SNat

-- | Floor of the logarithm of a natural number
flogBaseSNat :: (2 <= base, 1 <= x)
             => SNat base -- ^ Base
             -> SNat x
             -> SNat (FLog base x)
flogBaseSNat :: SNat base -> SNat x -> SNat (FLog base x)
flogBaseSNat SNat SNat = SNat (FLog base x)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE flogBaseSNat #-}

-- | Ceiling of the logarithm of a natural number
clogBaseSNat :: (2 <= base, 1 <= x)
             => SNat base -- ^ Base
             -> SNat x
             -> SNat (CLog base x)
clogBaseSNat :: SNat base -> SNat x -> SNat (CLog base x)
clogBaseSNat SNat SNat = SNat (CLog base x)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE clogBaseSNat #-}

-- | Exact integer logarithm of a natural number
--
-- __NB__: Only works when the argument is a power of the base
logBaseSNat :: (FLog base x ~ CLog base x)
            => SNat base -- ^ Base
            -> SNat x
            -> SNat (Log base x)
logBaseSNat :: SNat base -> SNat x -> SNat (Log base x)
logBaseSNat SNat SNat = SNat (Log base x)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE logBaseSNat #-}

-- | Power of two of a singleton natural number
pow2SNat :: SNat a -> SNat (2^a)
pow2SNat :: SNat a -> SNat (2 ^ a)
pow2SNat SNat = SNat (2 ^ a)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE pow2SNat #-}

-- | Ordering relation between two Nats
data SNatLE a b where
  SNatLE :: forall a b . a <= b => SNatLE a b
  SNatGT :: forall a b . (b+1) <= a => SNatLE a b

-- | Get an ordering relation between two SNats
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
compareSNat :: SNat a -> SNat b -> SNatLE a b
compareSNat a :: SNat a
a b :: SNat b
b =
  if SNat a -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat a
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= SNat b -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat b
b
     then SNatLE 0 0 -> SNatLE a b
forall a b. a -> b
unsafeCoerce ((0 <= 0) => SNatLE 0 0
forall (a :: Nat) (b :: Nat). (a <= b) => SNatLE a b
SNatLE @0 @0)
     else SNatLE 1 0 -> SNatLE a b
forall a b. a -> b
unsafeCoerce (((0 + 1) <= 1) => SNatLE 1 0
forall (a :: Nat) (b :: Nat). ((b + 1) <= a) => SNatLE a b
SNatGT @1 @0)

-- | Base-2 encoded natural number
--
--    * __NB__: The LSB is the left/outer-most constructor:
--    * __NB__: Not synthesizable
--
-- >>> B0 (B1 (B1 BT))
-- b6
--
-- == Constructors
--
-- * Starting/Terminating element:
--
--      @
--      __BT__ :: 'BNat' 0
--      @
--
-- * Append a zero (/0/):
--
--      @
--      __B0__ :: 'BNat' n -> 'BNat' (2 '*' n)
--      @
--
-- * Append a one (/1/):
--
--      @
--      __B1__ :: 'BNat' n -> 'BNat' ((2 '*' n) '+' 1)
--      @
data BNat :: Nat -> Type where
  BT :: BNat 0
  B0 :: BNat n -> BNat (2*n)
  B1 :: BNat n -> BNat ((2*n) + 1)

instance KnownNat n => Show (BNat n) where
  show :: BNat n -> String
show x :: BNat n
x = 'b'Char -> ShowS
forall a. a -> [a] -> [a]
:Integer -> String
forall a. Show a => a -> String
show (BNat n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal BNat n
x)

instance KnownNat n => ShowX (BNat n) where
  showsPrecX :: Int -> BNat n -> ShowS
showsPrecX = (Int -> BNat n -> ShowS) -> Int -> BNat n -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> BNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec

-- | Show a base-2 encoded natural as a binary literal
--
-- __NB__: The LSB is shown as the right-most bit
--
-- >>> d789
-- d789
-- >>> toBNat d789
-- b789
-- >>> showBNat (toBNat d789)
-- "0b1100010101"
-- >>> 0b1100010101 :: Integer
-- 789
showBNat :: BNat n -> String
showBNat :: BNat n -> String
showBNat = String -> BNat n -> String
forall (m :: Nat). String -> BNat m -> String
go []
  where
    go :: String -> BNat m -> String
    go :: String -> BNat m -> String
go xs :: String
xs BT  = "0b" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
xs
    go xs :: String
xs (B0 x :: BNat n
x) = String -> BNat n -> String
forall (m :: Nat). String -> BNat m -> String
go ('0'Char -> ShowS
forall a. a -> [a] -> [a]
:String
xs) BNat n
x
    go xs :: String
xs (B1 x :: BNat n
x) = String -> BNat n -> String
forall (m :: Nat). String -> BNat m -> String
go ('1'Char -> ShowS
forall a. a -> [a] -> [a]
:String
xs) BNat n
x

-- | Convert a singleton natural number to its base-2 representation
--
-- __NB__: Not synthesizable
toBNat :: SNat n -> BNat n
toBNat :: SNat n -> BNat n
toBNat s :: SNat n
s@SNat n
SNat = Integer -> BNat n
forall (m :: Nat). Integer -> BNat m
toBNat' (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
s)
  where
    toBNat' :: Integer -> BNat m
    toBNat' :: Integer -> BNat m
toBNat' 0 = BNat 0 -> BNat m
forall a b. a -> b
unsafeCoerce BNat 0
BT
    toBNat' n :: Integer
n = case Integer
n Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` 2 of
      (n' :: Integer
n',1) -> BNat ((2 * Any) + 1) -> BNat m
forall a b. a -> b
unsafeCoerce (BNat Any -> BNat ((2 * Any) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (Integer -> BNat Any
forall (m :: Nat). Integer -> BNat m
toBNat' Integer
n'))
      (n' :: Integer
n',_) -> BNat (2 * Any) -> BNat m
forall a b. a -> b
unsafeCoerce (BNat Any -> BNat (2 * Any)
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (Integer -> BNat Any
forall (m :: Nat). Integer -> BNat m
toBNat' Integer
n'))

-- | Convert a base-2 encoded natural number to its singleton representation
--
-- __NB__: Not synthesizable
fromBNat :: BNat n -> SNat n
fromBNat :: BNat n -> SNat n
fromBNat BT     = SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 0
fromBNat (B0 x :: BNat n
x) = SNat 2 -> SNat n -> SNat (2 * n)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a * b)
mulSNat (SNat 2
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 2) (BNat n -> SNat n
forall (n :: Nat). BNat n -> SNat n
fromBNat BNat n
x)
fromBNat (B1 x :: BNat n
x) = SNat (2 * n) -> SNat 1 -> SNat ((2 * n) + 1)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a + b)
addSNat (SNat 2 -> SNat n -> SNat (2 * n)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a * b)
mulSNat (SNat 2
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 2) (BNat n -> SNat n
forall (n :: Nat). BNat n -> SNat n
fromBNat BNat n
x))
                          (SNat 1
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 1)

-- | Add two base-2 encoded natural numbers
--
-- __NB__: Not synthesizable
addBNat :: BNat n -> BNat m -> BNat (n+m)
addBNat :: BNat n -> BNat m -> BNat (n + m)
addBNat (B0 a :: BNat n
a) (B0 b :: BNat n
b) = BNat (n + n) -> BNat (2 * (n + n))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b)
addBNat (B0 a :: BNat n
a) (B1 b :: BNat n
b) = BNat (n + n) -> BNat ((2 * (n + n)) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b)
addBNat (B1 a :: BNat n
a) (B0 b :: BNat n
b) = BNat (n + n) -> BNat ((2 * (n + n)) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b)
addBNat (B1 a :: BNat n
a) (B1 b :: BNat n
b) = BNat ((n + n) + 1) -> BNat (2 * ((n + n) + 1))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat (n + n) -> BNat ((n + n) + 1)
forall (n :: Nat). BNat n -> BNat (n + 1)
succBNat (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b))
addBNat BT     b :: BNat m
b      = BNat m
BNat (n + m)
b
addBNat a :: BNat n
a      BT     = BNat n
BNat (n + m)
a

-- | Multiply two base-2 encoded natural numbers
--
-- __NB__: Not synthesizable
mulBNat :: BNat n -> BNat m -> BNat (n*m)
mulBNat :: BNat n -> BNat m -> BNat (n * m)
mulBNat BT      _  = BNat 0
BNat (n * m)
BT
mulBNat _       BT = BNat 0
BNat (n * m)
BT
mulBNat (B0 a :: BNat n
a)  b :: BNat m
b  = BNat (n * m) -> BNat (2 * (n * m))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat m -> BNat (n * m)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
a BNat m
b)
mulBNat (B1 a :: BNat n
a)  b :: BNat m
b  = BNat (2 * (n * m)) -> BNat m -> BNat ((2 * (n * m)) + m)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat (BNat (n * m) -> BNat (2 * (n * m))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat m -> BNat (n * m)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
a BNat m
b)) BNat m
b

-- | Power of two base-2 encoded natural numbers
--
-- __NB__: Not synthesizable
powBNat :: BNat n -> BNat m -> BNat (n^m)
powBNat :: BNat n -> BNat m -> BNat (n ^ m)
powBNat _  BT      = BNat 0 -> BNat ((2 * 0) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 BNat 0
BT
powBNat a :: BNat n
a  (B0 b :: BNat n
b)  = let z :: BNat (n ^ n)
z = BNat n -> BNat n -> BNat (n ^ n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n ^ m)
powBNat BNat n
a BNat n
b
                     in  BNat (n ^ n) -> BNat (n ^ n) -> BNat ((n ^ n) * (n ^ n))
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat (n ^ n)
z BNat (n ^ n)
z
powBNat a :: BNat n
a  (B1 b :: BNat n
b)  = let z :: BNat (n ^ n)
z = BNat n -> BNat n -> BNat (n ^ n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n ^ m)
powBNat BNat n
a BNat n
b
                     in  BNat n
-> BNat ((n ^ n) * (n ^ n)) -> BNat (n * ((n ^ n) * (n ^ n)))
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
a (BNat (n ^ n) -> BNat (n ^ n) -> BNat ((n ^ n) * (n ^ n))
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat (n ^ n)
z BNat (n ^ n)
z)

-- | Successor of a base-2 encoded natural number
--
-- __NB__: Not synthesizable
succBNat :: BNat n -> BNat (n+1)
succBNat :: BNat n -> BNat (n + 1)
succBNat BT     = BNat 0 -> BNat ((2 * 0) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 BNat 0
BT
succBNat (B0 a :: BNat n
a) = BNat n -> BNat ((2 * n) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 BNat n
a
succBNat (B1 a :: BNat n
a) = BNat (n + 1) -> BNat (2 * (n + 1))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat (n + 1)
forall (n :: Nat). BNat n -> BNat (n + 1)
succBNat BNat n
a)

-- | Predecessor of a base-2 encoded natural number
--
-- __NB__: Not synthesizable
predBNat :: (1 <= n) => BNat n -> BNat (n-1)
predBNat :: BNat n -> BNat (n - 1)
predBNat (B1 a :: BNat n
a) = case BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
a of
  BT -> BNat 0
BNat (n - 1)
BT
  a' :: BNat n
a' -> BNat n -> BNat (2 * n)
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 BNat n
a'
predBNat (B0 x :: BNat n
x) = BNat (n - 1) -> BNat ((2 * (n - 1)) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat (n - 1)
forall (n :: Nat). (1 <= n) => BNat n -> BNat (n - 1)
predBNat BNat n
x)

-- | Divide a base-2 encoded natural number by 2
--
-- __NB__: Not synthesizable
div2BNat :: BNat (2*n) -> BNat n
div2BNat :: BNat (2 * n) -> BNat n
div2BNat BT     = BNat n
BNat 0
BT
div2BNat (B0 x :: BNat n
x) = BNat n
BNat n
x
div2BNat (B1 _) = String -> BNat n
forall a. HasCallStack => String -> a
error "div2BNat: impossible: 2*n ~ 2*n+1"

-- | Subtract 1 and divide a base-2 encoded natural number by 2
--
-- __NB__: Not synthesizable
div2Sub1BNat :: BNat (2*n+1) -> BNat n
div2Sub1BNat :: BNat ((2 * n) + 1) -> BNat n
div2Sub1BNat (B1 x :: BNat n
x) = BNat n
BNat n
x
div2Sub1BNat _      = String -> BNat n
forall a. HasCallStack => String -> a
error "div2Sub1BNat: impossible: 2*n+1 ~ 2*n"

-- | Get the log2 of a base-2 encoded natural number
--
-- __NB__: Not synthesizable
log2BNat :: BNat (2^n) -> BNat n
log2BNat :: BNat (2 ^ n) -> BNat n
log2BNat BT = String -> BNat n
forall a. HasCallStack => String -> a
error "log2BNat: log2(0) not defined"
log2BNat (B1 x :: BNat n
x) = case BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
x of
  BT -> BNat n
BNat 0
BT
  _  -> String -> BNat n
forall a. HasCallStack => String -> a
error "log2BNat: impossible: 2^n ~ 2x+1"
log2BNat (B0 x :: BNat n
x) = BNat (n - 1) -> BNat ((n - 1) + 1)
forall (n :: Nat). BNat n -> BNat (n + 1)
succBNat (BNat (2 ^ (n - 1)) -> BNat (n - 1)
forall (n :: Nat). BNat (2 ^ n) -> BNat n
log2BNat BNat n
BNat (2 ^ (n - 1))
x)

-- | Strip non-contributing zero's from a base-2 encoded natural number
--
-- >>> B1 (B0 (B0 (B0 BT)))
-- b1
-- >>> showBNat (B1 (B0 (B0 (B0 BT))))
-- "0b0001"
-- >>> showBNat (stripZeros (B1 (B0 (B0 (B0 BT)))))
-- "0b1"
-- >>> stripZeros (B1 (B0 (B0 (B0 BT))))
-- b1
--
-- __NB__: Not synthesizable
stripZeros :: BNat n -> BNat n
stripZeros :: BNat n -> BNat n
stripZeros BT      = BNat n
BNat 0
BT
stripZeros (B1 x :: BNat n
x)  = BNat n -> BNat ((2 * n) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
x)
stripZeros (B0 BT) = BNat n
BNat 0
BT
stripZeros (B0 x :: BNat n
x)  = case BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
x of
  BT -> BNat n
BNat 0
BT
  k :: BNat n
k  -> BNat n -> BNat (2 * n)
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 BNat n
k

-- | Change a function that has an argument with an @(n ~ (k + m))@ constraint to a
-- function with an argument that has an @(k <= n)@ constraint.
--
-- === __Examples__
--
-- Example 1
--
-- @
-- f :: Index (n+1) -> Index (n + 1) -> Bool
--
-- g :: forall n. (1 '<=' n) => Index n -> Index n -> Bool
-- g a b = 'leToPlus' \@1 \@n (f a b)
-- @
--
-- Example 2
--
-- @
-- head :: Vec (n + 1) a -> a
--
-- head' :: forall n a. (1 '<=' n) => Vec n a -> a
-- head' = 'leToPlus' @1 @n head
-- @
leToPlus
  :: forall (k :: Nat) (n :: Nat) r
   . ( k <= n
     )
  => (forall m . (n ~ (k + m)) => r)
  -- ^ Context with the @(n ~ (k + m))@ constraint
  -> r
leToPlus :: (forall (m :: Nat). (n ~ (k + m)) => r) -> r
leToPlus r :: forall (m :: Nat). (n ~ (k + m)) => r
r = (n ~ (k + (n - k))) => r
forall (m :: Nat). (n ~ (k + m)) => r
r @(n - k)
{-# INLINE leToPlus #-}

-- | Same as 'leToPlus' with added 'KnownNat' constraints
leToPlusKN
  :: forall (k :: Nat) (n :: Nat) r
   . ( k <= n
     , KnownNat k
     , KnownNat n
     )
  => (forall m . (n ~ (k + m), KnownNat m) => r)
  -- ^ Context with the @(n ~ (k + m))@ constraint
  -> r
leToPlusKN :: (forall (m :: Nat). (n ~ (k + m), KnownNat m) => r) -> r
leToPlusKN r :: forall (m :: Nat). (n ~ (k + m), KnownNat m) => r
r = (n ~ (k + (n - k)), KnownNat (n - k)) => r
forall (m :: Nat). (n ~ (k + m), KnownNat m) => r
r @(n - k)
{-# INLINE leToPlusKN #-}