{-# LANGUAGE AllowAmbiguousTypes     #-}
{-# LANGUAGE TypeApplications        #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE UndecidableInstances    #-}
{-# LANGUAGE UndecidableSuperClasses #-}

module ZkFold.Symbolic.Data.Combinators where

import           Control.Applicative              (Applicative)
import           Control.Monad                    (mapM)
import           Data.Foldable                    (foldlM)
import           Data.Functor.Rep                 (Representable, mzipRep, mzipWithRep)
import           Data.Kind                        (Type)
import           Data.List                        (find, splitAt)
import           Data.List.Split                  (chunksOf)
import           Data.Maybe                       (fromMaybe)
import           Data.Proxy                       (Proxy (..))
import           Data.Ratio                       ((%))
import           Data.Traversable                 (Traversable, for, sequenceA)
import           Data.Type.Bool                   (If)
import           Data.Type.Ord
import           GHC.Base                         (const, return)
import           GHC.List                         (reverse)
import           GHC.TypeNats
import           Prelude                          (error, head, pure, tail, ($), (.), (<$>), (<>))
import qualified Prelude                          as Haskell
import           Type.Errors

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number (value)
import           ZkFold.Symbolic.Class            (Arithmetic, BaseField)
import           ZkFold.Symbolic.MonadCircuit

mzipWithMRep ::
  (Representable f, Traversable f, Applicative m) =>
  (a -> b -> m c) -> f a -> f b -> m (f c)
mzipWithMRep :: forall (f :: Type -> Type) (m :: Type -> Type) a b c.
(Representable f, Traversable f, Applicative m) =>
(a -> b -> m c) -> f a -> f b -> m (f c)
mzipWithMRep a -> b -> m c
f f a
x f b
y = f (m c) -> m (f c)
forall (t :: Type -> Type) (f :: Type -> Type) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: Type -> Type) a. Applicative f => f (f a) -> f (f a)
sequenceA ((a -> b -> m c) -> f a -> f b -> f (m c)
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
mzipWithRep a -> b -> m c
f f a
x f b
y)

--------------------------------------------------------------------------------------------------

-- | A class for isomorphic types.
-- The @Iso b a@ context ensures that transformations in both directions are defined
--
class Iso b a => Iso a b where
    from :: a -> b

-- | Describes types that can increase or shrink their capacity by adding zero bits to the beginning (i.e. before the higher register)
-- or removing higher bits.
--
class Resize a b where
    resize :: a -> b

-- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables.
--
toBits
    :: (MonadCircuit v a w m, Arithmetic a)
    => [v]
    -> Natural
    -> Natural
    -> m [v]
toBits :: forall v a w (m :: Type -> Type).
(MonadCircuit v a w m, Arithmetic a) =>
[v] -> Natural -> Natural -> m [v]
toBits [v]
regs Natural
hiBits Natural
loBits = do
    let lows :: [v]
lows = [v] -> [v]
forall a. HasCallStack => [a] -> [a]
tail [v]
regs
        high :: v
high = [v] -> v
forall a. HasCallStack => [a] -> a
head [v]
regs
    [v]
bitsLow  <- ([v] -> [v]) -> [[v]] -> [v]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
Haskell.concatMap [v] -> [v]
forall a. [a] -> [a]
Haskell.reverse ([[v]] -> [v]) -> m [[v]] -> m [v]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (v -> m [v]) -> [v] -> m [[v]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (Natural -> v -> m [v]
forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansion Natural
loBits) [v]
lows
    [v]
bitsHigh <- [v] -> [v]
forall a. [a] -> [a]
Haskell.reverse ([v] -> [v]) -> m [v] -> m [v]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Natural -> v -> m [v]
forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansion Natural
hiBits v
high
    [v] -> m [v]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([v] -> m [v]) -> [v] -> m [v]
forall a b. (a -> b) -> a -> b
$ [v]
bitsHigh [v] -> [v] -> [v]
forall a. Semigroup a => a -> a -> a
<> [v]
bitsLow

-- | The inverse of @toBits@.
--
fromBits
    :: Natural
    -> Natural
    -> (forall v w m. MonadCircuit v a w m => [v] -> m [v])
fromBits :: forall a.
Natural
-> Natural
-> forall v w (m :: Type -> Type).
   MonadCircuit v a w m =>
   [v] -> m [v]
fromBits Natural
hiBits Natural
loBits [v]
bits = do
    let ([v]
bitsHighNew, [v]
bitsLowNew) = Int -> [v] -> ([v], [v])
forall a. Int -> [a] -> ([a], [a])
splitAt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
hiBits) [v]
bits
    let lowVarsNew :: [[v]]
lowVarsNew = Int -> [v] -> [[v]]
forall e. Int -> [e] -> [[e]]
chunksOf (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral Natural
loBits) [v]
bitsLowNew
    [v]
lowsNew <- ([v] -> m v) -> [[v]] -> m [v]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ([v] -> m v
forall i a w (m :: Type -> Type).
MonadCircuit i a w m =>
[i] -> m i
horner ([v] -> m v) -> ([v] -> [v]) -> [v] -> m v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [v] -> [v]
forall a. [a] -> [a]
Haskell.reverse) [[v]]
lowVarsNew
    v
highNew <- [v] -> m v
forall i a w (m :: Type -> Type).
MonadCircuit i a w m =>
[i] -> m i
horner ([v] -> m v) -> ([v] -> [v]) -> [v] -> m v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [v] -> [v]
forall a. [a] -> [a]
Haskell.reverse ([v] -> m v) -> [v] -> m v
forall a b. (a -> b) -> a -> b
$  [v]
bitsHighNew
    [v] -> m [v]
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([v] -> m [v]) -> [v] -> m [v]
forall a b. (a -> b) -> a -> b
$ v
highNew v -> [v] -> [v]
forall a. a -> [a] -> [a]
: [v]
lowsNew

data RegisterSize = Auto | Fixed Natural deriving (RegisterSize -> RegisterSize -> Bool
(RegisterSize -> RegisterSize -> Bool)
-> (RegisterSize -> RegisterSize -> Bool) -> Eq RegisterSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RegisterSize -> RegisterSize -> Bool
== :: RegisterSize -> RegisterSize -> Bool
$c/= :: RegisterSize -> RegisterSize -> Bool
/= :: RegisterSize -> RegisterSize -> Bool
Haskell.Eq)

class KnownRegisterSize (r :: RegisterSize) where
  regSize :: RegisterSize

instance KnownRegisterSize Auto where
  regSize :: RegisterSize
regSize = RegisterSize
Auto

instance KnownNat n => KnownRegisterSize (Fixed n) where
  regSize :: RegisterSize
regSize = Natural -> RegisterSize
Fixed (forall (n :: Natural). KnownNat n => Natural
value @n)

maxOverflow :: forall a n r . (Finite a, KnownNat n, KnownRegisterSize r) => Natural
maxOverflow :: forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
maxOverflow = forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
registerSize @a @n @r Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (Natural -> Double
log2 (Natural -> Double) -> Natural -> Double
forall a b. (a -> b) -> a -> b
$ forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
numberOfRegisters @a @n @r)

highRegisterSize :: forall a n r . (Finite a, KnownNat n, KnownRegisterSize r) => Natural
highRegisterSize :: forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
highRegisterSize = forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
-! forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
registerSize @a @n @r Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* (forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
numberOfRegisters @a @n @r Natural -> Natural -> Natural
-! Natural
1)

registerSize  :: forall a n r. (Finite a, KnownNat n, KnownRegisterSize r) => Natural
registerSize :: forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
registerSize = case forall (r :: RegisterSize). KnownRegisterSize r => RegisterSize
regSize @r of
    RegisterSize
Auto     -> Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
numberOfRegisters @a @n @r)
    Fixed Natural
rs -> Natural
rs

type Ceil a b = Div (a + b - 1) b

type family GetRegisterSize (a :: Type) (bits :: Natural) (r :: RegisterSize) :: Natural where
    GetRegisterSize a bits (Fixed rs) = rs
    GetRegisterSize a bits Auto       = Ceil bits (NumberOfRegisters a bits Auto)

type KnownRegisters c bits r = KnownNat (NumberOfRegisters (BaseField c) bits r)

type family NumberOfRegisters (a :: Type) (bits :: Natural) (r :: RegisterSize ) :: Natural where
  NumberOfRegisters a bits (Fixed rs) = If (Mod bits rs >? 0 ) (Div bits rs + 1) (Div bits rs) -- if rs <= maxregsize a, ceil (n / rs)
  NumberOfRegisters a bits Auto       = NumberOfRegisters' a bits (ListRange 1 50) -- TODO: Compilation takes ages if this constant is greater than 10000.
                                                                          -- But it is weird anyway if someone is trying to store a value
                                                                          -- which requires more than 50 registers.

type family NumberOfRegisters' (a :: Type) (bits :: Natural) (c :: [Natural]) :: Natural where
    NumberOfRegisters' a bits '[] = 0
    NumberOfRegisters' a bits (x ': xs) =
        OrdCond (CmpNat bits (x * MaxRegisterSize a x))
            x
            x
            (NumberOfRegisters' a bits xs)

type family BitLimit (a :: Type) :: Natural where
    BitLimit a = Log2 (Order a)

type family MaxAdded (regCount :: Natural) :: Natural where
    MaxAdded regCount =
        OrdCond (CmpNat regCount (2 ^ Log2 regCount))
            (TypeError (Text "Impossible"))
            (Log2 regCount)
            (1 + Log2 regCount)

type family MaxRegisterSize (a :: Type) (regCount :: Natural) :: Natural where
    MaxRegisterSize a regCount = Div (BitLimit a - MaxAdded regCount) 2

type family ListRange (from :: Natural) (to :: Natural) :: [Natural] where
    ListRange from from = '[from]
    ListRange from to = from ': ListRange (from + 1) to

numberOfRegisters :: forall a n r . ( Finite a, KnownNat n, KnownRegisterSize r) => Natural
numberOfRegisters :: forall a (n :: Natural) (r :: RegisterSize).
(Finite a, KnownNat n, KnownRegisterSize r) =>
Natural
numberOfRegisters =  case forall (r :: RegisterSize). KnownRegisterSize r => RegisterSize
regSize @r of
    RegisterSize
Auto -> Natural -> Maybe Natural -> Natural
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Natural
forall a. HasCallStack => [Char] -> a
error [Char]
"too many bits, field is not big enough")
        (Maybe Natural -> Natural) -> Maybe Natural -> Natural
forall a b. (a -> b) -> a -> b
$ (Natural -> Bool) -> [Natural] -> Maybe Natural
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe a
find (\Natural
c -> Natural
c Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* Natural -> Natural
maxRegisterSize Natural
c Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
Haskell.>= forall (n :: Natural). KnownNat n => Natural
getNatural @n) [Natural
1 .. Natural
maxRegisterCount]
        where
            maxRegisterCount :: Natural
maxRegisterCount = Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
bitLimit
            bitLimit :: Natural
bitLimit = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 (forall a. Finite a => Natural
order @a)
            maxRegisterSize :: Natural -> Natural
maxRegisterSize Natural
regCount =
                let maxAdded :: Natural
maxAdded = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 Natural
regCount
                in Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Ratio Natural -> Natural) -> Ratio Natural -> Natural
forall a b. (a -> b) -> a -> b
$ (Natural
bitLimit Natural -> Natural -> Natural
-! Natural
maxAdded) Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% Natural
2
    Fixed Natural
rs -> Ratio Natural -> Natural
forall b. Integral b => Ratio Natural -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.ceiling (forall (n :: Natural). KnownNat n => Natural
value @n Natural -> Natural -> Ratio Natural
forall a. Integral a => a -> a -> Ratio a
% Natural
rs)

log2 :: Natural -> Haskell.Double
log2 :: Natural -> Double
log2 = Double -> Double -> Double
forall a. Floating a => a -> a -> a
Haskell.logBase Double
2 (Double -> Double) -> (Natural -> Double) -> Natural -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Double
forall a b. (Integral a, Num b) => a -> b
Haskell.fromIntegral

getNatural :: forall n . KnownNat n => Natural
getNatural :: forall (n :: Natural). KnownNat n => Natural
getNatural = Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> Type).
KnownNat n =>
proxy n -> Natural
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

-- | The maximum number of bits a Field element can encode.
--
maxBitsPerFieldElement :: forall p. Finite p => Natural
maxBitsPerFieldElement :: forall a. Finite a => Natural
maxBitsPerFieldElement = Double -> Natural
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
Haskell.floor (Double -> Natural) -> Double -> Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Double
log2 (forall a. Finite a => Natural
order @p)

-- | The maximum number of bits it makes sense to encode in a register.
-- That is, if the field elements can encode more bits than required, choose the smaller number.
--
maxBitsPerRegister :: forall p n. (Finite p, KnownNat n) => Natural
maxBitsPerRegister :: forall p (n :: Natural). (Finite p, KnownNat n) => Natural
maxBitsPerRegister = Natural -> Natural -> Natural
forall a. Ord a => a -> a -> a
Haskell.min (forall a. Finite a => Natural
maxBitsPerFieldElement @p) (forall (n :: Natural). KnownNat n => Natural
getNatural @n)

-- | The number of bits remaining for the higher register
-- assuming that all smaller registers have the same optimal number of bits.
--
highRegisterBits :: forall p n. (Finite p, KnownNat n) => Natural
highRegisterBits :: forall p (n :: Natural). (Finite p, KnownNat n) => Natural
highRegisterBits = case forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. SemiEuclidean a => a -> a -> a
`mod` forall a. Finite a => Natural
maxBitsPerFieldElement @p of
                     Natural
0 -> forall a. Finite a => Natural
maxBitsPerFieldElement @p
                     Natural
m -> Natural
m

-- | The lowest possible number of registers to encode @n@ bits using Field elements from @p@
-- assuming that each register storest the largest possible number of bits.
--
minNumberOfRegisters :: forall p n. (Finite p, KnownNat n) => Natural
minNumberOfRegisters :: forall p (n :: Natural). (Finite p, KnownNat n) => Natural
minNumberOfRegisters = (forall (n :: Natural). KnownNat n => Natural
getNatural @n Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall p (n :: Natural). (Finite p, KnownNat n) => Natural
maxBitsPerRegister @p @n Natural -> Natural -> Natural
-! Natural
1) Natural -> Natural -> Natural
forall a. SemiEuclidean a => a -> a -> a
`div` forall p (n :: Natural). (Finite p, KnownNat n) => Natural
maxBitsPerRegister @p @n

---------------------------------------------------------------

expansion :: (MonadCircuit i a w m, Arithmetic a) => Natural -> i -> m [i]
-- ^ @expansion n k@ computes a binary expansion of @k@ if it fits in @n@ bits.
expansion :: forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansion = forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansionW @1

expansionW :: forall r i a w m . (KnownNat r, MonadCircuit i a w m, Arithmetic a) => Natural -> i -> m [i]
expansionW :: forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
expansionW Natural
n i
k = do
    [i]
words <- forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
wordsOf @r Natural
n i
k
    i
k' <- forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m) =>
[i] -> m i
hornerW @r [i]
words
    ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> i -> x
x i
k x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
k')
    [i] -> m [i]
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [i]
words

bitsOf :: (MonadCircuit i a w m, Arithmetic a) => Natural -> i -> m [i]
-- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller
-- bits of @k@.
bitsOf :: forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
bitsOf = forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
wordsOf @1

wordsOf :: forall r i a w m . (KnownNat r, MonadCircuit i a w m, Arithmetic a) => Natural -> i -> m [i]
-- ^ @wordsOf n k@ creates @n@ r-bit words and sets their witnesses equal to @n@ smaller
-- words of @k@.
wordsOf :: forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m, Arithmetic a) =>
Natural -> i -> m [i]
wordsOf Natural
n i
k = [Natural] -> (Natural -> m i) -> m [i]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Natural
0 .. Natural
n Natural -> Natural -> Natural
-! Natural
1] ((Natural -> m i) -> m [i]) -> (Natural -> m i) -> m [i]
forall a b. (a -> b) -> a -> b
$ \Natural
j ->
    a -> w -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
a -> w -> m var
newRanged (Natural -> a
forall a b. FromConstant a b => a -> b
fromConstant (Natural -> a) -> Natural -> a
forall a b. (a -> b) -> a -> b
$ Natural
wordSize Natural -> Natural -> Natural
-! Natural
1) (Natural -> w -> w
forall n x. ResidueField n x => Natural -> x -> x
repr Natural
j (i -> w
forall i w. Witness i w => i -> w
at i
k))
    where
        wordSize :: Natural
        wordSize :: Natural
wordSize = Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ forall (n :: Natural). KnownNat n => Natural
value @r

        repr :: ResidueField n x => Natural -> x -> x
        repr :: forall n x. ResidueField n x => Natural -> x -> x
repr Natural
j =
              n -> x
forall a b. FromConstant a b => a -> b
fromConstant
              (n -> x) -> (x -> n) -> x -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (n -> n -> n
forall a. SemiEuclidean a => a -> a -> a
`mod` Natural -> n
forall a b. FromConstant a b => a -> b
fromConstant Natural
wordSize)
              (n -> n) -> (x -> n) -> x -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (n -> n -> n
forall a. SemiEuclidean a => a -> a -> a
`div` Natural -> n
forall a b. FromConstant a b => a -> b
fromConstant (Natural
wordSize Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
j))
              (n -> n) -> (x -> n) -> x -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> n
x -> Const x
forall a. ToConstant a => a -> Const a
toConstant

hornerW :: forall r i a w m . (KnownNat r, MonadCircuit i a w m) => [i] -> m i
-- ^ @horner [b0,...,bn]@ computes the sum @b0 + (2^r) b1 + ... + 2^rn bn@ using
-- Horner's scheme.
hornerW :: forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m) =>
[i] -> m i
hornerW [i]
xs = case [i] -> [i]
forall a. [a] -> [a]
reverse [i]
xs of
    []       -> ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (x -> (i -> x) -> x
forall a b. a -> b -> a
const x
forall a. AdditiveMonoid a => a
zero)
    (i
b : [i]
bs) -> (i -> i -> m i) -> i -> [i] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\i
a i
i -> ClosedPoly i a -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m var
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural -> x -> x
forall b a. Scale b a => b -> a -> a
scale Natural
wordSize (i -> x
x i
a))) i
b [i]
bs
    where
        wordSize :: Natural
        wordSize :: Natural
wordSize = Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ (forall (n :: Natural). KnownNat n => Natural
value @r)

horner :: MonadCircuit i a w m => [i] -> m i
-- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using
-- Horner's scheme.
horner :: forall i a w (m :: Type -> Type).
MonadCircuit i a w m =>
[i] -> m i
horner = forall (r :: Natural) i a w (m :: Type -> Type).
(KnownNat r, MonadCircuit i a w m) =>
[i] -> m i
hornerW @1

splitExpansion :: (MonadCircuit i a w m, Arithmetic a) => Natural -> Natural -> i -> m (i, i)
-- ^ @splitExpansion n1 n2 k@ computes two values @(l, h)@ such that
-- @k = 2^n1 h + l@, @l@ fits in @n1@ bits and @h@ fits in n2 bits (if such
-- values exist).
splitExpansion :: forall i a w (m :: Type -> Type).
(MonadCircuit i a w m, Arithmetic a) =>
Natural -> Natural -> i -> m (i, i)
splitExpansion Natural
n1 Natural
n2 i
k = do
    i
l <- a -> w -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
a -> w -> m var
newRanged (forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> a) -> Natural -> a
forall a b. (a -> b) -> a -> b
$ Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n1 Natural -> Natural -> Natural
-! Natural
1) (w -> m i) -> w -> m i
forall a b. (a -> b) -> a -> b
$ w -> w
forall n a. ResidueField n a => a -> a
lower (i -> w
forall i w. Witness i w => i -> w
at i
k)
    i
h <- a -> w -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
a -> w -> m var
newRanged (forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural -> a) -> Natural -> a
forall a b. (a -> b) -> a -> b
$ Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n2 Natural -> Natural -> Natural
-! Natural
1) (w -> m i) -> w -> m i
forall a b. (a -> b) -> a -> b
$ w -> w
forall n a. ResidueField n a => a -> a
upper (i -> w
forall i w. Witness i w => i -> w
at i
k)
    ClosedPoly i a -> m ()
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
ClosedPoly var a -> m ()
constraint (\i -> x
x -> i -> x
x i
k x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
l x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- Natural -> x -> x
forall b a. Scale b a => b -> a -> a
scale (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n1 :: Natural) (i -> x
x i
h))
    (i, i) -> m (i, i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (i
l, i
h)
    where
        lower :: ResidueField n a => a -> a
        lower :: forall n a. ResidueField n a => a -> a
lower =
            n -> a
forall a b. FromConstant a b => a -> b
fromConstant (n -> a) -> (a -> n) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (n -> n -> n
forall a. SemiEuclidean a => a -> a -> a
`mod` forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n1)) (n -> n) -> (a -> n) -> a -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> n
a -> Const a
forall a. ToConstant a => a -> Const a
toConstant

        upper :: ResidueField n a => a -> a
        upper :: forall n a. ResidueField n a => a -> a
upper =
            n -> a
forall a b. FromConstant a b => a -> b
fromConstant
            (n -> a) -> (a -> n) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (n -> n -> n
forall a. SemiEuclidean a => a -> a -> a
`mod` forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n2))
            (n -> n) -> (a -> n) -> a -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (n -> n -> n
forall a. SemiEuclidean a => a -> a -> a
`div` forall a b. FromConstant a b => a -> b
fromConstant @Natural (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n1))
            (n -> n) -> (a -> n) -> a -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> n
a -> Const a
forall a. ToConstant a => a -> Const a
toConstant

runInvert :: (MonadCircuit i a w m, Representable f, Traversable f) => f i -> m (f i, f i)
runInvert :: forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
is = do
    f i
js <- f i -> (i -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for f i
is ((i -> m i) -> m (f i)) -> (i -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \i
i -> NewConstraint i a -> w -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
NewConstraint var a -> w -> m var
newConstrained (\i -> x
x i
j -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j) (w
forall a. MultiplicativeMonoid a => a
one w -> w -> w
forall a. AdditiveGroup a => a -> a -> a
- i -> w
forall i w. Witness i w => i -> w
at i
i w -> w -> w
forall a. Field a => a -> a -> a
// i -> w
forall i w. Witness i w => i -> w
at i
i)
    f i
ks <- f (i, i) -> ((i, i) -> m i) -> m (f i)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (f i -> f i -> f (i, i)
forall (f :: Type -> Type) a b.
Representable f =>
f a -> f b -> f (a, b)
mzipRep f i
is f i
js) (((i, i) -> m i) -> m (f i)) -> ((i, i) -> m i) -> m (f i)
forall a b. (a -> b) -> a -> b
$ \(i
i, i
j) -> NewConstraint i a -> w -> m i
forall var a w (m :: Type -> Type).
MonadCircuit var a w m =>
NewConstraint var a -> w -> m var
newConstrained (\i -> x
x i
k -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
k x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
j x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) (w -> w
forall a. Field a => a -> a
finv (i -> w
forall i w. Witness i w => i -> w
at i
i))
    (f i, f i) -> m (f i, f i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (f i
js, f i
ks)

isZero :: (MonadCircuit i a w m, Representable f, Traversable f) => f i -> m (f i)
isZero :: forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i)
isZero f i
is = (f i, f i) -> f i
forall a b. (a, b) -> a
Haskell.fst ((f i, f i) -> f i) -> m (f i, f i) -> m (f i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f i -> m (f i, f i)
forall i a w (m :: Type -> Type) (f :: Type -> Type).
(MonadCircuit i a w m, Representable f, Traversable f) =>
f i -> m (f i, f i)
runInvert f i
is