{-# LANGUAGE CPP #-}
module Test.Hspec.Core.Shuffle (
  shuffleForest
#ifdef TEST
, shuffle
, mkArray
#endif
) where

import           Prelude ()
import           Test.Hspec.Core.Compat
import           Test.Hspec.Core.Tree

import           System.Random
import           Control.Monad.ST
import           Data.STRef
import           Data.Array.ST

shuffleForest :: STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
shuffleForest :: STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
shuffleForest STRef s StdGen
ref [Tree c a]
xs = (STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
forall s a. STRef s StdGen -> [a] -> ST s [a]
shuffle STRef s StdGen
ref [Tree c a]
xs ST s [Tree c a]
-> ([Tree c a] -> ST s [Tree c a]) -> ST s [Tree c a]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Tree c a -> ST s (Tree c a)) -> [Tree c a] -> ST s [Tree c a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (STRef s StdGen -> Tree c a -> ST s (Tree c a)
forall s c a. STRef s StdGen -> Tree c a -> ST s (Tree c a)
shuffleTree STRef s StdGen
ref))

shuffleTree :: STRef s StdGen -> Tree c a -> ST s (Tree c a)
shuffleTree :: STRef s StdGen -> Tree c a -> ST s (Tree c a)
shuffleTree STRef s StdGen
ref Tree c a
t = case Tree c a
t of
  Node String
d [Tree c a]
xs -> String -> [Tree c a] -> Tree c a
forall c a. String -> [Tree c a] -> Tree c a
Node String
d ([Tree c a] -> Tree c a) -> ST s [Tree c a] -> ST s (Tree c a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
forall s c a. STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
shuffleForest STRef s StdGen
ref [Tree c a]
xs
  NodeWithCleanup Maybe (String, Location)
loc c
c [Tree c a]
xs -> Maybe (String, Location) -> c -> [Tree c a] -> Tree c a
forall c a. Maybe (String, Location) -> c -> [Tree c a] -> Tree c a
NodeWithCleanup Maybe (String, Location)
loc c
c ([Tree c a] -> Tree c a) -> ST s [Tree c a] -> ST s (Tree c a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
forall s c a. STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
shuffleForest STRef s StdGen
ref [Tree c a]
xs
  Leaf {} -> Tree c a -> ST s (Tree c a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tree c a
t

shuffle :: STRef s StdGen -> [a] -> ST s [a]
shuffle :: STRef s StdGen -> [a] -> ST s [a]
shuffle STRef s StdGen
ref [a]
xs = do
  STArray s Int a
arr <- [a] -> ST s (STArray s Int a)
forall a s. [a] -> ST s (STArray s Int a)
mkArray [a]
xs
  bounds :: (Int, Int)
bounds@(Int
_, Int
n) <- STArray s Int a -> ST s (Int, Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds STArray s Int a
arr
  [Int] -> (Int -> ST s a) -> ST s [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((Int, Int) -> [Int]
forall a. Ix a => (a, a) -> [a]
range (Int, Int)
bounds) ((Int -> ST s a) -> ST s [a]) -> (Int -> ST s a) -> ST s [a]
forall a b. (a -> b) -> a -> b
$ \ Int
i -> do
    Int
j <- (Int, Int) -> ST s Int
forall b. Random b => (b, b) -> ST s b
randomIndex (Int
i, Int
n)
    a
vi <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
arr Int
i
    a
vj <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
arr Int
j
    STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
arr Int
j a
vi
    a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
vj
  where
    randomIndex :: (b, b) -> ST s b
randomIndex (b, b)
bounds = do
      (b
a, StdGen
gen) <- (b, b) -> StdGen -> (b, StdGen)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (b, b)
bounds (StdGen -> (b, StdGen)) -> ST s StdGen -> ST s (b, StdGen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STRef s StdGen -> ST s StdGen
forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
ref
      STRef s StdGen -> StdGen -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s StdGen
ref StdGen
gen
      b -> ST s b
forall (m :: * -> *) a. Monad m => a -> m a
return b
a

mkArray :: [a] -> ST s (STArray s Int a)
mkArray :: [a] -> ST s (STArray s Int a)
mkArray [a]
xs = (Int, Int) -> [a] -> ST s (STArray s Int a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
1, [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) [a]
xs