{-# LANGUAGE MagicHash #-}
module AtCoder.Convolution
(
convolution,
convolutionRaw,
convolution64,
)
where
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Bit qualified as ACIB
import AtCoder.Internal.Convolution qualified as ACIC
import AtCoder.Internal.Math qualified as ACIM
import AtCoder.ModInt qualified as AM
import Data.Bits (bit)
import Data.Proxy (Proxy (..))
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Exts (proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (natVal')
convolution ::
forall p.
(HasCallStack, AM.Modulus p) =>
VU.Vector (AM.ModInt p) ->
VU.Vector (AM.ModInt p) ->
VU.Vector (AM.ModInt p)
convolution :: forall (p :: Nat).
(HasCallStack, Modulus p) =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolution Vector (ModInt p)
a Vector (ModInt p)
b
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Vector (ModInt p)
forall a. Unbox a => Vector a
VU.empty
| Bool
otherwise =
let z :: Int
z = Int -> Int
ACIB.bitCeil (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
!modulus :: Int
modulus = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# p -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p))
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert ((Int
modulus Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
z Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Convolution.convolution: not works when `(m - 1) mod z /= 0`: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
m, Int
z)
in if Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
60
then Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
forall (p :: Nat).
Modulus p =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
ACIC.convolutionNaive Vector (ModInt p)
a Vector (ModInt p)
b
else Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
forall (p :: Nat).
Modulus p =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
ACIC.convolutionFft Vector (ModInt p)
a Vector (ModInt p)
b
where
n :: Int
n = Vector (ModInt p) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (ModInt p)
a
m :: Int
m = Vector (ModInt p) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (ModInt p)
b
convolutionRaw ::
forall p a.
(HasCallStack, AM.Modulus p, Integral a, VU.Unbox a) =>
Proxy p ->
VU.Vector a ->
VU.Vector a ->
VU.Vector a
convolutionRaw :: forall (p :: Nat) a.
(HasCallStack, Modulus p, Integral a, Unbox a) =>
Proxy p -> Vector a -> Vector a -> Vector a
convolutionRaw Proxy p
_ Vector a
a Vector a
b
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Vector a
forall a. Unbox a => Vector a
VU.empty
| Bool
otherwise =
let z :: Int
z = Int -> Int
ACIB.bitCeil (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
!modulus :: Int
modulus = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# p -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @p))
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert ((Int
modulus Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
z Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"AtCoder.Convolution.convolutionRaw: not works when `(m - 1) mod z /= 0`: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
m, Int
z)
in
(ModInt p -> a) -> Vector (ModInt p) -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map ModInt p -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector (ModInt p) -> Vector a) -> Vector (ModInt p) -> Vector a
forall a b. (a -> b) -> a -> b
$ forall (p :: Nat).
(HasCallStack, Modulus p) =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolution @p ((a -> ModInt p) -> Vector a -> Vector (ModInt p)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map a -> ModInt p
forall a b. (Integral a, Num b) => a -> b
fromIntegral Vector a
a) ((a -> ModInt p) -> Vector a -> Vector (ModInt p)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map a -> ModInt p
forall a b. (Integral a, Num b) => a -> b
fromIntegral Vector a
b)
where
n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
a
m :: Int
m = Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
b
convolution64 ::
(HasCallStack) =>
VU.Vector Int ->
VU.Vector Int ->
VU.Vector Int
convolution64 :: HasCallStack => Vector Int -> Vector Int -> Vector Int
convolution64 Vector Int
a Vector Int
b
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Vector Int
forall a. Unbox a => Vector a
VU.empty
| Bool
otherwise =
let Int
mod1 :: Int = Int
754974721
Int
mod2 :: Int = Int
167772161
Int
mod3 :: Int = Int
469762049
Int
m2m3 :: Int = Int
mod2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod3
Int
m1m3 :: Int = Int
mod1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod3
Int
m1m2 :: Int = Int
mod1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod2
Int
m1m2m3 :: Int = Int
mod1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod3
(!Int
_, !Int
i1) = Int -> Int -> (Int, Int)
ACIM.invGcd (Int
mod2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod3) Int
mod1
(!Int
_, !Int
i2) = Int -> Int -> (Int, Int)
ACIM.invGcd (Int
mod1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod3) Int
mod2
(!Int
_, !Int
i3) = Int -> Int -> (Int, Int)
ACIM.invGcd (Int
mod1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
mod2) Int
mod3
maxAbBit :: Int
maxAbBit = Int
24
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int
forall a. Bits a => Int -> a
bit Int
maxAbBit) String
"AtCoder.Convolution.convolution64: given too long vector as input"
c1 :: Vector (ModInt 754974721)
c1 = Vector (ModInt 754974721)
-> Vector (ModInt 754974721) -> Vector (ModInt 754974721)
forall (p :: Nat).
(HasCallStack, Modulus p) =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolution ((Int -> ModInt 754974721)
-> Vector Int -> Vector (ModInt 754974721)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new @754974721) Vector Int
a) ((Int -> ModInt 754974721)
-> Vector Int -> Vector (ModInt 754974721)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new @754974721) Vector Int
b)
c2 :: Vector (ModInt 167772161)
c2 = Vector (ModInt 167772161)
-> Vector (ModInt 167772161) -> Vector (ModInt 167772161)
forall (p :: Nat).
(HasCallStack, Modulus p) =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolution ((Int -> ModInt 167772161)
-> Vector Int -> Vector (ModInt 167772161)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new @167772161) Vector Int
a) ((Int -> ModInt 167772161)
-> Vector Int -> Vector (ModInt 167772161)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new @167772161) Vector Int
b)
c3 :: Vector (ModInt 469762049)
c3 = Vector (ModInt 469762049)
-> Vector (ModInt 469762049) -> Vector (ModInt 469762049)
forall (p :: Nat).
(HasCallStack, Modulus p) =>
Vector (ModInt p) -> Vector (ModInt p) -> Vector (ModInt p)
convolution ((Int -> ModInt 469762049)
-> Vector Int -> Vector (ModInt 469762049)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new @469762049) Vector Int
a) ((Int -> ModInt 469762049)
-> Vector Int -> Vector (ModInt 469762049)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map (forall (a :: Nat). KnownNat a => Int -> ModInt a
AM.new @469762049) Vector Int
b)
in (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
MVector s Int
c <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
let !offset :: Vector Int
offset = forall a. Unbox a => Int -> [a] -> Vector a
VU.fromListN @Int Int
5 [Int
0, Int
0, Int
m1m2m3, Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m1m2m3, Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m1m2m3]
(Int
-> Int
-> (ModInt 754974721, ModInt 167772161, ModInt 469762049)
-> ST s Int)
-> Int
-> Vector (ModInt 754974721, ModInt 167772161, ModInt 469762049)
-> ST s ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> Int -> b -> m a) -> a -> Vector b -> m ()
VU.ifoldM'_
( \ !Int
x Int
i (AM.ModInt !Word32
x1, AM.ModInt !Word32
x2, AM.ModInt !Word32
x3) -> do
let !x' :: Int
x' =
Int
x
Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
mod1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m2m3
Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
mod2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m1m3
Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
mod3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
m1m2
let diff :: Int
diff = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
mod1
let diff' :: Int
diff' = if Int
diff Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Int
diff Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
mod1 else Int
diff
let !x'' :: Int
x'' = Int
x' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Int
offset Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! (Int
diff' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
5)
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s Int
MVector (PrimState (ST s)) Int
c Int
i Int
x''
Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x
)
(Int
0 :: Int)
(Vector (ModInt 754974721, ModInt 167772161, ModInt 469762049)
-> ST s ())
-> Vector (ModInt 754974721, ModInt 167772161, ModInt 469762049)
-> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector (ModInt 754974721)
-> Vector (ModInt 167772161)
-> Vector (ModInt 469762049)
-> Vector (ModInt 754974721, ModInt 167772161, ModInt 469762049)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
Vector a -> Vector b -> Vector c -> Vector (a, b, c)
VU.zip3 Vector (ModInt 754974721)
c1 Vector (ModInt 167772161)
c2 Vector (ModInt 469762049)
c3
MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Int
c
where
n :: Int
n = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
a
m :: Int
m = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
b