-- | Fenwick tree tests. module Tests.Math (tests) where import AtCoder.Math qualified as AM import Control.Monad (when) import Control.Monad.Fix (fix) import Data.Foldable import Data.List qualified as L import Data.Vector.Generic.Mutable qualified as VGM import Data.Vector.Unboxed qualified as VU import Data.Vector.Unboxed.Mutable qualified as VUM import Test.Tasty import Test.Tasty.HUnit floorSumNaive :: Int -> Int -> Int -> Int -> Int floorSumNaive n m a b = sum [(a * i + b) `div` m | i <- [0 .. n - 1]] unit_powMod :: TestTree unit_powMod = testCase "powMod" $ do let naive x n modulo = let y = x `mod` modulo z0 = 1 `mod` modulo in L.foldl' (\z' _ -> z' * y `mod` modulo) z0 [0 .. n - 1] for_ [-100 .. 100] $ \a -> do for_ [0 .. 100] $ \b -> do for_ [1 .. 100] $ \c -> do naive a b c @=? AM.powMod a b c unit_invBoundHand :: TestTree unit_invBoundHand = testCase "invBoundHand" $ do let min_ = minBound @Int let max_ = maxBound @Int AM.invMod (-1) max_ @=? AM.invMod min_ max_ 1 @=? AM.invMod max_ (max_ - 1) max_ - 1 @=? AM.invMod (max_ - 1) max_ 2 @=? AM.invMod (max_ `div` 2 + 1) max_ unit_invMod :: TestTree unit_invMod = testCase "invMod" $ do for_ [-100 .. 100] $ \a -> do for_ [1 .. 1000] $ \b -> do when (gcd (a `mod` b) b == 1) $ do let c = AM.invMod a b assertBool "" $ 0 <= c assertBool "" $ c < b (1 `mod` b) @=? (a * c `mod` b + b) `mod` b unit_invModZero :: TestTree unit_invModZero = testCase "invModZero" $ do 0 @=? AM.invMod 0 1 for_ [0 .. 10 - 1] $ \i -> do 0 @=? AM.invMod i 1 0 @=? AM.invMod (-i) 1 0 @=? AM.invMod (minBound @Int + i) 1 0 @=? AM.invMod (maxBound @Int - i) 1 unit_floorSum :: TestTree unit_floorSum = testCase "floorSum" $ do for_ [0 .. 20 - 1] $ \n -> do for_ [1 .. 20 - 1] $ \m -> do for_ [-20 .. 19] $ \a -> do for_ [-20 .. 19] $ \b -> do floorSumNaive n m a b @?= AM.floorSum n m a b unit_crtHand :: TestTree unit_crtHand = testCase "crtHand" $ do let (!res1, !res2) = AM.crt (VU.fromList [1, 2, 1]) (VU.fromList [2, 3, 2]) 5 @=? res1 6 @=? res2 unit_crt2 :: TestTree unit_crt2 = testCase "crt2" $ do for_ [1 .. 20] $ \a -> do for_ [1 .. 20] $ \b -> do for_ [-10 .. 10] $ \c -> do for_ [-10 .. 10] $ \d -> do let (!res1, !res2) = AM.crt (VU.fromList [c, d]) (VU.fromList [a, b]) if res2 == 0 then do for_ [0 .. a * b `div` gcd a b - 1] $ \x -> do assertBool "" $ x `mod` a /= c || x `mod` b /= d else do a * b `div` gcd a b @=? res2 c `mod` a @=? res1 `mod` a d `mod` b @=? res1 `mod` b unit_crt3 :: TestTree unit_crt3 = testCase "crt3" $ do for_ [1 .. 5] $ \a -> do for_ [1 .. 5] $ \b -> do for_ [1 .. 5] $ \c -> do for_ [-5 .. 5] $ \d -> do for_ [-5 .. 5] $ \e -> do for_ [-5 .. 5] $ \f -> do let (!res1, !res2) = AM.crt (VU.fromList [d, e, f]) (VU.fromList [a, b, c]) let lcm = a * b `div` gcd a b let lcm' = lcm * c `div` gcd lcm c if res2 == 0 then do for_ [0 .. lcm' - 1] $ \x -> do assertBool "" $ x `mod` a /= d || x `mod` b /= e || x `mod` c /= f else do lcm' @=? res2 d `mod` a @=? res1 `mod` a e `mod` b @=? res1 `mod` b f `mod` c @=? res1 `mod` c pure () pure () unit_crtOverflow :: TestTree unit_crtOverflow = testCase "crtOverflow" $ do let r0 = 0 let r1 = 1_000_000_000_000 - 2 let m0 = 900577 let m1 = 1_000_000_000_000 let (!res1, !res2) = AM.crt (VU.fromList [r0, r1]) (VU.fromList [m0, m1]) m0 * m1 @=? res2 r0 @=? res1 `mod` m0 r1 @=? res1 `mod` m1 unit_crtBound :: TestTree unit_crtBound = testCase "crtBound" $ do let inf = maxBound @Int let ps = VU.create $ do p <- VUM.unsafeNew (2 * 10 + 3) for_ [1 .. 10] $ \i -> do VGM.write p (2 * (i - 1) + 0) i VGM.write p (2 * (i - 1) + 1) $ inf - (i - 1) VGM.write p (2 * 10 + 0) 998244353 VGM.write p (2 * 10 + 1) 1_000_000_007 VGM.write p (2 * 10 + 2) 1_000_000_007 pure p for_ [ (inf, inf), (1, inf), (inf, 1), (7, inf), (inf `div` 337, 337), (2, (inf - 1) `div` 2) ] $ \(!a_, !b_) -> do for_ [0 .. 1] $ \ph -> do let (!a, !b) | ph == 0 = (a_, b_) | otherwise = (b_, a_) VU.forM_ ps $ \ans -> do let (!res1, !res2) = AM.crt (VU.fromList [ans `mod` a, ans `mod` b]) (VU.fromList [a, b]) let lcm = a `div` gcd a b * b lcm @=? res2 ans `mod` lcm @=? res1 factorInf <- VU.unsafeThaw $ VU.fromList [49 :: Int, 73, 127, 337, 92737, 649657] fix $ \loop -> do factors <- VU.unsafeFreeze factorInf VU.forM_ ps $ \ans -> do let r = VU.map (\f -> ans `mod` f) factors let (!res1, !res2) = AM.crt r factors ans `mod` inf @=? res1 inf @=? res2 b <- VUM.nextPermutation factorInf when b loop factorInf1 <- VU.unsafeThaw $ VU.fromList [2 :: Int, 3, 715827883, 2147483647] fix $ \loop -> do factors <- VU.unsafeFreeze factorInf1 VU.forM_ ps $ \ans -> do let r = VU.map (\f -> ans `mod` f) factors let (!res1, !res2) = AM.crt r factors ans `mod` (inf - 1) @=? res1 (inf - 1) @=? res2 b <- VUM.nextPermutation factorInf1 when b loop pure () tests :: [TestTree] tests = [ unit_powMod, unit_invBoundHand, unit_invMod, unit_invModZero, unit_floorSum, unit_crtHand, unit_crt2, unit_crt3, unit_crtOverflow, unit_crtBound ]