{-# LANGUAGE CPP, BangPatterns, ScopedTypeVariables, GeneralizedNewtypeDeriving, FlexibleContexts #-}
module Math.Combinat.Permutations
(
Permutation (..)
, fromPermutation
, permutationArray
, permutationUArray
, uarrayToPermutationUnsafe
, isPermutation
, maybePermutation
, toPermutation
, toPermutationUnsafe
, permutationSize
, DisjointCycles (..)
, fromDisjointCycles
, disjointCyclesUnsafe
, permutationToDisjointCycles
, disjointCyclesToPermutation
, numberOfCycles
, concatPermutations
, isIdentityPermutation
, isReversePermutation
, isEvenPermutation
, isOddPermutation
, signOfPermutation
, signValueOfPermutation
, module Math.Combinat.Sign
, isCyclicPermutation
, transposition
, transpositions
, adjacentTransposition
, adjacentTranspositions
, cycleLeft
, cycleRight
, reversePermutation
, inversions
, numberOfInversions
, numberOfInversionsNaive
, numberOfInversionsMerge
, bubbleSort2
, bubbleSort
, identity
, inverse
, multiply
, multiplyMany
, multiplyMany'
, permute
, permuteList
, permuteLeft , permuteRight
, permuteLeftList , permuteRightList
, sortingPermutationAsc
, sortingPermutationDesc
, asciiPermutation
, asciiDisjointCycles
, twoLineNotation
, inverseTwoLineNotation
, genericTwoLineNotation
, permutations
, _permutations
, permutationsNaive
, _permutationsNaive
, countPermutations
, randomPermutation
, _randomPermutation
, randomCyclicPermutation
, _randomCyclicPermutation
, randomPermutationDurstenfeld
, randomCyclicPermutationSattolo
, permuteMultiset
, countPermuteMultiset
, fasc2B_algorithm_L
)
where
import Control.Monad
import Control.Monad.ST
import Data.List hiding ( permutations )
import Data.Ord ( comparing )
import Data.Array (Array)
import Data.Array.ST
import Data.Array.Unboxed
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unsafe
import Math.Combinat.ASCII
import Math.Combinat.Classes
import Math.Combinat.Helper
import Math.Combinat.Sign
import Math.Combinat.Numbers ( factorial , binomial )
import System.Random
newtype Permutation = Permutation (UArray Int Int) deriving (Eq,Ord)
instance Show Permutation where
showsPrec d (Permutation arr)
= showParen (d > 10)
$ showString "toPermutation " . showsPrec 11 (elems arr)
instance Read Permutation where
readsPrec d r = readParen (d > 10) fun r where
fun r = [ (toPermutation p,t)
| ("toPermutation",s) <- lex r
, (p,t) <- readsPrec 11 s
]
instance DrawASCII Permutation where
ascii = asciiPermutation
newtype DisjointCycles = DisjointCycles [[Int]] deriving (Eq,Ord,Show,Read)
fromPermutation :: Permutation -> [Int]
fromPermutation (Permutation ar) = elems ar
permutationUArray :: Permutation -> UArray Int Int
permutationUArray (Permutation ar) = ar
permutationArray :: Permutation -> Array Int Int
permutationArray (Permutation ar) = listArray (1,n) (elems ar) where
(1,n) = bounds ar
toPermutationUnsafe :: [Int] -> Permutation
toPermutationUnsafe xs = Permutation perm where
n = length xs
perm = listArray (1,n) xs
uarrayToPermutationUnsafe :: UArray Int Int -> Permutation
uarrayToPermutationUnsafe = Permutation
isPermutation :: [Int] -> Bool
isPermutation xs = (ar!0 == 0) && and [ ar!j == 1 | j<-[1..n] ] where
n = length xs
ar = (accumArray (+) 0 (0,n) $ map f xs) :: UArray Int Int
f :: Int -> (Int,Int)
f !j = if j<1 || j>n then (0,1) else (j,1)
maybePermutation :: [Int] -> Maybe Permutation
maybePermutation input = runST action where
n = length input
action :: forall s. ST s (Maybe Permutation)
action = do
ar <- newArray (1,n) 0 :: ST s (STUArray s Int Int)
let go [] = return $ Just (Permutation $ listArray (1,n) input)
go (j:js) = if j<1 || j>n
then return Nothing
else do
z <- readArray ar j
writeArray ar j (z+1)
if z==0 then go js
else return Nothing
go input
toPermutation :: [Int] -> Permutation
toPermutation xs = case maybePermutation xs of
Just p -> p
Nothing -> error "toPermutation: not a permutation"
permutationSize :: Permutation -> Int
permutationSize (Permutation ar) = snd $ bounds ar
instance HasWidth Permutation where
width = permutationSize
isIdentityPermutation :: Permutation -> Bool
isIdentityPermutation (Permutation ar) = (elems ar == [1..n]) where
(1,n) = bounds ar
concatPermutations :: Permutation -> Permutation -> Permutation
concatPermutations perm1 perm2 = toPermutationUnsafe list where
n = permutationSize perm1
list = fromPermutation perm1 ++ map (+n) (fromPermutation perm2)
asciiPermutation :: Permutation -> ASCII
asciiPermutation = twoLineNotation
asciiDisjointCycles :: DisjointCycles -> ASCII
asciiDisjointCycles (DisjointCycles cycles) = final where
final = hCatWith VTop (HSepSpaces 1) boxes
boxes = [ genericTwoLineNotation (f cyc) | cyc <- cycles ]
f cyc = pairs (cyc ++ [head cyc])
twoLineNotation :: Permutation -> ASCII
twoLineNotation (Permutation arr) = genericTwoLineNotation $ zip [1..] (elems arr)
inverseTwoLineNotation :: Permutation -> ASCII
inverseTwoLineNotation (Permutation arr) =
genericTwoLineNotation $ sortBy (comparing snd) $ zip [1..] (elems arr)
genericTwoLineNotation :: [(Int,Int)] -> ASCII
genericTwoLineNotation xys = asciiFromLines [ topLine, botLine ] where
topLine = "( " ++ intercalate " " us ++ " )"
botLine = "( " ++ intercalate " " vs ++ " )"
pairs = [ (show x, show y) | (x,y) <- xys ]
(us,vs) = unzip (map f pairs)
f (s,t) = (s',t') where
a = length s
b = length t
c = max a b
s' = replicate (c-a) ' ' ++ s
t' = replicate (c-b) ' ' ++ t
fromDisjointCycles :: DisjointCycles -> [[Int]]
fromDisjointCycles (DisjointCycles cycles) = cycles
disjointCyclesUnsafe :: [[Int]] -> DisjointCycles
disjointCyclesUnsafe = DisjointCycles
instance DrawASCII DisjointCycles where
ascii = asciiDisjointCycles
instance HasNumberOfCycles DisjointCycles where
numberOfCycles (DisjointCycles cycles) = length cycles
instance HasNumberOfCycles Permutation where
numberOfCycles = numberOfCycles . permutationToDisjointCycles
disjointCyclesToPermutation :: Int -> DisjointCycles -> Permutation
disjointCyclesToPermutation n (DisjointCycles cycles) = Permutation perm where
pairs :: [Int] -> [(Int,Int)]
pairs xs@(x:_) = worker (xs++[x]) where
worker (x:xs@(y:_)) = (x,y):worker xs
worker _ = []
pairs [] = error "disjointCyclesToPermutation: empty cycle"
perm = runSTUArray $ do
ar <- newArray_ (1,n) :: ST s (STUArray s Int Int)
forM_ [1..n] $ \i -> writeArray ar i i
forM_ cycles $ \cyc -> forM_ (pairs cyc) $ \(i,j) -> writeArray ar i j
return ar
permutationToDisjointCycles :: Permutation -> DisjointCycles
permutationToDisjointCycles (Permutation perm) = res where
(1,n) = bounds perm
f :: [Int] -> Bool
f [_] = False
f _ = True
res = runST $ do
tag <- newArray (1,n) False
cycles <- unfoldM (step tag) 1
return (DisjointCycles $ filter f cycles)
step :: STUArray s Int Bool -> Int -> ST s ([Int],Maybe Int)
step tag k = do
cyc <- worker tag k k [k]
m <- next tag (k+1)
return (reverse cyc, m)
next :: STUArray s Int Bool -> Int -> ST s (Maybe Int)
next tag k = if k > n
then return Nothing
else readArray tag k >>= \b -> if b
then next tag (k+1)
else return (Just k)
worker :: STUArray s Int Bool -> Int -> Int -> [Int] -> ST s [Int]
worker tag k l cyc = do
writeArray tag l True
let m = perm ! l
if m == k
then return cyc
else worker tag k m (m:cyc)
isEvenPermutation :: Permutation -> Bool
isEvenPermutation (Permutation perm) = res where
(1,n) = bounds perm
res = runST $ do
tag <- newArray (1,n) False
cycles <- unfoldM (step tag) 1
return $ even (sum cycles)
step :: STUArray s Int Bool -> Int -> ST s (Int,Maybe Int)
step tag k = do
cyclen <- worker tag k k 0
m <- next tag (k+1)
return (cyclen,m)
next :: STUArray s Int Bool -> Int -> ST s (Maybe Int)
next tag k = if k > n
then return Nothing
else readArray tag k >>= \b -> if b
then next tag (k+1)
else return (Just k)
worker :: STUArray s Int Bool -> Int -> Int -> Int -> ST s Int
worker tag k l cyclen = do
writeArray tag l True
let m = perm ! l
if m == k
then return cyclen
else worker tag k m (1+cyclen)
isOddPermutation :: Permutation -> Bool
isOddPermutation = not . isEvenPermutation
signOfPermutation :: Permutation -> Sign
signOfPermutation perm = case isEvenPermutation perm of
True -> Plus
False -> Minus
{-# SPECIALIZE signValueOfPermutation :: Permutation -> Int #-}
{-# SPECIALIZE signValueOfPermutation :: Permutation -> Integer #-}
signValueOfPermutation :: Num a => Permutation -> a
signValueOfPermutation = signValue . signOfPermutation
isCyclicPermutation :: Permutation -> Bool
isCyclicPermutation perm =
case cycles of
[] -> True
[cyc] -> (length cyc == n)
_ -> False
where
n = permutationSize perm
DisjointCycles cycles = permutationToDisjointCycles perm
inversions :: Permutation -> [(Int,Int)]
inversions (Permutation arr) = list where
(_,n) = bounds arr
list = [ (i,j) | i<-[1..n-1], j<-[i+1..n], arr!i > arr!j ]
numberOfInversions :: Permutation -> Int
numberOfInversions = numberOfInversionsMerge
numberOfInversionsMerge :: Permutation -> Int
numberOfInversionsMerge (Permutation arr) = fst (sortCnt n $ elems arr) where
(_,n) = bounds arr
sortCnt :: Int -> [Int] -> (Int,[Int])
sortCnt 0 _ = (0,[] )
sortCnt 1 [x] = (0,[x])
sortCnt 2 [x,y] = if x>y then (1,[y,x]) else (0,[x,y])
sortCnt n xs = mergeCnt (sortCnt k us) (sortCnt l vs) where
k = div n 2
l = n - k
(us,vs) = splitAt k xs
mergeCnt :: (Int,[Int]) -> (Int,[Int]) -> (Int,[Int])
mergeCnt (!c,us) (!d,vs) = (c+d+e,ws) where
(e,ws) = go 0 us vs
go !k xs [] = ( k*length xs , xs )
go _ [] ys = ( 0 , ys)
go !k xxs@(x:xs) yys@(y:ys) = if x < y
then let (a,zs) = go k xs yys in (a+k, x:zs)
else let (a,zs) = go (k+1) xxs ys in (a , y:zs)
numberOfInversionsNaive :: Permutation -> Int
numberOfInversionsNaive (Permutation arr) = length list where
(_,n) = bounds arr
list = [ (0::Int) | i<-[1..n-1], j<-[i+1..n], arr!i > arr!j ]
bubbleSort2 :: Permutation -> [(Int,Int)]
bubbleSort2 = map f . bubbleSort where f i = (i,i+1)
bubbleSort :: Permutation -> [Int]
bubbleSort perm@(Permutation tgt) = runST action where
(_,n) = bounds tgt
action :: forall s. ST s [Int]
action = do
fwd <- newArray_ (1,n) :: ST s (STUArray s Int Int)
inv <- newArray_ (1,n) :: ST s (STUArray s Int Int)
forM_ [1..n] $ \i -> writeArray fwd i i
forM_ [1..n] $ \i -> writeArray inv i i
list <- forM [1..n] $ \x -> do
let k = tgt ! x
i <- readArray inv k
let j = x
let swaps = move i j
forM_ swaps $ \y -> do
a <- readArray fwd y
b <- readArray fwd (y+1)
writeArray fwd (y+1) a
writeArray fwd y b
u <- readArray inv a
v <- readArray inv b
writeArray inv b u
writeArray inv a v
return swaps
return (concat list)
move :: Int -> Int -> [Int]
move !i !j
| j == i = []
| j > i = [i..j-1]
| j < i = [i-1,i-2..j]
reversePermutation :: Int -> Permutation
reversePermutation n = Permutation $ listArray (1,n) [n,n-1..1]
isReversePermutation :: Permutation -> Bool
isReversePermutation (Permutation arr) = elems arr == [n,n-1..1] where (1,n) = bounds arr
transposition :: Int -> (Int,Int) -> Permutation
transposition n (i,j) =
if i>=1 && j>=1 && i<=n && j<=n
then Permutation $ listArray (1,n) [ f k | k<-[1..n] ]
else error "transposition: index out of range"
where
f k | k == i = j
| k == j = i
| otherwise = k
transpositions :: Int -> [(Int,Int)] -> Permutation
transpositions n list = Permutation (runSTUArray action) where
action :: ST s (STUArray s Int Int)
action = do
arr <- newArray_ (1,n)
forM_ [1..n] $ \i -> writeArray arr i i
let doSwap (i,j) = do
u <- readArray arr i
v <- readArray arr j
writeArray arr i v
writeArray arr j u
mapM_ doSwap list
return arr
adjacentTransposition :: Int -> Int -> Permutation
adjacentTransposition n k
| k>0 && k<n = transposition n (k,k+1)
| otherwise = error "adjacentTransposition: index out of range"
adjacentTranspositions :: Int -> [Int] -> Permutation
adjacentTranspositions n list = Permutation (runSTUArray action) where
action :: ST s (STUArray s Int Int)
action = do
arr <- newArray_ (1,n)
forM_ [1..n] $ \i -> writeArray arr i i
let doSwap i
| i<0 || i>=n = error "adjacentTranspositions: index out of range"
| otherwise = do
u <- readArray arr i
v <- readArray arr (i+1)
writeArray arr i v
writeArray arr (i+1) u
mapM_ doSwap list
return arr
cycleLeft :: Int -> Permutation
cycleLeft n = Permutation $ listArray (1,n) $ [2..n] ++ [1]
cycleRight :: Int -> Permutation
cycleRight n = Permutation $ listArray (1,n) $ n : [1..n-1]
multiply :: Permutation -> Permutation -> Permutation
multiply pi1@(Permutation perm1) pi2@(Permutation perm2) =
if (n==m)
then Permutation result
else error "multiply: permutations of different sets"
where
(_,n) = bounds perm1
(_,m) = bounds perm2
result = permute pi2 perm1
infixr 7 `multiply`
inverse :: Permutation -> Permutation
inverse (Permutation perm1) = Permutation result
where
result = array (1,n) $ map swap $ assocs perm1
(_,n) = bounds perm1
identity :: Int -> Permutation
identity n = Permutation $ listArray (1,n) [1..n]
multiplyMany :: [Permutation] -> Permutation
multiplyMany [] = error "multiplyMany: empty list, we don't know size of the result"
multiplyMany ps = foldl1' multiply ps
multiplyMany' :: Int -> [Permutation] -> Permutation
multiplyMany' n [] = identity n
multiplyMany' n ps@(p:_) = if n == permutationSize p
then foldl1' multiply ps
else error "multiplyMany': incompatible permutation size(s)"
{-# SPECIALIZE permute :: Permutation -> Array Int b -> Array Int b #-}
{-# SPECIALIZE permute :: Permutation -> UArray Int Int -> UArray Int Int #-}
permute :: IArray arr b => Permutation -> arr Int b -> arr Int b
permute = permuteRight
permuteList :: Permutation -> [a] -> [a]
permuteList = permuteRightList
{-# SPECIALIZE permuteRight :: Permutation -> Array Int b -> Array Int b #-}
{-# SPECIALIZE permuteRight :: Permutation -> UArray Int Int -> UArray Int Int #-}
permuteRight :: IArray arr b => Permutation -> arr Int b -> arr Int b
permuteRight pi@(Permutation perm) ar =
if (a==1) && (b==n)
then listArray (1,n) [ ar!(perm!i) | i <- [1..n] ]
else error "permuteRight: array bounds do not match"
where
(_,n) = bounds perm
(a,b) = bounds ar
permuteRightList :: forall a . Permutation -> [a] -> [a]
permuteRightList perm xs = elems $ permuteRight perm $ arr where
arr = listArray (1,n) xs :: Array Int a
n = permutationSize perm
{-# SPECIALIZE permuteLeft :: Permutation -> Array Int b -> Array Int b #-}
{-# SPECIALIZE permuteLeft :: Permutation -> UArray Int Int -> UArray Int Int #-}
permuteLeft :: IArray arr b => Permutation -> arr Int b -> arr Int b
permuteLeft pi@(Permutation perm) ar =
if (a==1) && (b==n)
then array (1,n) [ ( perm!i , ar!i ) | i <- [1..n] ]
else error "permuteLeft: array bounds do not match"
where
(_,n) = bounds perm
(a,b) = bounds ar
permuteLeftList :: forall a. Permutation -> [a] -> [a]
permuteLeftList perm xs = elems $ permuteLeft perm $ arr where
arr = listArray (1,n) xs :: Array Int a
n = permutationSize perm
sortingPermutationAsc :: Ord a => [a] -> Permutation
sortingPermutationAsc xs = toPermutation (map fst sorted) where
sorted = sortBy (comparing snd) $ zip [1..] xs
sortingPermutationDesc :: Ord a => [a] -> Permutation
sortingPermutationDesc xs = toPermutation (map fst sorted) where
sorted = sortBy (reverseComparing snd) $ zip [1..] xs
permutations :: Int -> [Permutation]
permutations = permutationsNaive
_permutations :: Int -> [[Int]]
_permutations = _permutationsNaive
permutationsNaive :: Int -> [Permutation]
permutationsNaive n = map toPermutationUnsafe $ _permutations n
_permutationsNaive :: Int -> [[Int]]
_permutationsNaive 0 = [[]]
_permutationsNaive 1 = [[1]]
_permutationsNaive n = helper [1..n] where
helper [] = [[]]
helper xs = [ i : ys | i <- xs , ys <- helper (xs `minus` i) ]
minus [] _ = []
minus (x:xs) i = if x < i then x : minus xs i else xs
countPermutations :: Int -> Integer
countPermutations = factorial
randomPermutation :: RandomGen g => Int -> g -> (Permutation,g)
randomPermutation = randomPermutationDurstenfeld
_randomPermutation :: RandomGen g => Int -> g -> ([Int],g)
_randomPermutation n rndgen = (fromPermutation perm, rndgen') where
(perm, rndgen') = randomPermutationDurstenfeld n rndgen
randomCyclicPermutation :: RandomGen g => Int -> g -> (Permutation,g)
randomCyclicPermutation = randomCyclicPermutationSattolo
_randomCyclicPermutation :: RandomGen g => Int -> g -> ([Int],g)
_randomCyclicPermutation n rndgen = (fromPermutation perm, rndgen') where
(perm, rndgen') = randomCyclicPermutationSattolo n rndgen
randomPermutationDurstenfeld :: RandomGen g => Int -> g -> (Permutation,g)
randomPermutationDurstenfeld = randomPermutationDurstenfeldSattolo False
randomCyclicPermutationSattolo :: RandomGen g => Int -> g -> (Permutation,g)
randomCyclicPermutationSattolo = randomPermutationDurstenfeldSattolo True
randomPermutationDurstenfeldSattolo :: RandomGen g => Bool -> Int -> g -> (Permutation,g)
randomPermutationDurstenfeldSattolo isSattolo n rnd = res where
res = runST $ do
ar <- newArray_ (1,n)
forM_ [1..n] $ \i -> writeArray ar i i
rnd' <- worker n (if isSattolo then n-1 else n) rnd ar
perm <- Data.Array.Unsafe.unsafeFreeze ar
return (Permutation perm, rnd')
worker :: RandomGen g => Int -> Int -> g -> STUArray s Int Int -> ST s g
worker n m rnd ar =
if n==1
then return rnd
else do
let (k,rnd') = randomR (1,m) rnd
when (k /= n) $ do
y <- readArray ar k
z <- readArray ar n
writeArray ar n y
writeArray ar k z
worker (n-1) (m-1) rnd' ar
permuteMultiset :: (Eq a, Ord a) => [a] -> [[a]]
permuteMultiset = fasc2B_algorithm_L
countPermuteMultiset :: (Eq a, Ord a) => [a] -> Integer
countPermuteMultiset xs = factorial n `div` product [ factorial (length z) | z <- group ys ]
where
ys = sort xs
n = length xs
fasc2B_algorithm_L :: (Eq a, Ord a) => [a] -> [[a]]
fasc2B_algorithm_L xs = unfold1 next (sort xs) where
next xs = case findj (reverse xs,[]) of
Nothing -> Nothing
Just ( (l:ls) , rs) -> Just $ inc l ls (reverse rs,[])
Just ( [] , _ ) -> error "permute: should not happen"
findj ( xxs@(x:xs) , yys@(y:_) ) = if x >= y
then findj ( xs , x : yys )
else Just ( xxs , yys )
findj ( x:xs , [] ) = findj ( xs , [x] )
findj ( [] , _ ) = Nothing
inc !u us ( (x:xs) , yys ) = if u >= x
then inc u us ( xs , x : yys )
else reverse (x:us) ++ reverse (u:yys) ++ xs
inc _ _ ( [] , _ ) = error "permute: should not happen"