{-# LANGUAGE AllowAmbiguousTypes #-}

module ZkFold.Base.Algorithm.ReedSolomon where


import           Data.Bool                                  (bool)
import           Data.Vector                                as V hiding (sum)
import           GHC.Natural                                (Natural)
import           Prelude                                    (Eq, Int, Integer, Maybe (..), error, fromIntegral, iterate,
                                                             min, ($), (.), (<=), (==))
import qualified Prelude                                    as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number           (KnownNat, value)
import           ZkFold.Base.Algebra.Polynomials.Univariate


numberOfError :: forall n k. (KnownNat n, KnownNat k) => Natural
numberOfError :: forall (n :: Nat) (k :: Nat). (KnownNat n, KnownNat k) => Nat
numberOfError = (forall (n :: Nat). KnownNat n => Nat
value @n Nat -> Nat -> Nat
-! forall (n :: Nat). KnownNat n => Nat
value @k) Nat -> Nat -> Nat
forall a. SemiEuclidean a => a -> a -> a
`div` Nat
2

generator :: (Field a, Eq a) => Int -> a -> Poly a
generator :: forall a. (Field a, Eq a) => Int -> a -> Poly a
generator Int
r a
a = (a -> Poly a -> Poly a) -> Poly a -> Vector a -> Poly a
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr (\a
ai Poly a
pi -> a -> Poly a
forall {c}. (Ring c, Eq c) => c -> Poly c
toLinPoly a
ai Poly a -> Poly a -> Poly a
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly a
pi) Poly a
forall a. MultiplicativeMonoid a => a
one Vector a
roots
    where
        roots :: Vector a
roots = Int -> (a -> a) -> a -> Vector a
forall a. Int -> (a -> a) -> a -> Vector a
V.iterateN Int
r (a -> a -> a
forall a. MultiplicativeSemigroup a => a -> a -> a
* a
a) a
a
        toLinPoly :: c -> Poly c
toLinPoly c
p = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ [c] -> Vector c
forall a. [a] -> Vector a
fromList [c -> c
forall a. AdditiveGroup a => a -> a
negate c
p, c
forall a. MultiplicativeMonoid a => a
one]

encode :: (Field c, Eq c) => [c] -> c -> Int -> Poly c
encode :: forall c. (Field c, Eq c) => [c] -> c -> Int -> Poly c
encode [c]
msg c
prim_elem Int
r = Poly c
msg_padded Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- Poly c
reminder
    where
        g_x :: Poly c
g_x = Int -> c -> Poly c
forall a. (Field a, Eq a) => Int -> a -> Poly a
generator Int
r c
prim_elem
        poly_msg :: Poly c
poly_msg = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ [c] -> Vector c
forall a. [a] -> Vector a
fromList [c]
msg
        msg_padded :: Poly c
msg_padded = c -> Nat -> Poly c -> Poly c
forall c. Ring c => c -> Nat -> Poly c -> Poly c
scaleP c
forall a. MultiplicativeMonoid a => a
one (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r) Poly c
poly_msg
        (Poly c
_, Poly c
reminder) = Poly c -> Poly c -> (Poly c, Poly c)
forall c. (Field c, Eq c) => Poly c -> Poly c -> (Poly c, Poly c)
qr Poly c
msg_padded Poly c
g_x

-- beta = one
decode :: (Field c, Eq c) => Poly c -> c -> Int -> Int -> Poly c
decode :: forall c. (Field c, Eq c) => Poly c -> c -> Int -> Int -> Poly c
decode Poly c
encoded c
primeElement Int
r Int
n = Poly c -> Poly c -> Bool -> Poly c
forall a. a -> a -> Bool -> a
bool Poly c
decoded Poly c
encoded' Bool
isCorrect
    where
        rElems :: Vector c
rElems = Int -> (c -> c) -> c -> Vector c
forall a. Int -> (a -> a) -> a -> Vector a
iterateN Int
r (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
primeElement) c
primeElement

        encoded' :: Poly c
encoded' = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c)
-> (Vector c -> Vector c) -> Vector c -> Poly c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop Int
r (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
encoded
        syndromes :: Poly c
syndromes = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c)
-> (Vector c -> Vector c) -> Vector c -> Poly c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
encoded) (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c
rElems
        isCorrect :: Bool
isCorrect = Poly c
forall a. AdditiveMonoid a => a
zero Poly c -> Poly c -> Bool
forall a. Eq a => a -> a -> Bool
== Poly c
syndromes

        (Integer
_, Poly c
lx) = Poly c -> Int -> (Integer, Poly c)
forall c. (Field c, Eq c) => Poly c -> Int -> (Integer, Poly c)
berlekamp Poly c
syndromes Int
r

        invPE :: c
invPE = c -> c
forall a. Field a => a -> a
finv c
primeElement
        es1 :: Vector (Int, c)
es1 = Vector c -> Vector (Int, c)
forall a. Vector a -> Vector (Int, a)
V.indexed (Vector c -> Vector (Int, c))
-> ([c] -> Vector c) -> [c] -> Vector (Int, c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [c] -> Vector c
forall a. [a] -> Vector a
V.fromList ([c] -> Vector (Int, c)) -> [c] -> Vector (Int, c)
forall a b. (a -> b) -> a -> b
$ c
forall a. MultiplicativeMonoid a => a
one c -> [c] -> [c]
forall a. a -> [a] -> [a]
: Int -> [c] -> [c]
forall a. Int -> [a] -> [a]
P.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Int
1) ((c -> c) -> c -> [c]
forall a. (a -> a) -> a -> [a]
iterate (c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c
invPE) c
invPE)
        iroots :: Vector (Int, c)
iroots = ((Int, c) -> Maybe (Int, c)) -> Vector (Int, c) -> Vector (Int, c)
forall a b. (a -> Maybe b) -> Vector a -> Vector b
mapMaybe (\(Int
i,c
x) -> Maybe (Int, c) -> Maybe (Int, c) -> Bool -> Maybe (Int, c)
forall a. a -> a -> Bool -> a
bool Maybe (Int, c)
forall a. Maybe a
Nothing ((Int, c) -> Maybe (Int, c)
forall a. a -> Maybe a
Just (Int
i , c
x)) (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
lx c
x c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero)) Vector (Int, c)
es1

        omega :: Poly c
omega = Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
take Int
r (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly (Poly c
lx Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
syndromes)
        lx' :: Poly c
lx'= Poly c -> Poly c
forall c. (Field c, Eq c) => Poly c -> Poly c
diff Poly c
lx

        err :: Poly c
err = (Poly c -> Poly c -> Poly c) -> Poly c -> Vector (Poly c) -> Poly c
forall a b. (a -> b -> a) -> a -> Vector b -> a
V.foldl Poly c -> Poly c -> Poly c
forall a. AdditiveSemigroup a => a -> a -> a
(+) Poly c
forall a. AdditiveMonoid a => a
zero (Vector (Poly c) -> Poly c) -> Vector (Poly c) -> Poly c
forall a b. (a -> b) -> a -> b
$ ((Int, c) -> Poly c) -> Vector (Int, c) -> Vector (Poly c)
forall a b. (a -> b) -> Vector a -> Vector b
map (\(Int
i,c
x) ->
            let xi :: Poly c
xi = Poly c -> Poly c -> Bool -> Poly c
forall a. a -> a -> Bool -> a
bool (Nat -> c -> Poly c
forall c. Ring c => Nat -> c -> Poly c
monomial (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) c
forall a. MultiplicativeMonoid a => a
one) (c -> Poly c
forall c. c -> Poly c
constant c
forall a. MultiplicativeMonoid a => a
one) (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)
                ei :: c
ei = Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
omega c
x c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c -> c
forall a. Field a => a -> a
finv (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
lx' c
x)
            in c -> Poly c
forall c. c -> Poly c
constant c
ei Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
xi) Vector (Int, c)
iroots

        fx :: Poly c
fx = Poly c
encoded Poly c -> Poly c -> Poly c
forall a. AdditiveSemigroup a => a -> a -> a
+ Poly c
err
        checkSum :: Vector c
checkSum = (c -> c) -> Vector c -> Vector c
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Poly c -> c -> c
forall c. Ring c => Poly c -> c -> c
evalPoly Poly c
fx) Vector c
rElems

        decoded :: Poly c
decoded = Poly c -> Poly c -> Bool -> Poly c
forall a. a -> a -> Bool -> a
bool ([Char] -> Poly c
forall a. HasCallStack => [Char] -> a
error [Char]
"Can't decode") (Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop Int
r (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
fx) ((c -> Bool) -> Vector c -> Bool
forall a. (a -> Bool) -> Vector a -> Bool
all (c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero) Vector c
checkSum)


berlekamp :: forall c . (Field c, Eq c) => Poly c -> Int -> (Integer, Poly c)
berlekamp :: forall c. (Field c, Eq c) => Poly c -> Int -> (Integer, Poly c)
berlekamp Poly c
s Int
r
    | Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== -Integer
1 = (Integer
0, Poly c
forall a. MultiplicativeMonoid a => a
one)
    | Bool
P.otherwise = Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
сx0 Poly c
bx0 Integer
0 Integer
0 Nat
1 c
forall a. MultiplicativeMonoid a => a
one
    where
        sv :: Vector c
sv = Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
s
        сx0 :: Poly c
сx0 = Poly c
forall a. MultiplicativeMonoid a => a
one :: Poly c
        bx0 :: Poly c
bx0 = Poly c
forall a. MultiplicativeMonoid a => a
one :: Poly c
        lenS :: Integer
lenS = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
r

        go :: Poly c -> Poly c -> Integer -> Integer -> Natural -> c -> (Integer, Poly c)
        go :: Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx Poly c
bx Integer
n Integer
l Nat
m c
b
            | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
lenS = (Integer, Poly c) -> (Integer, Poly c) -> Bool -> (Integer, Poly c)
forall a. a -> a -> Bool -> a
bool ([Char] -> (Integer, Poly c)
forall a. HasCallStack => [Char] -> a
error [Char]
"locators didn't find") (Integer
l, Poly c
cx) (Poly c -> Integer
forall c. Poly c -> Integer
deg Poly c
cx Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
l)
            | Bool
P.otherwise = (Integer, Poly c) -> (Integer, Poly c) -> Bool -> (Integer, Poly c)
forall a. a -> a -> Bool -> a
bool (Integer, Poly c)
innerChoice (Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx Poly c
bx Integer
n' Integer
l (Nat
mNat -> Nat -> Nat
forall a. AdditiveSemigroup a => a -> a -> a
+Nat
1) c
b) (c
d c -> c -> Bool
forall a. Eq a => a -> a -> Bool
== c
forall a. AdditiveMonoid a => a
zero)
            where
                d :: c
d = Int -> Int -> Vector c -> Vector c -> c
forall c. Semiring c => Int -> Int -> Vector c -> Vector c -> c
scalarN (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+ Int
1) (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
l Int -> Int -> Int
forall a. Num a => a -> a -> a
P.+Int
1) Vector c
lxv Vector c
sv

                lxv :: Vector c
lxv = Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
cx
                cx' :: Poly c
cx' = Poly c
cx Poly c -> Poly c -> Poly c
forall a. AdditiveGroup a => a -> a -> a
- c -> Poly c
forall a b. FromConstant a b => a -> b
fromConstant c
d Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* c -> Poly c
forall c. c -> Poly c
constant (c -> c
forall a. Field a => a -> a
finv c
b) Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Poly c
bx Poly c -> Poly c -> Poly c
forall a. MultiplicativeSemigroup a => a -> a -> a
* Nat -> c -> Poly c
forall c. Ring c => Nat -> c -> Poly c
monomial Nat
m c
forall a. MultiplicativeMonoid a => a
one

                n' :: Integer
n' = Integer
n Integer -> Integer -> Integer
forall a. AdditiveSemigroup a => a -> a -> a
+ Integer
1
                innerChoice :: (Integer, Poly c)
innerChoice = (Integer, Poly c) -> (Integer, Poly c) -> Bool -> (Integer, Poly c)
forall a. a -> a -> Bool -> a
bool (Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx' Poly c
bx Integer
n' Integer
l (Nat
mNat -> Nat -> Nat
forall a. AdditiveSemigroup a => a -> a -> a
+Nat
1) c
b) (Poly c
-> Poly c -> Integer -> Integer -> Nat -> c -> (Integer, Poly c)
go Poly c
cx' Poly c
cx Integer
n' (Integer
nInteger -> Integer -> Integer
forall a. AdditiveSemigroup a => a -> a -> a
+Integer
1Integer -> Integer -> Integer
forall a. AdditiveGroup a => a -> a -> a
-Integer
l) Nat
1 c
d) (Integer
2Integer -> Integer -> Integer
forall a. MultiplicativeSemigroup a => a -> a -> a
*Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n)


scalarN :: (Semiring c) => Int -> Int -> Vector c -> Vector c -> c
scalarN :: forall c. Semiring c => Int -> Int -> Vector c -> Vector c -> c
scalarN Int
q Int
l Vector c
lv Vector c
rv = c -> c -> Bool -> c
forall a. a -> a -> Bool -> a
bool (Vector c -> c
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Vector c -> c) -> Vector c -> c
forall a b. (a -> b) -> a -> b
$ (c -> c -> c) -> Vector c -> Vector c -> Vector c
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith c -> c -> c
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector c
lPadded Vector c
rPadded) c
forall a. AdditiveMonoid a => a
zero (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
l (Vector c -> Int
forall a. Vector a -> Int
length Vector c
lv) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
length Vector c
rv)
    where
        lPadded :: Vector c
lPadded = Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.drop (Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
P.- Vector c -> Int
forall a. Vector a -> Int
V.length Vector c
rv) Vector c
lv
        rPadded :: Vector c
rPadded = Vector c -> Vector c
forall a. Vector a -> Vector a
V.reverse (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ Int -> Vector c -> Vector c
forall a. Int -> Vector a -> Vector a
V.take Int
q Vector c
rv


diff :: ( Field c, Eq c) =>Poly c -> Poly c
diff :: forall c. (Field c, Eq c) => Poly c -> Poly c
diff Poly c
p = let cs :: Vector c
cs = Poly c -> Vector c
forall c. Poly c -> Vector c
fromPoly Poly c
p in Vector c -> Poly c
forall c. (Ring c, Eq c) => Vector c -> Poly c
toPoly (Vector c -> Poly c) -> Vector c -> Poly c
forall a b. (a -> b) -> a -> b
$ Vector c -> Vector c
forall a. Vector a -> Vector a
V.tail (Vector c -> Vector c) -> Vector c -> Vector c
forall a b. (a -> b) -> a -> b
$ (Int -> c -> c) -> Vector c -> Vector c
forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap (\Int
i c
c -> Integer -> c -> c
forall b a. Scale b a => b -> a -> a
scale (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i :: Integer) c
c) Vector c
cs