{-# LANGUAGE AllowAmbiguousTypes #-}
module ZkFold.Base.Algorithm.ReedSolomon where
import Data.Bool (bool)
import Data.Vector as V hiding (sum)
import GHC.Natural (Natural)
import Prelude (Eq, Int, Integer, Maybe (..), error, fromIntegral, iterate,
min, ($), (.), (<=), (==))
import qualified Prelude as P
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (KnownNat, value)
import ZkFold.Base.Algebra.Polynomials.Univariate
numberOfError :: forall n k. (KnownNat n, KnownNat k) => Natural
numberOfError :: forall (n :: Nat) (k :: Nat). (KnownNat n, KnownNat k) => Nat
numberOfError = (forall (n :: Nat). KnownNat n => Nat
value @n Nat -> Nat -> Nat
-! forall (n :: Nat). KnownNat n => Nat
value @k) Nat -> Nat -> Nat
forall a. SemiEuclidean a => a -> a -> a
`div` Nat
2
generator :: (Field a, Eq a) => Int -> a -> Poly a
generator :: forall a. (Field a, Eq a) => Int -> a -> Poly a
generator Int
r a
a = (a -> Poly a -> Poly a) -> Poly a -> Vector a -> Poly a
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr (\a
ai Poly a
pi -> a -> Poly a
forall {c}. (Ring c, Eq c) => c -> Poly c
toLinPoly a
ai Poly a -> Poly a -> Poly a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly a
pi) Poly a
forall a. MultiplicativeMonoid a => a
one Vector a
roots
where
roots :: Vector a
roots = Int -> (a -> a) -> a -> Vector a
forall a. Int -> (a -> a) -> a -> Vector a
V.iterateN Int
r (a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* a
a) a
a
toLinPoly :: c -> Poly c
toLinPoly c
p = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ [c] -> Vector c
forall a. [a] -> Vector a
fromList [c -> c
forall a. AdditiveGroup a => a -> a
negate c
p, c
forall a. MultiplicativeMonoid a => a
one]
encode :: (Field c, Eq c) => [c] -> c -> Int -> Poly c
encode :: forall c. (Field c, Eq c) => [c] -> c -> Int -> Poly c
encode [c]
msg c
prim_elem Int
r = Poly c
msg_padded Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- Poly c
reminder
where
g_x :: Poly c
g_x = Int -> c -> Poly c
forall a. (Field a, Eq a) => Int -> a -> Poly a
generator Int
r c
prim_elem
poly_msg :: Poly c
poly_msg = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ [c] -> Vector c
forall a. [a] -> Vector a
fromList [c]
msg
msg_padded :: Poly c
msg_padded = c -> Nat -> Poly c -> Poly c
forall c. Ring c => c -> Nat -> Poly c -> Poly c
scaleP c
forall a. MultiplicativeMonoid a => a
one (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r) Poly c
poly_msg
(Poly c
_, Poly c
reminder) = Poly c -> Poly c -> (Poly c, Poly c)
forall c. (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
qr Poly c
msg_padded Poly c
g_x
decode :: (Field c, Eq c) => Poly c -> c -> Int -> Int -> Poly c
decode :: forall c. (Field c, Eq c) => Poly c -> c -> Int -> Int -> Poly c
decode Poly c
encoded c
primeElement Int
r Int
n = Poly c -> Poly c -> Bool -> Poly c
forall a. a -> a -> Bool -> a
bool Poly c
decoded Poly c
encoded' Bool
isCorrect
where
rElems :: Vector c
rElems = Int -> (c -> c) -> c -> Vector c
forall a. Int -> (a -> a) -> a -> Vector a
iterateN Int
r (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
primeElement) c
primeElement
encoded' :: Poly c
encoded' = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c)
-> (Vector c -> Vector c) -> Vector c -> Poly c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop Int
r (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
encoded
syndromes :: Poly c
syndromes = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c)
-> (Vector c -> Vector c) -> Vector c -> Poly c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
encoded) (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c
rElems
isCorrect :: Bool
isCorrect = Poly c
forall a. AdditiveMonoid a => a
zero Poly c -> Poly c -> Bool
forall a. Eq a => a -> a -> Bool
== Poly c
syndromes
(Integer
_, Poly c
lx) = Poly c -> Int -> (Integer, Poly c)
forall c. (Field c, Eq c) => Poly c -> Int -> (Integer, Poly c)
berlekamp Poly c
syndromes Int
r
invPE :: c
invPE = c -> c
forall a. Field a => a -> a
finv c
primeElement
es1 :: Vector (Int, c)
es1 = Vector c -> Vector (Int, c)
forall a. Vector a -> Vector (Int, a)
V.indexed (Vector c -> Vector (Int, c))
-> ([c] -> Vector c) -> [c] -> Vector (Int, c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [c] -> Vector c
forall a. [a] -> Vector a
V.fromList ([c] -> Vector (Int, c)) -> [c] -> Vector (Int, c)
forall a b. (a -> b) -> a -> b
$ c
forall a. MultiplicativeMonoid a => a
one c -> [c] -> [c]
forall a. a -> [a] -> [a]
: Int -> [c] -> [c]
forall a. Int -> [a] -> [a]
P.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1) ((c -> c) -> c -> [c]
forall a. (a -> a) -> a -> [a]
iterate (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
invPE) c
invPE)
iroots :: Vector (Int, c)
iroots = ((Int, c) -> Maybe (Int, c)) -> Vector (Int, c) -> Vector (Int, c)
forall a b. (a -> Maybe b) -> Vector a -> Vector b
mapMaybe (\(Int
i,c
x) -> Maybe (Int, c) -> Maybe (Int, c) -> Bool -> Maybe (Int, c)
forall a. a -> a -> Bool -> a
bool Maybe (Int, c)
forall a. Maybe a
Nothing ((Int, c) -> Maybe (Int, c)
forall a. a -> Maybe a
Just (Int
i , c
x)) (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
lx c
x c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero)) Vector (Int, c)
es1
omega :: Poly c
omega = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
take Int
r (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly (Poly c
lx Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
syndromes)
lx' :: Poly c
lx'= Poly c -> Poly c
forall c. (Field c, Eq c) => Poly c -> Poly c
diff Poly c
lx
err :: Poly c
err = (Poly c -> Poly c -> Poly c) -> Poly c -> Vector (Poly c) -> Poly c
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl Poly c -> Poly c -> Poly c
forall a. AdditiveSemigroup a => a -> a -> a
(+) Poly c
forall a. AdditiveMonoid a => a
zero (Vector (Poly c) -> Poly c) -> Vector (Poly c) -> Poly c
forall a b. (a -> b) -> a -> b
$ ((Int, c) -> Poly c) -> Vector (Int, c) -> Vector (Poly c)
forall a b. (a -> b) -> Vector a -> Vector b
map (\(Int
i,c
x) ->
let xi :: Poly c
xi = Poly c -> Poly c -> Bool -> Poly c
forall a. a -> a -> Bool -> a
bool (Nat -> c -> Poly c
forall c. Ring c => Nat -> c -> Poly c
monomial (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) c
forall a. MultiplicativeMonoid a => a
one) (c -> Poly c
forall c. c -> Poly c
constant c
forall a. MultiplicativeMonoid a => a
one) (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)
ei :: c
ei = Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
omega c
x c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c -> c
forall a. Field a => a -> a
finv (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
lx' c
x)
in c -> Poly c
forall c. c -> Poly c
constant c
ei Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
xi) Vector (Int, c)
iroots
fx :: Poly c
fx = Poly c
encoded Poly c -> Poly c -> Poly c
forall a. AdditiveSemigroup a => a -> a -> a
+ Poly c
err
checkSum :: Vector c
checkSum = (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
fx) Vector c
rElems
decoded :: Poly c
decoded = Poly c -> Poly c -> Bool -> Poly c
forall a. a -> a -> Bool -> a
bool ([Char] -> Poly c
forall a. HasCallStack => [Char] -> a
error [Char]
"Can't decode") (Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop Int
r (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
fx) ((c -> Bool) -> Vector c -> Bool
forall a. (a -> Bool) -> Vector a -> Bool
all (c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero) Vector c
checkSum)
berlekamp :: forall c . (Field c, Eq c) => Poly c -> Int -> (Integer, Poly c)
berlekamp :: forall c. (Field c, Eq c) => Poly c -> Int -> (Integer, Poly c)
berlekamp Poly c
s Int
r
| Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== -Integer
1 = (Integer
0, Poly c
forall a. MultiplicativeMonoid a => a
one)
| Bool
P.otherwise = Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
сx0 Poly c
bx0 Integer
0 Integer
0 Nat
1 c
forall a. MultiplicativeMonoid a => a
one
where
sv :: Vector c
sv = Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
s
сx0 :: Poly c
сx0 = Poly c
forall a. MultiplicativeMonoid a => a
one :: Poly c
bx0 :: Poly c
bx0 = Poly c
forall a. MultiplicativeMonoid a => a
one :: Poly c
lenS :: Integer
lenS = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r
go :: Poly c -> Poly c -> Integer -> Integer -> Natural -> c -> (Integer, Poly c)
go :: Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx Poly c
bx Integer
n Integer
l Nat
m c
b
| Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
lenS = (Integer, Poly c) -> (Integer, Poly c) -> Bool -> (Integer, Poly c)
forall a. a -> a -> Bool -> a
bool ([Char] -> (Integer, Poly c)
forall a. HasCallStack => [Char] -> a
error [Char]
"locators didn't find") (Integer
l, Poly c
cx) (Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
cx Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
l)
| Bool
P.otherwise = (Integer, Poly c) -> (Integer, Poly c) -> Bool -> (Integer, Poly c)
forall a. a -> a -> Bool -> a
bool (Integer, Poly c)
innerChoice (Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx Poly c
bx Integer
n' Integer
l (Nat
mNat -> Nat -> Nat
forall a. AdditiveSemigroup a => a -> a -> a
+Nat
1) c
b) (c
d c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero)
where
d :: c
d = Int -> Int -> Vector c -> Vector c -> c
forall c. Semiring c => Int -> Int -> Vector c -> Vector c -> c
scalarN (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
1) (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
l Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+Int
1) Vector c
lxv Vector c
sv
lxv :: Vector c
lxv = Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
cx
cx' :: Poly c
cx' = Poly c
cx Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- c -> Poly c
forall a b. FromConstant a b => a -> b
fromConstant c
d Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c -> Poly c
forall c. c -> Poly c
constant (c -> c
forall a. Field a => a -> a
finv c
b) Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
bx Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Nat -> c -> Poly c
forall c. Ring c => Nat -> c -> Poly c
monomial Nat
m c
forall a. MultiplicativeMonoid a => a
one
n' :: Integer
n' = Integer
n Integer -> Integer -> Integer
forall a. AdditiveSemigroup a => a -> a -> a
+ Integer
1
innerChoice :: (Integer, Poly c)
innerChoice = (Integer, Poly c) -> (Integer, Poly c) -> Bool -> (Integer, Poly c)
forall a. a -> a -> Bool -> a
bool (Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx' Poly c
bx Integer
n' Integer
l (Nat
mNat -> Nat -> Nat
forall a. AdditiveSemigroup a => a -> a -> a
+Nat
1) c
b) (Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx' Poly c
cx Integer
n' (Integer
nInteger -> Integer -> Integer
forall a. AdditiveSemigroup a => a -> a -> a
+Integer
1Integer -> Integer -> Integer
forall a. AdditiveGroup a => a -> a -> a
-Integer
l) Nat
1 c
d) (Integer
2Integer -> Integer -> Integer
forall a. MultiplicativeSemigroup a => a -> a -> a
*Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n)
scalarN :: (Semiring c) => Int -> Int -> Vector c -> Vector c -> c
scalarN :: forall c. Semiring c => Int -> Int -> Vector c -> Vector c -> c
scalarN Int
q Int
l Vector c
lv Vector c
rv = c -> c -> Bool -> c
forall a. a -> a -> Bool -> a
bool (Vector c -> c
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Vector c -> c) -> Vector c -> c
forall a b. (a -> b) -> a -> b
$ (c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector c
lPadded Vector c
rPadded) c
forall a. AdditiveMonoid a => a
zero (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
l (Vector c -> Int
forall a. Vector a -> Int
length Vector c
lv) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
length Vector c
rv)
where
lPadded :: Vector c
lPadded = Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop (Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
rv) Vector c
lv
rPadded :: Vector c
rPadded = Vector c -> Vector c
forall a. Vector a -> Vector a
V.reverse (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.take Int
q Vector c
rv
diff :: ( Field c, Eq c) =>Poly c -> Poly c
diff :: forall c. (Field c, Eq c) => Poly c -> Poly c
diff Poly c
p = let cs :: Vector c
cs = Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
p in Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c -> Vector c
forall a. Vector a -> Vector a
V.tail (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ (Int -> c -> c) -> Vector c -> Vector c
forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap (\Int
i c
c -> Integer -> c -> c
forall b a. Scale b a => b -> a -> a
scale (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i :: Integer) c
c) Vector c
cs