-- | A permutation represented by a vector, mainly for binary exponentiation.
--
-- The permutation is a left semigroup action: \(p_2 (p_1 x) = (p_2 \circ p_1) x\).
--
-- ==== __Example__
-- >>> import AtCoder.Extra.Semigroup.Permutation qualified as Permutation
-- >>> import Data.Semigroup (Semigroup (stimes))
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> let perm = Permutation.new $ VU.fromList [1, 2, 3, 0]
-- >>> Permutation.act perm 1
-- 2
--
-- >>> Permutation.act (perm <> perm) 1
-- 3
--
-- >>> Permutation.act (stimes 3 perm) 1
-- 0
--
-- @since 1.1.0.0
module AtCoder.Extra.Semigroup.Permutation
  ( -- * Permutation
    Permutation (..),

    -- * Constructors
    new,
    unsafeNew,
    ident,
    zero,

    -- * Actions
    act,

    -- * Metadata
    length,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import Data.Vector.Generic qualified as VG
import Data.Vector.Unboxed qualified as VU
import GHC.Stack (HasCallStack)
import Prelude hiding (length)

-- | A permutation represented by a vector, mainly for binary exponentiation.
--
-- The permutation is a left semigroup action: \(p_2 (p_1 x) = (p_2 \circ p_1) x\).
--
-- @since 1.1.0.0
newtype Permutation = Permutation
  { Permutation -> Vector Int
unPermutation :: VU.Vector Int
  }
  deriving newtype
    ( -- | @since 1.1.0.0
      Permutation -> Permutation -> Bool
(Permutation -> Permutation -> Bool)
-> (Permutation -> Permutation -> Bool) -> Eq Permutation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Permutation -> Permutation -> Bool
== :: Permutation -> Permutation -> Bool
$c/= :: Permutation -> Permutation -> Bool
/= :: Permutation -> Permutation -> Bool
Eq,
      -- | @since 1.1.0.0
      Int -> Permutation -> ShowS
[Permutation] -> ShowS
Permutation -> String
(Int -> Permutation -> ShowS)
-> (Permutation -> String)
-> ([Permutation] -> ShowS)
-> Show Permutation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Permutation -> ShowS
showsPrec :: Int -> Permutation -> ShowS
$cshow :: Permutation -> String
show :: Permutation -> String
$cshowList :: [Permutation] -> ShowS
showList :: [Permutation] -> ShowS
Show
    )

-- | \(O(1)\) Creates a `Permutation`, performing boundary check on input vector.
--
-- @since 1.1.0.0
{-# INLINE new #-}
new :: (HasCallStack) => VU.Vector Int -> Permutation
new :: HasCallStack => Vector Int -> Permutation
new Vector Int
xs = Vector Int -> Permutation
Permutation Vector Int
xs
  where
    n :: Int
n = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
xs
    !()
_ = (() -> Int -> ()) -> () -> Vector Int -> ()
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
VU.foldl' (\() Int
i -> let !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (-Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) String
"AtCoder.Extra.Semigroup.Permutation.new: index boundary error" in ()) () Vector Int
xs

-- | \(O(1)\) Creates a `Permutation`, without performing boundary check on input vector.
--
-- @since 1.1.0.0
{-# INLINE unsafeNew #-}
unsafeNew :: (HasCallStack) => VU.Vector Int -> Permutation
unsafeNew :: HasCallStack => Vector Int -> Permutation
unsafeNew = Vector Int -> Permutation
Permutation

-- | \(O(1)\) Creates an identity `Permutation` of length \(n\).
--
-- @since 1.1.0.0
{-# INLINE ident #-}
ident :: Int -> Permutation
ident :: Int -> Permutation
ident = Vector Int -> Permutation
Permutation (Vector Int -> Permutation)
-> (Int -> Vector Int) -> Int -> Permutation
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
`VU.generate` Int -> Int
forall a. a -> a
id)

-- | \(O(1)\) Creates a zero `Permutation` of length \(n\). It's similar to `ident`, but filled
-- with \(-1\) and invalidates corresponding slots on composition.
--
-- @since 1.1.0.0
{-# INLINE zero #-}
zero :: Int -> Permutation
zero :: Int -> Permutation
zero Int
n = Vector Int -> Permutation
Permutation (Vector Int -> Permutation) -> Vector Int -> Permutation
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector Int
forall a. Unbox a => Int -> a -> Vector a
VU.replicate Int
n (-Int
1)

-- | \(O(1)\) Maps an index.
--
-- @since 1.1.0.0
{-# INLINE act #-}
act :: (HasCallStack) => Permutation -> Int -> Int
act :: HasCallStack => Permutation -> Int -> Int
act (Permutation Vector Int
vec) Int
i = case Vector Int
vec Vector Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i of
  (-1) -> Int
i
  Int
i' -> Int
i'

-- | \(O(1)\) Returns the length of the internal vector.
--
-- @since 1.1.0.0
{-# INLINE length #-}
length :: (HasCallStack) => Permutation -> Int
length :: HasCallStack => Permutation -> Int
length = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length (Vector Int -> Int)
-> (Permutation -> Vector Int) -> Permutation -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Permutation -> Vector Int
unPermutation

-- | @since 1.1.0.0
instance Semigroup Permutation where
  {-# INLINE (<>) #-}
  Permutation Vector Int
r2 <> :: Permutation -> Permutation -> Permutation
<> Permutation Vector Int
r1 = Vector Int -> Permutation
Permutation (Vector Int -> Permutation) -> Vector Int -> Permutation
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> Vector Int -> Vector Int
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
VU.map Int -> Int
f Vector Int
r1
    where
      !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r1) String
"AtCoder.Extra.Semigroup.Permutation.(<>): legth mismatch"
      f :: Int -> Int
f (-1) = -Int
1
      f Int
i = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
r2 Int
i