{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}

-- | A simple HxW matrix backed by a vector, mainly for binary exponention.
--
-- The matrix is a left semigroup action: \(m_2 (m_1 v) = (m_2 \circ m_1) v\).
--
-- @since 1.1.0.0
module AtCoder.Extra.Semigroup.Matrix
  ( -- * Matrix
    Matrix (..),

    -- * Constructors
    new,
    zero,
    ident,
    diag,

    -- * Mapping
    map,

    -- * Multiplications
    mulToCol,
    mul,
    mulMod,
    mulMint,

    -- * Powers
    pow,
    powMod,
    powMint,
  )
where

import AtCoder.Extra.Math qualified as ACEM
import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Barrett qualified as BT
import AtCoder.ModInt qualified as M
import Data.Foldable (for_)
import Data.Semigroup (Semigroup (..))
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Exts (proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, natVal')
import Prelude hiding (map)

-- | A simple HxW matrix backed by a vector, mainly for binary exponention.
--
-- The matrix is a left semigroup action: \(m_2 (m_1 v) = (m_2 \circ m_1) v\).
--
--
-- @since 1.1.0.0
data Matrix a = Matrix
  { -- | @since 1.1.0.0
    forall a. Matrix a -> Int
hM :: {-# UNPACK #-} !Int,
    -- | @since 1.1.0.0
    forall a. Matrix a -> Int
wM :: {-# UNPACK #-} !Int,
    -- | @since 1.1.0.0
    forall a. Matrix a -> Vector a
vecM :: !(VU.Vector a)
  }
  deriving
    ( -- | @since 1.1.0.0
      Int -> Matrix a -> ShowS
[Matrix a] -> ShowS
Matrix a -> String
(Int -> Matrix a -> ShowS)
-> (Matrix a -> String) -> ([Matrix a] -> ShowS) -> Show (Matrix a)
forall a. (Show a, Unbox a) => Int -> Matrix a -> ShowS
forall a. (Show a, Unbox a) => [Matrix a] -> ShowS
forall a. (Show a, Unbox a) => Matrix a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. (Show a, Unbox a) => Int -> Matrix a -> ShowS
showsPrec :: Int -> Matrix a -> ShowS
$cshow :: forall a. (Show a, Unbox a) => Matrix a -> String
show :: Matrix a -> String
$cshowList :: forall a. (Show a, Unbox a) => [Matrix a] -> ShowS
showList :: [Matrix a] -> ShowS
Show,
      -- | @since 1.1.0.0
      Matrix a -> Matrix a -> Bool
(Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool) -> Eq (Matrix a)
forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
== :: Matrix a -> Matrix a -> Bool
$c/= :: forall a. (Unbox a, Eq a) => Matrix a -> Matrix a -> Bool
/= :: Matrix a -> Matrix a -> Bool
Eq
    )

-- | Type alias of a column vector.
--
-- @since 1.1.0.0
type Col a = VU.Vector a

-- | \(O(hw)\) Creates an HxW matrix.
--
-- @since 1.1.0.0
{-# INLINE new #-}
new :: (HasCallStack, VU.Unbox a) => Int -> Int -> VU.Vector a -> Matrix a
new :: forall a.
(HasCallStack, Unbox a) =>
Int -> Int -> Vector a -> Matrix a
new Int
h Int
w Vector a
vec
  | Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
vec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w = String -> Matrix a
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix: size mismatch"
  | Bool
otherwise = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w Vector a
vec

-- | \(O(n^2)\) Creates an NxN zero matrix.
--
-- @since 1.1.0.0
{-# INLINE zero #-}
zero :: (VU.Unbox a, Num a) => Int -> Matrix a
zero :: forall a. (Unbox a, Num a) => Int -> Matrix a
zero Int
n = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
VU.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0

-- | \(O(n^2)\) Creates an NxN identity matrix.
--
-- @since 1.1.0.0
{-# INLINE ident #-}
ident :: (VU.Unbox a, Num a) => Int -> Matrix a
ident :: forall a. (Unbox a, Num a) => Int -> Matrix a
ident Int
n = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vec <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
vec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) a
1
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
vec

-- | \(O(n^2)\) Creates an NxN diagonal matrix.
--
-- @since 1.1.0.0
{-# INLINE diag #-}
diag :: (VU.Unbox a, Num a) => Int -> VU.Vector a -> Matrix a
diag :: forall a. (Unbox a, Num a) => Int -> Vector a -> Matrix a
diag Int
n Vector a
xs = Int -> Int -> Vector a -> Matrix a
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
n Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vec <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) a
0
  Vector a -> (Int -> a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector a
xs ((Int -> a -> ST s ()) -> ST s ())
-> (Int -> a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i a
x -> do
    MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s a
MVector (PrimState (ST s)) a
vec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
i) a
x
  MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s a
vec

-- | \(O(n^2)\) Maps the `Matrix`.
--
-- @since 1.1.0.0
{-# INLINE map #-}
map :: (VU.Unbox a, VU.Unbox b) => (a -> b) -> Matrix a -> Matrix b
map :: forall a b. (Unbox a, Unbox b) => (a -> b) -> Matrix a -> Matrix b
map a -> b
f Matrix {Int
Vector a
hM :: forall a. Matrix a -> Int
wM :: forall a. Matrix a -> Int
vecM :: forall a. Matrix a -> Vector a
hM :: Int
wM :: Int
vecM :: Vector a
..} = Int -> Int -> Vector b -> Matrix b
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
hM Int
wM (Vector b -> Matrix b) -> Vector b -> Matrix b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Vector a -> Vector b
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map a -> b
f Vector a
vecM

-- | \(O(hw)\) Multiplies HxW matrix to a Hx1 column vector.
--
-- @since 1.1.0.0
{-# INLINE mulToCol #-}
mulToCol :: (Num a, VU.Unbox a) => Matrix a -> Col a -> Col a
mulToCol :: forall a. (Num a, Unbox a) => Matrix a -> Col a -> Col a
mulToCol Matrix {Int
Vector a
hM :: forall a. Matrix a -> Int
wM :: forall a. Matrix a -> Int
vecM :: forall a. Matrix a -> Vector a
hM :: Int
wM :: Int
vecM :: Vector a
..} !Vector a
col = Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ (Vector a -> a) -> Vector (Vector a) -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Vector a -> a
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector a -> a) -> (Vector a -> Vector a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
VU.zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Vector a
col) Vector (Vector a)
rows
  where
    !n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
col
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
wM) String
"AtCoder.Extra.Matrix.mulToCol: size mismatch"
    rows :: Vector (Vector a)
rows = Int
-> (Vector a -> (Vector a, Vector a))
-> Vector a
-> Vector (Vector a)
forall b a. Int -> (b -> (a, b)) -> b -> Vector a
V.unfoldrExactN Int
hM (Int -> Vector a -> (Vector a, Vector a)
forall a. Unbox a => Int -> Vector a -> (Vector a, Vector a)
VU.splitAt Int
wM) Vector a
vecM

-- | \(O(h_1 K w_2)\) Multiplies H1xK matrix to a KxW2 matrix.
--
-- @since 1.1.0.0
{-# INLINE mul #-}
mul :: (Num e, VU.Unbox e) => Matrix e -> Matrix e -> Matrix e
mul :: forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul !Matrix e
a !Matrix e
b =
  Int -> Int -> Vector e -> Matrix e
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector e -> Matrix e) -> Vector e -> Matrix e
forall a b. (a -> b) -> a -> b
$
    Int -> ((Int, Int) -> (e, (Int, Int))) -> (Int, Int) -> Vector e
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
      (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
      ( \(!Int
row, !Int
col) ->
          let !x :: e
x = Int -> Int -> e
f Int
row Int
col
           in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
                then (e
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
                else (e
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
      )
      (Int
0, Int
0)
  where
    f :: Int -> Int -> e
f Int
row Int
col = Vector e -> e
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector e -> e) -> Vector e -> e
forall a b. (a -> b) -> a -> b
$ (Int -> e -> e) -> Vector e -> Vector e
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap (\Int
iRow e
x -> e
x e -> e -> e
forall a. Num a => a -> a -> a
* Vector e -> Int -> e
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector e
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')) (Int -> Int -> Vector e -> Vector e
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector e
vecA)
    h :: Int
h = Matrix e -> Int
forall a. Matrix a -> Int
hM Matrix e
a
    w :: Int
w = Matrix e -> Int
forall a. Matrix a -> Int
wM Matrix e
a
    h' :: Int
h' = Matrix e -> Int
forall a. Matrix a -> Int
hM Matrix e
b
    vecA :: Vector e
vecA = Matrix e -> Vector e
forall a. Matrix a -> Vector a
vecM Matrix e
a
    w' :: Int
w' = Matrix e -> Int
forall a. Matrix a -> Int
wM Matrix e
b
    vecB :: Vector e
vecB = Matrix e -> Vector e
forall a. Matrix a -> Vector a
vecM Matrix e
b
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mul: matrix size mismatch"

-- | \(O(h_1 w_2 K)\) Multiplies H1xK matrix to a KxW2 matrix, taking the mod.
--
-- @since 1.1.0.0
{-# INLINE mulMod #-}
mulMod :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod :: Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod !Int
m !Matrix Int
a !Matrix Int
b =
  Int -> Int -> Vector Int -> Matrix Int
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector Int -> Matrix Int) -> Vector Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$
    Int
-> ((Int, Int) -> (Int, (Int, Int))) -> (Int, Int) -> Vector Int
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
      (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
      ( \(!Int
row, !Int
col) ->
          let !x :: Int
x = Int -> Int -> Int
f Int
row Int
col
           in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
                then (Int
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
                else (Int
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
      )
      (Int
0, Int
0)
  where
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
    f :: Int -> Int -> Int
f Int
row Int
col = (Int -> Int -> Int) -> Vector Int -> Int
forall a. Unbox a => (a -> a -> a) -> Vector a -> a
VU.foldl1' Int -> Int -> Int
addMod (Vector Int -> Int) -> Vector Int -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap (\Int
iRow Int
x -> Int -> Int -> Int
mulMod_ Int
x (Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')))) (Int -> Int -> Vector Int -> Vector Int
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector Int
vecA)
    addMod :: Int -> Int -> Int
addMod Int
x Int
y = (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
m
    mulMod_ :: Int -> Int -> Int
mulMod_ Int
x Int
y = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ Barrett -> Word64 -> Word64 -> Word64
BT.mulMod Barrett
bt (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
y)
    h :: Int
h = Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
a
    w :: Int
w = Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
a
    h' :: Int
h' = Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
b
    vecA :: Vector Int
vecA = Matrix Int -> Vector Int
forall a. Matrix a -> Vector a
vecM Matrix Int
a
    w' :: Int
w' = Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
b
    vecB :: Vector Int
vecB = Matrix Int -> Vector Int
forall a. Matrix a -> Vector a
vecM Matrix Int
b
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mulMod: matrix size mismatch"

-- | \(O(h_1 w_2 K)\) `mul` specialized to `M.ModInt`.
--
-- @since 1.1.0.0
{-# INLINE mulMint #-}
mulMint :: forall a. (KnownNat a) => Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMint :: forall (a :: Nat).
KnownNat a =>
Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMint = Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl Barrett
bt
  where
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Nat -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# a -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @a))

{-# INLINE mulMintImpl #-}
mulMintImpl :: forall a. (KnownNat a) => BT.Barrett -> Matrix (M.ModInt a) -> Matrix (M.ModInt a) -> Matrix (M.ModInt a)
mulMintImpl :: forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl !Barrett
bt !Matrix (ModInt a)
a !Matrix (ModInt a)
b =
  Int -> Int -> Vector (ModInt a) -> Matrix (ModInt a)
forall a. Int -> Int -> Vector a -> Matrix a
Matrix Int
h Int
w' (Vector (ModInt a) -> Matrix (ModInt a))
-> Vector (ModInt a) -> Matrix (ModInt a)
forall a b. (a -> b) -> a -> b
$
    Int
-> ((Int, Int) -> (ModInt a, (Int, Int)))
-> (Int, Int)
-> Vector (ModInt a)
forall a b. Unbox a => Int -> (b -> (a, b)) -> b -> Vector a
VU.unfoldrExactN
      (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')
      ( \(!Int
row, !Int
col) ->
          let !x :: ModInt a
x = Int -> Int -> ModInt a
f Int
row Int
col
           in if Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
w'
                then (ModInt a
x, (Int
row Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
0))
                else (ModInt a
x, (Int
row, Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
      )
      (Int
0, Int
0)
  where
    f :: Int -> Int -> M.ModInt a
    f :: Int -> Int -> ModInt a
f Int
row Int
col = Vector (ModInt a) -> ModInt a
forall a. (Unbox a, Num a) => Vector a -> a
VU.sum (Vector (ModInt a) -> ModInt a) -> Vector (ModInt a) -> ModInt a
forall a b. (a -> b) -> a -> b
$ (Int -> ModInt a -> ModInt a)
-> Vector (ModInt a) -> Vector (ModInt a)
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
VU.imap (\Int
iRow ModInt a
x -> ModInt a -> ModInt a -> ModInt a
mulMod_ ModInt a
x (Vector (ModInt a) -> Int -> ModInt a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector (ModInt a)
vecB (Int
col Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
iRow Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
w')))) (Int -> Int -> Vector (ModInt a) -> Vector (ModInt a)
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
VU.unsafeSlice (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
row) Int
w Vector (ModInt a)
vecA)
    mulMod_ :: M.ModInt a -> M.ModInt a -> M.ModInt a
    mulMod_ :: ModInt a -> ModInt a -> ModInt a
mulMod_ (M.ModInt Word32
x) (M.ModInt Word32
y) = Word32 -> ModInt a
forall (a :: Nat). KnownNat a => Word32 -> ModInt a
M.unsafeNew (Word32 -> ModInt a) -> (Word64 -> Word32) -> Word64 -> ModInt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> ModInt a) -> Word64 -> ModInt a
forall a b. (a -> b) -> a -> b
$ Barrett -> Word64 -> Word64 -> Word64
BT.mulMod Barrett
bt (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x) (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
y)
    h :: Int
h = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt a)
a
    w :: Int
w = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt a)
a
    h' :: Int
h' = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt a)
b
    vecA :: Vector (ModInt a)
vecA = Matrix (ModInt a) -> Vector (ModInt a)
forall a. Matrix a -> Vector a
vecM Matrix (ModInt a)
a
    w' :: Int
w' = Matrix (ModInt a) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt a)
b
    vecB :: Vector (ModInt a)
vecB = Matrix (ModInt a) -> Vector (ModInt a)
forall a. Matrix a -> Vector a
vecM Matrix (ModInt a)
b
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
h') String
"AtCoder.Extra.Matrix.mulMint: matrix size mismatch"

-- | \(O(w n^3)\) Calculates \(M^k\).
--
-- @since 1.1.0.0
{-# INLINE pow #-}
pow :: Int -> Matrix Int -> Matrix Int
pow :: Int -> Matrix Int -> Matrix Int
pow Int
k Matrix Int
mat
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix Int
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMod: the exponential must be non-negative"
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix Int
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix Int) -> Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$ Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat
  | Bool
otherwise = (Matrix Int -> Matrix Int -> Matrix Int)
-> Int -> Matrix Int -> Matrix Int
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power Matrix Int -> Matrix Int -> Matrix Int
forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul Int
k Matrix Int
mat
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
mat) String
"AtCoder.Extra.Matrix.powMod: matrix size mismatch"

-- | \(O(w n^3)\) Calculates \(M^k\), taking the mod.
--
-- @since 1.1.0.0
{-# INLINE powMod #-}
powMod :: Int -> Int -> Matrix Int -> Matrix Int
powMod :: Int -> Int -> Matrix Int -> Matrix Int
powMod Int
m Int
k Matrix Int
mat
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix Int
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMod: the exponential must be non-negative"
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix Int
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix Int) -> Int -> Matrix Int
forall a b. (a -> b) -> a -> b
$ Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat
  | Bool
otherwise = (Matrix Int -> Matrix Int -> Matrix Int)
-> Int -> Matrix Int -> Matrix Int
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power (Int -> Matrix Int -> Matrix Int -> Matrix Int
mulMod Int
m) Int
k Matrix Int
mat
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix Int -> Int
forall a. Matrix a -> Int
hM Matrix Int
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Int -> Int
forall a. Matrix a -> Int
wM Matrix Int
mat) String
"AtCoder.Extra.Matrix.powMod: matrix size mismatch"

-- | \(O(w n^3)\) Calculates \(M^k\), specialized to `M.ModInt`.
--
-- @since 1.1.0.0
powMint :: forall m. (KnownNat m) => Int -> Matrix (M.ModInt m) -> Matrix (M.ModInt m)
powMint :: forall (m :: Nat).
KnownNat m =>
Int -> Matrix (ModInt m) -> Matrix (ModInt m)
powMint Int
k Matrix (ModInt m)
mat
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> Matrix (ModInt m)
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Matrix.powMint: the exponential must be non-negative"
  | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Matrix (ModInt m)
forall a. (Unbox a, Num a) => Int -> Matrix a
ident (Int -> Matrix (ModInt m)) -> Int -> Matrix (ModInt m)
forall a b. (a -> b) -> a -> b
$ Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt m)
mat
  | Bool
otherwise = (Matrix (ModInt m) -> Matrix (ModInt m) -> Matrix (ModInt m))
-> Int -> Matrix (ModInt m) -> Matrix (ModInt m)
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power (Barrett
-> Matrix (ModInt m) -> Matrix (ModInt m) -> Matrix (ModInt m)
forall (a :: Nat).
KnownNat a =>
Barrett
-> Matrix (ModInt a) -> Matrix (ModInt a) -> Matrix (ModInt a)
mulMintImpl Barrett
bt) Int
k Matrix (ModInt m)
mat
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
hM Matrix (ModInt m)
mat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix (ModInt m) -> Int
forall a. Matrix a -> Int
wM Matrix (ModInt m)
mat) String
"AtCoder.Extra.Matrix.powMint: matrix size mismatch"
    !bt :: Barrett
bt = Word32 -> Barrett
BT.new32 (Word32 -> Barrett) -> Word32 -> Barrett
forall a b. (a -> b) -> a -> b
$ Nat -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# m -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @m))

-- | @since 1.1.0.0
instance (Num a, VU.Unbox a) => Semigroup (Matrix a) where
  {-# INLINE (<>) #-}
  <> :: Matrix a -> Matrix a -> Matrix a
(<>) = Matrix a -> Matrix a -> Matrix a
forall e. (Num e, Unbox e) => Matrix e -> Matrix e -> Matrix e
mul
  {-# INLINE stimes #-}
  stimes :: forall b. Integral b => b -> Matrix a -> Matrix a
stimes = (Matrix a -> Matrix a -> Matrix a) -> Int -> Matrix a -> Matrix a
forall a. (a -> a -> a) -> Int -> a -> a
ACEM.power Matrix a -> Matrix a -> Matrix a
forall a. Semigroup a => a -> a -> a
(<>) (Int -> Matrix a -> Matrix a)
-> (b -> Int) -> b -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral