{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
module AtCoder.Extra.Semigroup.Matrix
(
Matrix (..),
new,
zero,
ident,
diag,
map,
mulToCol,
mul,
mulMod,
mulMint,
pow,
powMod,
powMint,
)
where
import AtCoder.Extra.Math qualified as ACEM
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Barrett qualified as BT
import AtCoder.ModInt qualified as M
import Data.Foldable (for_)
import Data.Semigroup (Semigroup (..))
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Exts (proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, natVal')
import Prelude hiding (map)
data Matrix a = Matrix
{
forall a. Matrix a -> Int
hM :: {-# UNPACK #-} !Int,
forall a. Matrix a -> Int
wM :: {-# UNPACK #-} !Int,
forall a. Matrix a -> Vector a
vecM :: !(VU.Vector a)
}
deriving
(
Int -> Matrix a -> ShowS
[Matrix a] -> ShowS
Matrix a -> String
(Int -> Matrix a -> ShowS)
-> (Matrix a -> String) -> ([Matrix a] -> ShowS) -> Show (Matrix a)
forall a. (Show a, Unbox a) => Int -> Matrix a -> ShowS
forall a. (Show a, Unbox a) => [Matrix a] -> ShowS
forall a. (Show a, Unbox a) => Matrix a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. (Show a, Unbox a) => Int -> Matrix a -> ShowS
showsPrec :: Int -> Matrix a -> ShowS
$cshow :: forall a. (Show a, Unbox a) => Matrix a -> String
show :: Matrix a -> String
$cshowList :: forall a. (Show a, Unbox a) => [Matrix a] -> ShowS
showList :: [Matrix a] -> ShowS
Show,
Matrix a -> Matrix a -> Bool
(Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool) -> Eq (Matrix a)
forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
== :: Matrix a -> Matrix a -> Bool
$c/= :: forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
/= :: Matrix a -> Matrix a -> Bool
Eq
)
type Col a = VU.Vector a
{-# INLINE new #-}
new :: (HasCallStack, VU.Unbox a) => Int -> Int -> VU.Vector a -> Matrix a
new :: forall a.
(HasCallStack, Unbox a) =>
Int -> Int -> Vector a -> Matrix a
new Int
h Int
w Vector a
vec
| Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
vec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w = String -> Matrix a
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix: size mismatch"
| Bool
otherwise = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w Vector a
vec
{-# INLINE zero #-}
zero :: (VU.Unbox a, Num a) => Int -> Matrix a
zero :: forall a. (Unbox a, Num a) => Int -> Matrix a
zero Int
n = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
VU.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
{-# INLINE ident #-}
ident :: (VU.Unbox a, Num a) => Int -> Matrix a
ident :: forall a. (Unbox a, Num a) => Int -> Matrix a
ident Int
n = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
MVector s a
vec <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
vec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) a
1
MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
vec
{-# INLINE diag #-}
diag :: (VU.Unbox a, Num a) => Int -> VU.Vector a -> Matrix a
diag :: forall a. (Unbox a, Num a) => Int -> Vector a -> Matrix a
diag Int
n Vector a
xs = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
MVector s a
vec <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
Vector a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector a
xs ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i a
x -> do
MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
vec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) a
x
MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
vec
{-# INLINE map #-}
map :: (VU.Unbox a, VU.Unbox b) => (a -> b) -> Matrix a -> Matrix b
map :: forall a b. (Unbox a, Unbox b) => (a -> b) -> Matrix a -> Matrix b
map a -> b
f Matrix {Int
Vector a
hM :: forall a. Matrix a -> Int
wM :: forall a. Matrix a -> Int
vecM :: forall a. Matrix a -> Vector a
hM :: Int
wM :: Int
vecM :: Vector a
..} = Int -> Int -> Vector b -> Matrix b
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
hM Int
wM (Vector b -> Matrix b) -> Vector b -> Matrix b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Vector a -> Vector b
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map a -> b
f Vector a
vecM
{-# INLINE mulToCol #-}
mulToCol :: (Num a, VU.Unbox a) => Matrix a -> Col a -> Col a
mulToCol :: forall a. (Num a, Unbox a) => Matrix a -> Col a -> Col a
mulToCol Matrix {Int
Vector a
hM :: forall a. Matrix a -> Int
wM :: forall a. Matrix a -> Int
vecM :: forall a. Matrix a -> Vector a
hM :: Int
wM :: Int
vecM :: Vector a
..} !Vector a
col = Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Vector a -> a) -> Vector (Vector a) -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Vector a -> a
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector a -> a) -> (Vector a -> Vector a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
VU.zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Vector a
col) Vector (Vector a)
rows
where
!n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
col
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
wM) String
"AtCoder.Extra.Matrix.mulToCol: size mismatch"
rows :: Vector (Vector a)
rows = Int
-> (Vector a -> (Vector a, Vector a))
-> Vector a
-> Vector (Vector a)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
hM (Int -> Vector a -> (Vector a, Vector a)
forall a. Unbox a => Int -> Vector a -> (Vector a, Vector a)
VU.splitAt Int
wM) Vector a
vecM
{-# INLINE mul #-}
mul :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul :: forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul !Matrix e
a !Matrix e
b =
Int -> Int -> Vector e -> Matrix e
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector e -> Matrix e) -> Vector e -> Matrix e
forall a b. (a -> b) -> a -> b
$
Int -> ((Int, Int) -> (e, (Int, Int))) -> (Int, Int) -> Vector e
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
(Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
( \(!Int
row, !Int
col) ->
let !x :: e
x = Int -> Int -> e
f Int
row Int
col
in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
then (e
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
else (e
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
)
(Int
0, Int
0)
where
f :: Int -> Int -> e
f Int
row Int
col = Vector e -> e
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector e -> e) -> Vector e -> e
forall a b. (a -> b) -> a -> b
$ (Int -> e -> e) -> Vector e -> Vector e
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap (\Int
iRow e
x -> e
x e -> e -> e
forall a. Num a => a -> a -> a
* Vector e -> Int -> e
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector e
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')) (Int -> Int -> Vector e -> Vector e
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector e
vecA)
h :: Int
h = Matrix e -> Int
forall a. Matrix a -> Int
hM Matrix e
a
w :: Int
w = Matrix e -> Int
forall a. Matrix a -> Int
wM Matrix e
a
h' :: Int
h' = Matrix e -> Int
forall a. Matrix a -> Int
hM Matrix e
b
vecA :: Vector e
vecA = Matrix e -> Vector e
forall a. Matrix a -> Vector a
vecM Matrix e
a
w' :: Int
w' = Matrix e -> Int
forall a. Matrix a -> Int
wM Matrix e
b
vecB :: Vector e
vecB = Matrix e -> Vector e
forall a. Matrix a -> Vector a
vecM Matrix e
b
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mul: matrix size mismatch"
{-# INLINE mulMod #-}
mulMod :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod !Int
m !Matrix Int
a !Matrix Int
b =
Int -> Int -> Vector Int -> Matrix Int
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector Int -> Matrix Int) -> Vector Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$
Int
-> ((Int, Int) -> (Int, (Int, Int))) -> (Int, Int) -> Vector Int
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
(Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
( \(!Int
row, !Int
col) ->
let !x :: Int
x = Int -> Int -> Int
f Int
row Int
col
in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
then (Int
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
else (Int
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
)
(Int
0, Int
0)
where
!bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
f :: Int -> Int -> Int
f Int
row Int
col = (Int -> Int -> Int) -> Vector Int -> Int
forall a. Unbox a => (a -> a -> a) -> Vector a -> a
VU.foldl1' Int -> Int -> Int
addMod (Vector Int -> Int) -> Vector Int -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap (\Int
iRow Int
x -> Int -> Int -> Int
mulMod_ Int
x (Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')))) (Int -> Int -> Vector Int -> Vector Int
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector Int
vecA)
addMod :: Int -> Int -> Int
addMod Int
x Int
y = (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
m
mulMod_ :: Int -> Int -> Int
mulMod_ Int
x Int
y = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ Barrett -> Word64 -> Word64 -> Word64
BT.mulMod Barrett
bt (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y)
h :: Int
h = Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
a
w :: Int
w = Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
a
h' :: Int
h' = Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
b
vecA :: Vector Int
vecA = Matrix Int -> Vector Int
forall a. Matrix a -> Vector a
vecM Matrix Int
a
w' :: Int
w' = Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
b
vecB :: Vector Int
vecB = Matrix Int -> Vector Int
forall a. Matrix a -> Vector a
vecM Matrix Int
b
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mulMod: matrix size mismatch"
{-# INLINE mulMint #-}
mulMint :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMint :: forall (a :: Nat).
KnownNat a =>
Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMint = Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl Barrett
bt
where
!bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Nat -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# a -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @a))
{-# INLINE mulMintImpl #-}
mulMintImpl :: forall a. (KnownNat a) => BT.Barrett -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMintImpl :: forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl !Barrett
bt !Matrix (ModInt a)
a !Matrix (ModInt a)
b =
Int -> Int -> Vector (ModInt a) -> Matrix (ModInt a)
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector (ModInt a) -> Matrix (ModInt a))
-> Vector (ModInt a) -> Matrix (ModInt a)
forall a b. (a -> b) -> a -> b
$
Int
-> ((Int, Int) -> (ModInt a, (Int, Int)))
-> (Int, Int)
-> Vector (ModInt a)
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
(Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
( \(!Int
row, !Int
col) ->
let !x :: ModInt a
x = Int -> Int -> ModInt a
f Int
row Int
col
in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
then (ModInt a
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
else (ModInt a
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
)
(Int
0, Int
0)
where
f :: Int -> Int -> M.ModInt a
f :: Int -> Int -> ModInt a
f Int
row Int
col = Vector (ModInt a) -> ModInt a
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector (ModInt a) -> ModInt a) -> Vector (ModInt a) -> ModInt a
forall a b. (a -> b) -> a -> b
$ (Int -> ModInt a -> ModInt a)
-> Vector (ModInt a) -> Vector (ModInt a)
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap (\Int
iRow ModInt a
x -> ModInt a -> ModInt a -> ModInt a
mulMod_ ModInt a
x (Vector (ModInt a) -> Int -> ModInt a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector (ModInt a)
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')))) (Int -> Int -> Vector (ModInt a) -> Vector (ModInt a)
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector (ModInt a)
vecA)
mulMod_ :: M.ModInt a -> M.ModInt a -> M.ModInt a
mulMod_ :: ModInt a -> ModInt a -> ModInt a
mulMod_ (M.ModInt Word32
x) (M.ModInt Word32
y) = Word32 -> ModInt a
forall (a :: Nat). KnownNat a => Word32 -> ModInt a
M.unsafeNew (Word32 -> ModInt a) -> (Word64 -> Word32) -> Word64 -> ModInt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> ModInt a) -> Word64 -> ModInt a
forall a b. (a -> b) -> a -> b
$ Barrett -> Word64 -> Word64 -> Word64
BT.mulMod Barrett
bt (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x) (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
y)
h :: Int
h = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt a)
a
w :: Int
w = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt a)
a
h' :: Int
h' = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt a)
b
vecA :: Vector (ModInt a)
vecA = Matrix (ModInt a) -> Vector (ModInt a)
forall a. Matrix a -> Vector a
vecM Matrix (ModInt a)
a
w' :: Int
w' = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt a)
b
vecB :: Vector (ModInt a)
vecB = Matrix (ModInt a) -> Vector (ModInt a)
forall a. Matrix a -> Vector a
vecM Matrix (ModInt a)
b
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mulMint: matrix size mismatch"
{-# INLINE pow #-}
pow :: Int -> Matrix Int -> Matrix Int
pow :: Int -> Matrix Int -> Matrix Int
pow Int
k Matrix Int
mat
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix Int
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMod: the exponential must be non-negative"
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix Int
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix Int) -> Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$ Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat
| Bool
otherwise = (Matrix Int -> Matrix Int -> Matrix Int)
-> Int -> Matrix Int -> Matrix Int
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power Matrix Int -> Matrix Int -> Matrix Int
forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul Int
k Matrix Int
mat
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
mat) String
"AtCoder.Extra.Matrix.powMod: matrix size mismatch"
{-# INLINE powMod #-}
powMod :: Int -> Int -> Matrix Int -> Matrix Int
powMod :: Int -> Int -> Matrix Int -> Matrix Int
powMod Int
m Int
k Matrix Int
mat
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix Int
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMod: the exponential must be non-negative"
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix Int
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix Int) -> Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$ Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat
| Bool
otherwise = (Matrix Int -> Matrix Int -> Matrix Int)
-> Int -> Matrix Int -> Matrix Int
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power (Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod Int
m) Int
k Matrix Int
mat
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
mat) String
"AtCoder.Extra.Matrix.powMod: matrix size mismatch"
powMint :: forall m. (KnownNat m) => Int -> Matrix (M.ModInt m) -> Matrix (M.ModInt m)
powMint :: forall (m :: Nat).
KnownNat m =>
Int -> Matrix (ModInt m) -> Matrix (ModInt m)
powMint Int
k Matrix (ModInt m)
mat
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix (ModInt m)
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMint: the exponential must be non-negative"
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix (ModInt m)
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix (ModInt m)) -> Int -> Matrix (ModInt m)
forall a b. (a -> b) -> a -> b
$ Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt m)
mat
| Bool
otherwise = (Matrix (ModInt m) -> Matrix (ModInt m) -> Matrix (ModInt m))
-> Int -> Matrix (ModInt m) -> Matrix (ModInt m)
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power (Barrett
-> Matrix (ModInt m) -> Matrix (ModInt m) -> Matrix (ModInt m)
forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl Barrett
bt) Int
k Matrix (ModInt m)
mat
where
!()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt m)
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt m)
mat) String
"AtCoder.Extra.Matrix.powMint: matrix size mismatch"
!bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Nat -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# m -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @m))
instance (Num a, VU.Unbox a) => Semigroup (Matrix a) where
{-# INLINE (<>) #-}
<> :: Matrix a -> Matrix a -> Matrix a
(<>) = Matrix a -> Matrix a -> Matrix a
forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul
{-# INLINE stimes #-}
stimes :: forall b. Integral b => b -> Matrix a -> Matrix a
stimes = (Matrix a -> Matrix a -> Matrix a) -> Int -> Matrix a -> Matrix a
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power Matrix a -> Matrix a -> Matrix a
forall a. Semigroup a => a -> a -> a
(<>) (Int -> Matrix a -> Matrix a)
-> (b -> Int) -> b -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral