{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}

-- | Fast modular multiplication for `Word32` using barrett reduction.
-- Reference: https://en.wikipedia.org/wiki/Barrett_reduction
--
-- ==== __Example__
-- >>> let bt = new32 10 -- mod 10
-- >>> umod bt
-- 10
--
-- >>> mulMod bt 7 7
-- 9
--
-- @since 1.0.0.0
module AtCoder.Internal.Barrett
  ( -- * Barrett
    Barrett,
    -- * Constructor
    new32,
    new64,
    -- * Accessor
    umod,
    -- * Barrett reduction
    mulMod,
  )
where

import Data.WideWord.Word128 (Word128 (..))
import Data.Word (Word32, Word64)

-- | Fast modular multiplication using barrett reduction.
-- Reference: https://en.wikipedia.org/wiki/Barrett_reduction
--
-- @since 1.0.0.0
data Barrett = Barrett
  { Barrett -> Word32
mBarrett :: {-# UNPACK #-} !Word32,
    Barrett -> Word64
imBarrett :: {-# UNPACK #-} !Word64
  }
  deriving
    ( -- | @since 1.0.0.0
      Barrett -> Barrett -> Bool
(Barrett -> Barrett -> Bool)
-> (Barrett -> Barrett -> Bool) -> Eq Barrett
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Barrett -> Barrett -> Bool
== :: Barrett -> Barrett -> Bool
$c/= :: Barrett -> Barrett -> Bool
/= :: Barrett -> Barrett -> Bool
Eq,
      -- | @since 1.0.0.0
      Int -> Barrett -> ShowS
[Barrett] -> ShowS
Barrett -> String
(Int -> Barrett -> ShowS)
-> (Barrett -> String) -> ([Barrett] -> ShowS) -> Show Barrett
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Barrett -> ShowS
showsPrec :: Int -> Barrett -> ShowS
$cshow :: Barrett -> String
show :: Barrett -> String
$cshowList :: [Barrett] -> ShowS
showList :: [Barrett] -> ShowS
Show
    )

-- | Creates a `Barrett` for a modulus value \(m\) of type `Word32` value.
--
-- @since 1.0.0.0
{-# INLINE new32 #-}
new32 :: Word32 -> Barrett
new32 :: Word32 -> Barrett
new32 Word32
m = Word32 -> Word64 -> Barrett
Barrett Word32
m (Word64 -> Barrett) -> Word64 -> Barrett
forall a b. (a -> b) -> a -> b
$ forall a. Bounded a => a
maxBound @Word64 Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
m :: Word64) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1

-- | Creates a `Barrett` for a modulus value \(m\) of type `Word64` value.
--
-- @since 1.0.0.0
{-# INLINE new64 #-}
new64 :: Word64 -> Barrett
new64 :: Word64 -> Barrett
new64 Word64
m = Word32 -> Word64 -> Barrett
Barrett (Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
m) (Word64 -> Barrett) -> Word64 -> Barrett
forall a b. (a -> b) -> a -> b
$ forall a. Bounded a => a
maxBound @Word64 Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1

-- | Retrieves the modulus \(m\).
--
-- @since 1.0.0.0
{-# INLINE umod #-}
umod :: Barrett -> Word32
umod :: Barrett -> Word32
umod Barrett {Word32
mBarrett :: Barrett -> Word32
mBarrett :: Word32
mBarrett} = Word32
mBarrett

-- | Calculates \(a b \bmod m\).
--
-- @since 1.0.0.0
{-# INLINE mulMod #-}
mulMod :: Barrett -> Word64 -> Word64 -> Word64
mulMod :: Barrett -> Word64 -> Word64 -> Word64
mulMod Barrett {Word32
Word64
mBarrett :: Barrett -> Word32
imBarrett :: Barrett -> Word64
mBarrett :: Word32
imBarrett :: Word64
..} Word64
a Word64
b =
  let Word64
z :: Word64 = Word64
a Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
b
      Word64
x :: Word64 = Word128 -> Word64
word128Hi64 ((Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
z :: Word128) Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* (Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
imBarrett :: Word128))
      Word64
y :: Word64 = Word64
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
mBarrett
   in Word64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
z Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
y Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ if Word64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
z Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
y then Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
mBarrett else Word64
0