{-# LANGUAGE TypeFamilies #-}

-- | Monoid action \(f: x \rightarrow ax + b\). Less efficient than @Affine1@, but compatible with
-- inverse opereations.
--
-- @since 1.1.0.0
module AtCoder.Extra.Monoid.Mat2x2
  ( -- * Mat2x2
    Mat2x2 (..),
    Mat2x2Repr,

    -- * Constructors
    new,
    unMat2x2,
    ident,
    zero,

    -- * Actions
    act,

    -- * Operators
    map,
    det,
    inv,
  )
where

import AtCoder.Extra.Math qualified as ACEM
import AtCoder.Extra.Monoid.V2 (V2 (..))
import AtCoder.LazySegTree (SegAct (..))
import Data.Semigroup (Dual (..), Semigroup (..))
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import GHC.Stack (HasCallStack)
import Prelude hiding (map)

-- | Monoid action \(f: x \rightarrow ax + b\). Less efficient than @Affine1@, but compatible with
-- inverse opereations.
--
-- ==== Composition and dual
-- The affine transformation acts as a left monoid action: \(f_2 (f_1 v) = (f_2 \circ f_1) v\). To
-- apply the leftmost transformation first in a segment tree, wrap `Mat2x2` in @Data.Monoid.Dual@.
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Monoid.Mat2x2 qualified as Mat2x2
-- >>> import AtCoder.Extra.Monoid.V2 qualified as V2
-- >>> import AtCoder.Extra.Monoid (SegAct(..), Mat2x2(..), V2(..))
-- >>> import AtCoder.LazySegTree qualified as LST
-- >>> seg <- LST.build @_ @(Mat2x2 Int) @(V2 Int) $ VU.generate 3 V2.new -- [0, 1, 2]
-- >>> LST.applyIn seg 0 3 $ Mat2x2.new 2 1 -- [1, 3, 5]
-- >>> V2.unV2 <$> LST.allProd seg
-- 9
--
-- @since 1.1.0.0
newtype Mat2x2 a = Mat2x2 (Mat2x2Repr a)
  deriving newtype
    ( -- | @since 1.1.0.0
      Mat2x2 a -> Mat2x2 a -> Bool
(Mat2x2 a -> Mat2x2 a -> Bool)
-> (Mat2x2 a -> Mat2x2 a -> Bool) -> Eq (Mat2x2 a)
forall a. Eq a => Mat2x2 a -> Mat2x2 a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Mat2x2 a -> Mat2x2 a -> Bool
== :: Mat2x2 a -> Mat2x2 a -> Bool
$c/= :: forall a. Eq a => Mat2x2 a -> Mat2x2 a -> Bool
/= :: Mat2x2 a -> Mat2x2 a -> Bool
Eq,
      -- | @since 1.1.0.0
      Eq (Mat2x2 a)
Eq (Mat2x2 a) =>
(Mat2x2 a -> Mat2x2 a -> Ordering)
-> (Mat2x2 a -> Mat2x2 a -> Bool)
-> (Mat2x2 a -> Mat2x2 a -> Bool)
-> (Mat2x2 a -> Mat2x2 a -> Bool)
-> (Mat2x2 a -> Mat2x2 a -> Bool)
-> (Mat2x2 a -> Mat2x2 a -> Mat2x2 a)
-> (Mat2x2 a -> Mat2x2 a -> Mat2x2 a)
-> Ord (Mat2x2 a)
Mat2x2 a -> Mat2x2 a -> Bool
Mat2x2 a -> Mat2x2 a -> Ordering
Mat2x2 a -> Mat2x2 a -> Mat2x2 a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Mat2x2 a)
forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Bool
forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Ordering
forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Mat2x2 a
$ccompare :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Ordering
compare :: Mat2x2 a -> Mat2x2 a -> Ordering
$c< :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Bool
< :: Mat2x2 a -> Mat2x2 a -> Bool
$c<= :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Bool
<= :: Mat2x2 a -> Mat2x2 a -> Bool
$c> :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Bool
> :: Mat2x2 a -> Mat2x2 a -> Bool
$c>= :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Bool
>= :: Mat2x2 a -> Mat2x2 a -> Bool
$cmax :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Mat2x2 a
max :: Mat2x2 a -> Mat2x2 a -> Mat2x2 a
$cmin :: forall a. Ord a => Mat2x2 a -> Mat2x2 a -> Mat2x2 a
min :: Mat2x2 a -> Mat2x2 a -> Mat2x2 a
Ord,
      -- | @since 1.1.0.0
      Int -> Mat2x2 a -> ShowS
[Mat2x2 a] -> ShowS
Mat2x2 a -> String
(Int -> Mat2x2 a -> ShowS)
-> (Mat2x2 a -> String) -> ([Mat2x2 a] -> ShowS) -> Show (Mat2x2 a)
forall a. Show a => Int -> Mat2x2 a -> ShowS
forall a. Show a => [Mat2x2 a] -> ShowS
forall a. Show a => Mat2x2 a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Mat2x2 a -> ShowS
showsPrec :: Int -> Mat2x2 a -> ShowS
$cshow :: forall a. Show a => Mat2x2 a -> String
show :: Mat2x2 a -> String
$cshowList :: forall a. Show a => [Mat2x2 a] -> ShowS
showList :: [Mat2x2 a] -> ShowS
Show
    )

-- | `Mat2x2` internal representation. Tuples are not the fastest representation, but it's easier
-- to implement `Data.Vector.Unboxed.Unbox`.
--
-- @since 1.1.0.0
type Mat2x2Repr a = (a, a, a, a)

-- | \(O(1)\) Creates a one-dimensional affine transformation: \(f: x \rightarrow a \times x + b\).
--
-- @since 1.1.0.0
{-# INLINE new #-}
new :: (Num a) => a -> a -> Mat2x2 a
new :: forall a. Num a => a -> a -> Mat2x2 a
new !a
a !a
b = Mat2x2Repr a -> Mat2x2 a
forall a. Mat2x2Repr a -> Mat2x2 a
Mat2x2 (a
a, a
b, a
0, a
1)

-- | \(O(1)\) Retrieves the four components of `Mat2x2`.
--
-- @since 1.1.0.0
{-# INLINE unMat2x2 #-}
unMat2x2 :: Mat2x2 a -> Mat2x2Repr a
unMat2x2 :: forall a. Mat2x2 a -> Mat2x2Repr a
unMat2x2 (Mat2x2 Mat2x2Repr a
a) = Mat2x2Repr a
a

-- | \(O(1)\) Transformation to zero.
--
-- @since 1.1.0.0
{-# INLINE zero #-}
zero :: (Num a) => Mat2x2 a
zero :: forall a. Num a => Mat2x2 a
zero = Mat2x2Repr a -> Mat2x2 a
forall a. Mat2x2Repr a -> Mat2x2 a
Mat2x2 (a
0, a
0, a
0, a
0)

-- | \(O(1)\) Identity transformation.
--
-- @since 1.1.0.0
{-# INLINE ident #-}
ident :: (Num a) => Mat2x2 a
ident :: forall a. Num a => Mat2x2 a
ident = Mat2x2Repr a -> Mat2x2 a
forall a. Mat2x2Repr a -> Mat2x2 a
Mat2x2 (a
1, a
0, a
0, a
1)

-- | \(O(1)\) Multiplies `Mat2x2` to `V2`.
{-# INLINE mulMV #-}
mulMV :: (Num a) => Mat2x2 a -> V2 a -> V2 a
mulMV :: forall a. Num a => Mat2x2 a -> V2 a -> V2 a
mulMV (Mat2x2 (!a
a11, !a
a12, !a
a21, !a
a22)) (V2 (!a
x1, !a
x2)) = (a, a) -> V2 a
forall a. V2Repr a -> V2 a
V2 (a
a', a
b')
  where
    !a' :: a
a' = a
a11 a -> a -> a
forall a. Num a => a -> a -> a
* a
x1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12 a -> a -> a
forall a. Num a => a -> a -> a
* a
x2
    !b' :: a
b' = a
a21 a -> a -> a
forall a. Num a => a -> a -> a
* a
x1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22 a -> a -> a
forall a. Num a => a -> a -> a
* a
x2

-- | \(O(1)\) Multiplies `Mat2x2` to `Mat2x2`.
{-# INLINE mulMM #-}
mulMM :: (Num a) => Mat2x2 a -> Mat2x2 a -> Mat2x2 a
mulMM :: forall a. Num a => Mat2x2 a -> Mat2x2 a -> Mat2x2 a
mulMM (Mat2x2 (!a
a11, !a
a12, !a
a21, !a
a22)) (Mat2x2 (!a
b11, !a
b12, !a
b21, !a
b22)) = (a, a, a, a) -> Mat2x2 a
forall a. Mat2x2Repr a -> Mat2x2 a
Mat2x2 (a
c11, a
c12, a
c21, a
c22)
  where
    !c11 :: a
c11 = a
a11 a -> a -> a
forall a. Num a => a -> a -> a
* a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12 a -> a -> a
forall a. Num a => a -> a -> a
* a
b21
    !c12 :: a
c12 = a
a11 a -> a -> a
forall a. Num a => a -> a -> a
* a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a12 a -> a -> a
forall a. Num a => a -> a -> a
* a
b22
    !c21 :: a
c21 = a
a21 a -> a -> a
forall a. Num a => a -> a -> a
* a
b11 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22 a -> a -> a
forall a. Num a => a -> a -> a
* a
b21
    !c22 :: a
c22 = a
a21 a -> a -> a
forall a. Num a => a -> a -> a
* a
b12 a -> a -> a
forall a. Num a => a -> a -> a
+ a
a22 a -> a -> a
forall a. Num a => a -> a -> a
* a
b22

-- | \(O(1)\) Multiplies `Mat2x2` to `V2`.
--
-- @since 1.1.0.0
{-# INLINE act #-}
act :: (Num a) => Mat2x2 a -> V2 a -> V2 a
act :: forall a. Num a => Mat2x2 a -> V2 a -> V2 a
act = Mat2x2 a -> V2 a -> V2 a
forall a. Num a => Mat2x2 a -> V2 a -> V2 a
mulMV

-- | \(O(1)\) Maps the every component of `Mat2x2`.
--
-- @since 1.1.0.0
{-# INLINE map #-}
map :: (a -> b) -> Mat2x2 a -> Mat2x2 b
map :: forall a b. (a -> b) -> Mat2x2 a -> Mat2x2 b
map a -> b
f (Mat2x2 (!a
a11, !a
a12, !a
a21, !a
a22)) = Mat2x2Repr b -> Mat2x2 b
forall a. Mat2x2Repr a -> Mat2x2 a
Mat2x2 (b
a11', b
a12', b
a21', b
a22')
  where
    !a11' :: b
a11' = a -> b
f a
a11
    !a12' :: b
a12' = a -> b
f a
a12
    !a21' :: b
a21' = a -> b
f a
a21
    !a22' :: b
a22' = a -> b
f a
a22

-- | \(O(1)\) Returns the determinan of the matrix.
--
-- @since 1.1.0.0
{-# INLINE det #-}
det :: (Fractional e) => Mat2x2 e -> e
det :: forall e. Fractional e => Mat2x2 e -> e
det (Mat2x2 (!e
a, !e
b, !e
c, !e
d)) = e
a e -> e -> e
forall a. Num a => a -> a -> a
* e
d e -> e -> e
forall a. Num a => a -> a -> a
- e
b e -> e -> e
forall a. Num a => a -> a -> a
* e
c

-- | \(O(1)\) Returns the inverse matrix, based on `Fractional` instance (mainly for @ModInt@).
--
-- ==== Constraints
-- - The determinant (`det`) of the matrix must be non-zero, otherwise an error is thrown.
--
-- @since 1.1.0.0
{-# INLINE inv #-}
inv :: (HasCallStack, Fractional e, Eq e) => Mat2x2 e -> Mat2x2 e
inv :: forall e.
(HasCallStack, Fractional e, Eq e) =>
Mat2x2 e -> Mat2x2 e
inv (Mat2x2 (!e
a, !e
b, !e
c, !e
d)) = (e, e, e, e) -> Mat2x2 e
forall a. Mat2x2Repr a -> Mat2x2 a
Mat2x2 (e
a', e
b', e
c', e
d')
  where
    -- NOTE: zero division
    -- !r = recip $ a * d - b * c
    !r :: e
r
      | e
det_ e -> e -> Bool
forall a. Eq a => a -> a -> Bool
== e
0 = String -> e
forall a. HasCallStack => String -> a
error String
"AtCoder.Extra.Mat2x2.inv: the determinant of the matrix must be non zero"
      | Bool
otherwise = e -> e
forall a. Fractional a => a -> a
recip e
det_
      where
        !det_ :: e
det_ = e
a e -> e -> e
forall a. Num a => a -> a -> a
* e
d e -> e -> e
forall a. Num a => a -> a -> a
- e
b e -> e -> e
forall a. Num a => a -> a -> a
* e
c
    !a' :: e
a' = e
r e -> e -> e
forall a. Num a => a -> a -> a
* e
d
    !b' :: e
b' = e
r e -> e -> e
forall a. Num a => a -> a -> a
* (-e
b)
    !c' :: e
c' = e
r e -> e -> e
forall a. Num a => a -> a -> a
* (-e
c)
    !d' :: e
d' = e
r e -> e -> e
forall a. Num a => a -> a -> a
* e
a

-- | @since 1.1.0.0
instance (Num a) => Semigroup (Mat2x2 a) where
  {-# INLINE (<>) #-}
  <> :: Mat2x2 a -> Mat2x2 a -> Mat2x2 a
(<>) = Mat2x2 a -> Mat2x2 a -> Mat2x2 a
forall a. Num a => Mat2x2 a -> Mat2x2 a -> Mat2x2 a
mulMM
  {-# INLINE stimes #-}
  stimes :: forall b. Integral b => b -> Mat2x2 a -> Mat2x2 a
stimes = Int -> Mat2x2 a -> Mat2x2 a
forall a. Semigroup a => Int -> a -> a
ACEM.stimes' (Int -> Mat2x2 a -> Mat2x2 a)
-> (b -> Int) -> b -> Mat2x2 a -> Mat2x2 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

-- | @since 1.1.0.0
instance (Num a) => Monoid (Mat2x2 a) where
  {-# INLINE mempty #-}
  mempty :: Mat2x2 a
mempty = Mat2x2 a
forall a. Num a => Mat2x2 a
ident

-- | @since 1.1.0.0
instance (Num a) => SegAct (Mat2x2 a) (V2 a) where
  {-# INLINE segAct #-}
  segAct :: Mat2x2 a -> V2 a -> V2 a
segAct = Mat2x2 a -> V2 a -> V2 a
forall a. Num a => Mat2x2 a -> V2 a -> V2 a
mulMV

-- | @since 1.1.0.0
instance (Num a) => SegAct (Dual (Mat2x2 a)) (V2 a) where
  {-# INLINE segAct #-}
  segAct :: Dual (Mat2x2 a) -> V2 a -> V2 a
segAct (Dual Mat2x2 a
f) = Mat2x2 a -> V2 a -> V2 a
forall a. Num a => Mat2x2 a -> V2 a -> V2 a
mulMV Mat2x2 a
f

-- | @since 1.1.0.0
newtype instance VU.MVector s (Mat2x2 a) = MV_Mat2x2 (VU.MVector s (Mat2x2Repr a))

-- | @since 1.1.0.0
newtype instance VU.Vector (Mat2x2 a) = V_Mat2x2 (VU.Vector (Mat2x2Repr a))

-- | @since 1.1.0.0
deriving instance (VU.Unbox a) => VGM.MVector VUM.MVector (Mat2x2 a)

-- | @since 1.1.0.0
deriving instance (VU.Unbox a) => VG.Vector VU.Vector (Mat2x2 a)

-- | @since 1.1.0.0
instance (VU.Unbox a) => VU.Unbox (Mat2x2 a)