{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Mutable.Shuffle where
import Control.Monad.Primitive
import Control.Monad.Random (MonadRandom (..))
import Data.Vector.Mutable
import Prelude hiding (length, read, tail)
import System.Random (RandomGen)
import qualified System.Random as SR
shuffle
:: forall m a g
. ( PrimMonad m
, RandomGen g
)
=> MVector (PrimState m) a -> g -> m g
{-# INLINABLE shuffle #-}
shuffle mutV gen = go mutV gen (length mutV - 1)
where
go :: MVector (PrimState m) a -> g -> Int -> m g
{-# INLINE go #-}
go _ g 0 = pure g
go v g maxInd =
do
let (ind, newGen) :: (Int, g) = SR.randomR (0, maxInd) g
swap v 0 ind
go (tail v) newGen (maxInd - 1)
shuffleM
:: forall m a
. ( PrimMonad m
, MonadRandom m
)
=> MVector (PrimState m) a -> m ()
{-# INLINABLE shuffleM #-}
shuffleM mutV = go mutV (length mutV - 1)
where
go :: MVector (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go _ 0 = pure ()
go v maxInd =
do
ind <- getRandomR (0, maxInd)
swap v 0 ind
go (tail v) (maxInd - 1)
{-# SPECIALISE shuffleM :: MVector RealWorld a -> IO () #-}
shuffleK
:: forall m a
. ( PrimMonad m
, MonadRandom m
)
=> Int -> MVector (PrimState m) a -> m ()
{-# INLINABLE shuffleK #-}
shuffleK numberOfShuffles mutV = go mutV (numberOfShuffles - 1)
where
go :: MVector (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go _ k | k < 0
= error "Cannot pass negative value to ShuffleK"
go _ k | k >= length mutV
= error "Cannot pass value greater than the length of the vector to ShuffleK"
go _ 0 = pure ()
go v maxInd =
do
ind <- getRandomR (0, maxInd)
swap v 0 ind
go (tail v) (maxInd - 1)
maximalCycle
:: forall m a g
. ( PrimMonad m
, RandomGen g
)
=> MVector (PrimState m) a -> g -> m g
{-# INLINABLE maximalCycle #-}
maximalCycle mutV gen = go mutV gen (length mutV - 1)
where
go :: MVector (PrimState m) a -> g -> Int -> m g
{-# INLINE go #-}
go _ g 0 = pure g
go v g maxInd =
do
let (ind, newGen) :: (Int, g) = SR.randomR (1, maxInd) g
swap v 0 ind
go (tail v) newGen (maxInd - 1)
maximalCycleM
:: forall m a
. ( PrimMonad m
, MonadRandom m
)
=> MVector (PrimState m) a -> m ()
{-# INLINABLE maximalCycleM #-}
maximalCycleM mutV = go mutV (length mutV - 1)
where
go :: MVector (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go _ 0 = pure ()
go v maxInd =
do
ind <- getRandomR (1, maxInd)
swap v 0 ind
go (tail v) (maxInd - 1)
{-# SPECIALISE maximalCycleM :: MVector RealWorld a -> IO () #-}
derangement
:: forall m a g
. ( PrimMonad m
, RandomGen g
, Eq a
)
=> MVector (PrimState m) a -> g -> m g
{-# INLINABLE derangement #-}
derangement mutV gen = do
mutV_copy <- clone mutV
go mutV_copy mutV gen 0 (length mutV - 1)
where
go :: MVector (PrimState m) a -> MVector (PrimState m) a -> g -> Int -> Int -> m g
{-# INLINE go #-}
go c v g lastInd 0 =
do
v_last_old <- read c lastInd
v_last_new <- read v 0
if v_last_old == v_last_new then
do
unsafeCopy mutV c
go c mutV g 0 (length mutV - 1)
else
pure g
go c v oldGen currInd maxInd =
do
let (swapInd, newGen) :: (Int, g) = SR.randomR (0, maxInd) oldGen
v_old <- read c currInd
v_ind <- read v swapInd
if v_old == v_ind then
do
unsafeCopy mutV c
go c mutV newGen 0 (length mutV - 1)
else
do
swap v 0 swapInd
go c (tail v) newGen (currInd + 1) (maxInd - 1)
derangementM
:: forall m a
. ( PrimMonad m
, MonadRandom m
, Eq a
)
=> MVector (PrimState m) a -> m ()
{-# INLINABLE derangementM #-}
derangementM mutV = do
mutV_copy <- clone mutV
go mutV_copy mutV 0 (length mutV - 1)
where
go :: MVector (PrimState m) a -> MVector (PrimState m) a -> Int -> Int -> m ()
{-# INLINE go #-}
go c v lastInd 0 =
do
v_last_old <- read c lastInd
v_last_new <- read v 0
if v_last_old == v_last_new then
do
unsafeCopy mutV c
go c mutV 0 (length mutV - 1)
else
pure ()
go c v currInd maxInd =
do
swapInd :: Int <- getRandomR (0, maxInd)
v_old <- read c currInd
v_ind <- read v swapInd
if v_old == v_ind then
do
unsafeCopy mutV c
go c mutV 0 (length mutV - 1)
else
do
swap v 0 swapInd
go c (tail v) (currInd + 1) (maxInd - 1)