{-# LANGUAGE TypeFamilies #-}
module AtCoder.Math
(
powMod,
invMod,
crt,
floorSum,
)
where
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Math (powMod)
import AtCoder.Internal.Math qualified as ACIM
import Data.Bits (bit)
import Data.Vector.Unboxed qualified as VU
import GHC.Stack (HasCallStack)
{-# INLINE invMod #-}
invMod ::
(HasCallStack) =>
Int ->
Int ->
Int
invMod :: HasCallStack => Int -> Int -> Int
invMod Int
x Int
m =
let !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.invMod: given invalid `m` less than 1: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m
(!Int
z1, !Int
z2) = Int -> Int -> (Int, Int)
ACIM.invGcd (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m)
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
z1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) String
"AtCoder.Math.invMod: `x^(-1) mod m` cannot be calculated when `gcd x m /= 1`"
in Int
z2
{-# INLINE crt #-}
crt :: (HasCallStack) => VU.Vector Int -> VU.Vector Int -> (Int, Int)
crt :: HasCallStack => Vector Int -> Vector Int -> (Int, Int)
crt Vector Int
r Vector Int
m = Int -> Int -> [Int] -> (Int, Int)
loop Int
0 Int
1 [Int
0 .. Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
m) String
"AtCoder.Math.crt: given `r` and `m` with different lengths"
loop :: Int -> Int -> [Int] -> (Int, Int)
loop !Int
r0 !Int
m0 [] = (Int
r0, Int
m0)
loop !Int
r0 !Int
m0 (Int
i : [Int]
rest)
| Int
m0' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m1' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
if Int
r0' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m1' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
r1'
then (Int
0, Int
0)
else Int -> Int -> [Int] -> (Int, Int)
loop Int
r0' Int
m0' [Int]
rest
| Bool
otherwise =
let (!Int
g, !Int
im) = Int -> Int -> (Int, Int)
ACIM.invGcd Int
m0' Int
m1'
u1 :: Int
u1 = Int
m1' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
g
in if ((Int
r1' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0') Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
g) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
then (Int
0, Int
0)
else
let !x :: Int
x = (Int
r1' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r0') Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
g Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
u1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
im Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
u1
!r0'' :: Int
r0'' = Int
r0' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m0'
!m0'' :: Int
m0'' = Int
m0' Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
u1
in if Int
r0'' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
then Int -> Int -> [Int] -> (Int, Int)
loop (Int
r0'' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m0'') Int
m0'' [Int]
rest
else Int -> Int -> [Int] -> (Int, Int)
loop Int
r0'' Int
m0'' [Int]
rest
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
mi) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.crt: `m[i]` is not positive: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
mi
!mi :: Int
mi = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
VU.unsafeIndex Vector Int
m Int
i
!ri :: Int
ri = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
VU.unsafeIndex Vector Int
r Int
i
!r1 :: Int
r1 = Int
ri Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
mi
!m1 :: Int
m1 = Int
mi
(!Int
m0', !Int
m1', !Int
r0', !Int
r1')
| Int
m0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
m1 = (Int
m1, Int
m0, Int
r1, Int
r0)
| Bool
otherwise = (Int
m0, Int
m1, Int
r0, Int
r1)
{-# INLINE floorSum #-}
floorSum ::
(HasCallStack) =>
Int ->
Int ->
Int ->
Int ->
Int
floorSum :: HasCallStack => Int -> Int -> Int -> Int -> Int
floorSum Int
n Int
m Int
a Int
b = Int -> Int -> Int -> Int -> Int
ACIM.floorSumUnsigned Int
n Int
m Int
a' Int
b' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
da Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
db
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int
forall a. Bits a => Int -> a
bit Int
32) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.floorSum: given invalid `n` (`" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"`)"
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m Bool -> Bool -> Bool
&& Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int
forall a. Bits a => Int -> a
bit Int
32) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Math.floorSum: given invalid `m` (`" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
m String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"`)"
a' :: Int
a'
| Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m
| Bool
otherwise = Int
a
da :: Int
da
| Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (((Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
a) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m)
| Bool
otherwise = Int
0
b' :: Int
b'
| Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m
| Bool
otherwise = Int
b
db :: Int
db
| Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* (((Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
b) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m)
| Bool
otherwise = Int
0