-- | Small prime fields (up to @p < 2^31@), without type safety.
--
-- This module is considered internal.
--

{-# LANGUAGE BangPatterns #-}
module Math.FiniteField.PrimeField.Small.Raw where

--------------------------------------------------------------------------------

import Data.Bits
import Data.Int
import Data.Word
import GHC.TypeNats (Nat)

-- import Math.FiniteField.Primes

--------------------------------------------------------------------------------

type P = Word64
type F = Word64 

neg :: P -> F -> F
neg :: P -> P -> P
neg !P
p !P
x = if P
x P -> P -> Bool
forall a. Eq a => a -> a -> Bool
== P
0 then P
x else (P
p P -> P -> P
forall a. Num a => a -> a -> a
- P
x)

add :: P -> F -> F -> F
add :: P -> P -> P -> P
add !P
p !P
x !P
y = let a :: P
a = P
x P -> P -> P
forall a. Num a => a -> a -> a
+ P
y in if P
a P -> P -> Bool
forall a. Ord a => a -> a -> Bool
< P
p then P
a else (P
a P -> P -> P
forall a. Num a => a -> a -> a
- P
p)

sub :: P -> F -> F -> F
sub :: P -> P -> P -> P
sub !P
p !P
x !P
y = if P
x P -> P -> Bool
forall a. Ord a => a -> a -> Bool
>= P
y then (P
xP -> P -> P
forall a. Num a => a -> a -> a
-P
y) else (P
pP -> P -> P
forall a. Num a => a -> a -> a
+P
xP -> P -> P
forall a. Num a => a -> a -> a
-P
y)

mul :: P -> F -> F -> F
mul :: P -> P -> P -> P
mul !P
p !P
x !P
y = P -> P -> P
forall a. Integral a => a -> a -> a
mod (P
xP -> P -> P
forall a. Num a => a -> a -> a
*P
y) P
p

--------------------------------------------------------------------------------
-- * Nontrivial operations

pow :: P -> F -> Int64 -> F
pow :: P -> P -> Int64 -> P
pow !P
p !P
z !Int64
e 
  | P
z P -> P -> Bool
forall a. Eq a => a -> a -> Bool
== P
0    = P
0
  | Int64
e Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
0    = P
1
  | Int64
e Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0     = P -> P -> Int64 -> P
pow P
p (P -> P -> P
inv P
p P
z) (Int64 -> Int64
forall a. Num a => a -> a
negate Int64
e)
  | Int64
e Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
pm1i = P -> P -> Int64 -> P
go P
1 P
z (Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
mod Int64
e Int64
pm1i)
  | Bool
otherwise = P -> P -> Int64 -> P
go P
1 P
z Int64
e
  where
    pm1 :: P
pm1  = P
p P -> P -> P
forall a. Num a => a -> a -> a
- P
1
    pm1i :: Int64
pm1i = P -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral P
pm1 :: Int64
    go :: F -> F -> Int64 -> F
    go :: P -> P -> Int64 -> P
go !P
acc !P
y !Int64
e = if Int64
e Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
0 
      then P
acc
      else case (Int64
e Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
1) of
        Int64
0 -> P -> P -> Int64 -> P
go        P
acc    (P -> P -> P -> P
mul P
p P
y P
y) (Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
shiftR Int64
e Int
1)
        Int64
_ -> P -> P -> Int64 -> P
go (P -> P -> P -> P
mul P
p P
acc P
y) (P -> P -> P -> P
mul P
p P
y P
y) (Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
shiftR Int64
e Int
1)

pow' :: P -> F -> Integer -> F
pow' :: P -> P -> Integer -> P
pow' !P
p !P
z !Integer
e 
  | P
z P -> P -> Bool
forall a. Eq a => a -> a -> Bool
== P
0    = P
0
  | Integer
e Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0    = P
1
  | Integer
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0     = P -> P -> Integer -> P
pow' P
p (P -> P -> P
inv P
p P
z) (Integer -> Integer
forall a. Num a => a -> a
negate Integer
e)
  | Integer
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
pm1  = P -> P -> Int64 -> P
pow  P
p P
z (Integer -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
mod Integer
e Integer
pm1))
  | Bool
otherwise = P -> P -> Int64 -> P
pow  P
p P
z (Integer -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
e)
  where
    pm1 :: Integer
pm1 = P -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (P
p P -> P -> P
forall a. Num a => a -> a -> a
- P
1) :: Integer

-- | Inversion (using Euclid's algorithm)
inv :: P -> F -> F
inv :: P -> P -> P
inv !P
p !P
a 
  | P
a P -> P -> Bool
forall a. Eq a => a -> a -> Bool
== P
0    = P
0
  | Bool
otherwise = (P -> P -> P -> P -> P -> P
euclid64 P
p P
1 P
0 P
a P
p) 

-- | Division via Euclid's algorithm
div :: P -> F -> F -> F
div :: P -> P -> P -> P
div !P
p !P
a !P
b
  | P
b P -> P -> Bool
forall a. Eq a => a -> a -> Bool
== P
0    = P
0
  | Bool
otherwise = (P -> P -> P -> P -> P -> P
euclid64 P
p P
a P
0 P
b P
p) 

-- | Division via multiplying by the inverse
div2 :: P -> F -> F -> F
div2 :: P -> P -> P -> P
div2 !P
p !P
a !P
b = P -> P -> P -> P
mul P
p P
a (P -> P -> P
inv P
p P
b)

--------------------------------------------------------------------------------
-- * Euclidean algorithm

-- | Extended binary Euclidean algorithm
euclid64 :: Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> Word64 
euclid64 :: P -> P -> P -> P -> P -> P
euclid64 !P
p !P
x1 !P
x2 !P
u !P
v = P -> P -> P -> P -> P
go P
x1 P
x2 P
u P
v where

  halfp1 :: P
halfp1 = P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR (P
pP -> P -> P
forall a. Num a => a -> a -> a
+P
1) Int
1

  modp :: Word64 -> Word64
  modp :: P -> P
modp !P
n = P -> P -> P
forall a. Integral a => a -> a -> a
mod P
n P
p

  -- Inverse using the binary Euclidean algorithm 
  euclid :: Word64 -> Word64
  euclid :: P -> P
euclid P
a 
    | P
a P -> P -> Bool
forall a. Eq a => a -> a -> Bool
== P
0     = P
0
    | Bool
otherwise  = P -> P -> P -> P -> P
go P
1 P
0 P
a P
p
  
  go :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  go :: P -> P -> P -> P -> P
go !P
x1 !P
x2 !P
u !P
v 
    | P
uP -> P -> Bool
forall a. Eq a => a -> a -> Bool
==P
1       = P
x1
    | P
vP -> P -> Bool
forall a. Eq a => a -> a -> Bool
==P
1       = P
x2
    | Bool
otherwise  = P -> P -> P -> P -> P
stepU P
x1 P
x2 P
u P
v

  stepU :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  stepU :: P -> P -> P -> P -> P
stepU !P
x1 !P
x2 !P
u !P
v = if P -> Bool
forall a. Integral a => a -> Bool
even P
u 
    then let u' :: P
u'  = P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR P
u Int
1
             x1' :: P
x1' = if P -> Bool
forall a. Integral a => a -> Bool
even P
x1 then P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR P
x1 Int
1 else P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR P
x1 Int
1 P -> P -> P
forall a. Num a => a -> a -> a
+ P
halfp1
         in  P -> P -> P -> P -> P
stepU P
x1' P
x2 P
u' P
v
    else     P -> P -> P -> P -> P
stepV P
x1  P
x2 P
u  P
v

  stepV :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  stepV :: P -> P -> P -> P -> P
stepV !P
x1 !P
x2 !P
u !P
v = if P -> Bool
forall a. Integral a => a -> Bool
even P
v
    then let v' :: P
v'  = P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR P
v Int
1
             x2' :: P
x2' = if P -> Bool
forall a. Integral a => a -> Bool
even P
x2 then P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR P
x2 Int
1 else P -> Int -> P
forall a. Bits a => a -> Int -> a
shiftR P
x2 Int
1 P -> P -> P
forall a. Num a => a -> a -> a
+ P
halfp1
         in  P -> P -> P -> P -> P
stepV P
x1 P
x2' P
u P
v' 
    else     P -> P -> P -> P -> P
final P
x1 P
x2  P
u P
v

  final :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  final :: P -> P -> P -> P -> P
final !P
x1 !P
x2 !P
u !P
v = if P
uP -> P -> Bool
forall a. Ord a => a -> a -> Bool
>=P
v

    then let u' :: P
u'  = P
uP -> P -> P
forall a. Num a => a -> a -> a
-P
v
             x1' :: P
x1' = if P
x1 P -> P -> Bool
forall a. Ord a => a -> a -> Bool
>= P
x2 then P -> P
modp (P
x1P -> P -> P
forall a. Num a => a -> a -> a
-P
x2) else P -> P
modp (P
x1P -> P -> P
forall a. Num a => a -> a -> a
+P
pP -> P -> P
forall a. Num a => a -> a -> a
-P
x2)               
         in  P -> P -> P -> P -> P
go P
x1' P
x2  P
u' P
v 

    else let v' :: P
v'  = P
vP -> P -> P
forall a. Num a => a -> a -> a
-P
u
             x2' :: P
x2' = if P
x2 P -> P -> Bool
forall a. Ord a => a -> a -> Bool
>= P
x1 then P -> P
modp (P
x2P -> P -> P
forall a. Num a => a -> a -> a
-P
x1) else P -> P
modp (P
x2P -> P -> P
forall a. Num a => a -> a -> a
+P
pP -> P -> P
forall a. Num a => a -> a -> a
-P
x1)
         in  P -> P -> P -> P -> P
go P
x1  P
x2' P
u  P
v'

--------------------------------------------------------------------------------