{-# LANGUAGE TypeFamilies #-}

-- | Math module. It contains number-theoretic algorithms.
--
-- @since 1.0.0.0
module AtCoder.Math
  ( -- * Modulus operations

    -- These functions are internally used for `AtCoder.ModInt`.
    powMod,
    invMod,

    -- * Chinese Remainder Theorem
    crt,

    -- * Floor sum
    floorSum,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Math (powMod)
import AtCoder.Internal.Math qualified as ACIM
import Data.Bits (bit)
import Data.Vector.Unboxed qualified as VU
import GHC.Stack (HasCallStack)

-- `powMod` is re-exported from the internal math module.

-- | Returns an integer \(y\) such that \(0 \le y < m\) and  \(xy \equiv 1 \pmod m\).
--
-- ==== Constraints
-- - \(\gcd(x, m) = 1\)
-- - \(1 \leq m\)
--
-- ==== Complexity
-- - \(O(\log m)\)
--
-- ==== Example
-- >>> let m = 998244353
-- >>> (invMod 2 m) * 2 `mod` m -- (2^(-1) mod m) * 2 mod m
-- 1
--
-- @since 1.0.0.0
{-# INLINE invMod #-}
invMod ::
  (HasCallStack) =>
  -- | \(x\)
  Int ->
  -- | \(m\)
  Int ->
  -- | \(x^{-1} \bmod m\)
  Int
invMod :: HasCallStack => Int -> Int -> Int
invMod Int
x Int
m =
  let !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.invMod: given invalid `m` less than 1: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m
      (!Int
z1, !Int
z2) = Int -> Int -> (Int, Int)
ACIM.invGcd (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m)
      !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
z1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) String
"AtCoder.Math.invMod: `x^(-1) mod m` cannot be calculated when `gcd x m /= 1`"
   in Int
z2

-- | Given two arrays \(r,m\) with length \(n\), it solves the modular equation system
--
-- \[
-- x \equiv r[i] \pmod{m[i]}, \forall i \in \lbrace 0,1,\cdots, n - 1 \rbrace.
-- \]
--
-- If there is no solution, it returns \((0, 0)\). Otherwise, all the solutions can be written as the form \(x \equiv y \pmod z\), using integers
-- \(y, z\) \((0 \leq y < z = \mathrm{lcm}(m[i]))\). It returns this \((y, z)\) as a pair. If \(n=0\), it returns \((0, 1)\).
--
-- ==== Constraints
-- - \(|r| = |m|\)
-- - \(1 \le m[i]\)
-- - \(\mathrm{lcm}(m[i])\) is in `Int` bounds.
--
-- ==== Complexity
-- - \(O(n \log{\mathrm{lcm}(m[i])})\)
--
-- ==== __Example__
-- `crt` calculates \(y\) such that \(y \equiv r_i \pmod m_i, \forall i \in \lbrace 0,1,\cdots, n - 1 \rbrace\):
--
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let rs = VU.fromList @Int [6, 7, 8, 9]
-- >>> let ms = VU.fromList @Int [2, 3, 4, 5]
-- >>> crt rs ms
-- (4,60)
--
-- The property can be checked as follows:
--
-- >>> let (y, {- lcm ms -} _) = crt rs ms
-- >>> VU.zipWith mod rs ms
-- [0,1,0,4]
--
-- >>> VU.zipWith mod rs ms == VU.map (y `mod`) ms
-- True
--
-- @since 1.0.0.0
{-# INLINE crt #-}
crt :: (HasCallStack) => VU.Vector Int -> VU.Vector Int -> (Int, Int)
crt :: HasCallStack => Vector Int -> Vector Int -> (Int, Int)
crt Vector Int
r Vector Int
m = Int -> Int -> [Int] -> (Int, Int)
loop Int
0 Int
1 [Int
0 .. Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
m) String
"AtCoder.Math.crt: given `r` and `m` with different lengths"
    loop :: Int -> Int -> [Int] -> (Int, Int)
loop !Int
r0 !Int
m0 [] = (Int
r0, Int
m0)
    loop !Int
r0 !Int
m0 (Int
i : [Int]
rest)
      | Int
m0' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m1' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
          if Int
r0' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m1' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
r1'
            then (Int
0, Int
0)
            else Int -> Int -> [Int] -> (Int, Int)
loop Int
r0' Int
m0' [Int]
rest
      | Bool
otherwise =
          let (!Int
g, !Int
im) = Int -> Int -> (Int, Int)
ACIM.invGcd Int
m0' Int
m1'
              u1 :: Int
u1 = Int
m1' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
g
           in if ((Int
r1' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0') Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
g) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
                then (Int
0, Int
0)
                else
                  let !x :: Int
x = (Int
r1' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0') Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
g Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
u1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
im Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
u1
                      !r0'' :: Int
r0'' = Int
r0' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m0'
                      !m0'' :: Int
m0'' = Int
m0' Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
u1
                   in if Int
r0'' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
                        then Int -> Int -> [Int] -> (Int, Int)
loop (Int
r0'' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m0'') Int
m0'' [Int]
rest
                        else Int -> Int -> [Int] -> (Int, Int)
loop Int
r0'' Int
m0'' [Int]
rest
      where
        !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
mi) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.crt: `m[i]` is not positive: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
mi
        !mi :: Int
mi = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
VU.unsafeIndex Vector Int
m Int
i
        !ri :: Int
ri = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
VU.unsafeIndex Vector Int
r Int
i
        !r1 :: Int
r1 = Int
ri Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
mi
        !m1 :: Int
m1 = Int
mi
        (!Int
m0', !Int
m1', !Int
r0', !Int
r1')
          | Int
m0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
m1 = (Int
m1, Int
m0, Int
r1, Int
r0)
          | Bool
otherwise = (Int
m0, Int
m1, Int
r0, Int
r1)

-- | Returns \(\sum\limits_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor\).
--
-- ==== Constraints
-- - \(0 \le n\)
-- - \(1 \le m\)
--
-- ==== Complexity
-- - \(O(\log m)\)
--
-- ==== __Example__
-- `floorSum` calculates the number of points surrounded by a line
-- \(y = \frac {a \times x + b} {m} \) and \(x, y\) axes in \(O(\log m)\) time:
--
-- >>> floorSum 5 1 1 1 -- floorSum n {- line information -} m a b
-- 15
--
-- @
--   y
--   ^
-- 6 |
-- 5 |           o           line: y = x + 1
-- 4 |        o  o           The number of \`o\` is 15
-- 3 |     o  o  o
-- 2 |  o  o  o  o
-- 1 |  o  o  o  o
-- --+-----------------> x
--   0  1  2  3  4  5
--                  n = 5
-- @
--
-- @since 1.0.0.0
{-# INLINE floorSum #-}
floorSum ::
  (HasCallStack) =>
  -- | \(n\)
  Int ->
  -- | \(m\)
  Int ->
  -- | \(a\)
  Int ->
  -- | \(b\)
  Int ->
  -- | \(\sum\limits_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor\)
  Int
floorSum :: HasCallStack => Int -> Int -> Int -> Int -> Int
floorSum Int
n Int
m Int
a Int
b = Int -> Int -> Int -> Int -> Int
ACIM.floorSumUnsigned Int
n Int
m Int
a' Int
b' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
da Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
db
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int
forall a. Bits a => Int -> a
bit Int
32) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.floorSum: given invalid `n` (`" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"`)"
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m Bool -> Bool -> Bool
&& Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int
forall a. Bits a => Int -> a
bit Int
32) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.floorSum: given invalid `m` (`" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"`)"
    a' :: Int
a'
      | Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m
      | Bool
otherwise = Int
a
    da :: Int
da
      | Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (((Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
a) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m)
      | Bool
otherwise = Int
0
    b' :: Int
b'
      | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m
      | Bool
otherwise = Int
b
    db :: Int
db
      | Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* (((Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
b) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m)
      | Bool
otherwise = Int
0