-- | Extra math module.
--
-- @since 1.0.0.0
module AtCoder.Extra.Math
  ( -- * Re-exports from the internal math module
    isPrime32,
    ACIM.invGcd,
    ACIM.primitiveRoot,

    -- * Binary exponentiation

    -- | ==== __Examples__
    -- >>> import AtCoder.Extra.Math qualified as M
    -- >>> import Data.Semigroup (Product(..), Sum(..))
    -- >>> getProduct $ M.power (<>) 32 (Product 2)
    -- 4294967296
    --
    -- >>> getProduct $ M.stimes' 32 (Product 2)
    -- 4294967296
    --
    -- >>> getProduct $ M.mtimes' 32 (Product 2)
    -- 4294967296
    power,
    stimes',
    mtimes',
  )
where

import AtCoder.Internal.Math qualified as ACIM
import Data.Bits ((.>>.))

-- | \(O(k \log^3 n) (k = 3)\). Returns whether the given `Int` value is a prime number.
--
-- ==== Constraints
-- - \(n < 4759123141 (2^{32} < 4759123141)\), otherwise the return value can lie
--   (see [Wikipedia](https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Testing_against_small_sets_of_bases)).
--
--
-- @since 1.1.0.0
{-# INLINE isPrime32 #-}
isPrime32 :: Int -> Bool
isPrime32 :: Int -> Bool
isPrime32 = Int -> Bool
ACIM.isPrime

-- | Calculates \(x^n\) with custom multiplication operator using the binary exponentiation
-- technique.
--
-- The internal implementation is taken from @Data.Semigroup.stimes@, but `power` uses strict
-- evaluation and is often much faster.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Constraints
-- - \(n \gt 0\)
--
-- @since 1.0.0.0
{-# INLINE power #-}
power :: (a -> a -> a) -> Int -> a -> a
power :: forall a. (a -> a -> a) -> Int -> a -> a
power a -> a -> a
op Int
n0 a
x1
  | Int
n0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Char] -> a
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"AtCoder.Extra.Math.power: positive multiplier expected"
  | Bool
otherwise = a -> Int -> a
f a
x1 Int
n0
  where
    f :: a -> Int -> a
f !a
x !Int
n
      | Int -> Bool
forall a. Integral a => a -> Bool
even Int
n = a -> Int -> a
f (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1)
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a
x
      | Bool
otherwise = a -> Int -> a -> a
g (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) a
x
    g :: a -> Int -> a -> a
g !a
x !Int
n !a
z
      | Int -> Bool
forall a. Integral a => a -> Bool
even Int
n = a -> Int -> a -> a
g (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) a
z
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a
x a -> a -> a
`op` a
z
      | Bool
otherwise = a -> Int -> a -> a
g (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) (a
x a -> a -> a
`op` a
z)

-- | Strict variant of @Data.Semigroup.stimes@.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Constraints
-- - \(n \gt 0\)
--
-- @since 1.0.0.0
{-# INLINE stimes' #-}
stimes' :: (Semigroup a) => Int -> a -> a
stimes' :: forall a. Semigroup a => Int -> a -> a
stimes' = (a -> a -> a) -> Int -> a -> a
forall a. (a -> a -> a) -> Int -> a -> a
power a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>)

-- | Strict variant of @Data.Monoid.mtimes@.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Constraints
-- - \(n \ge 0\)
--
-- @since 1.0.0.0
{-# INLINE mtimes' #-}
mtimes' :: (Monoid a) => Int -> a -> a
mtimes' :: forall a. Monoid a => Int -> a -> a
mtimes' Int
n a
x = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
n Int
0 of
  Ordering
LT -> [Char] -> a
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"AtCoder.Extra.Math.mtimes': non-negative multiplier expected"
  Ordering
EQ -> a
forall a. Monoid a => a
mempty
  Ordering
GT -> (a -> a -> a) -> Int -> a -> a
forall a. (a -> a -> a) -> Int -> a -> a
power a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>) Int
n a
x