-- | A rearrangement is a generalisation of transposition, where the
-- dimensions are arbitrarily permuted.
module Futhark.IR.Prop.Rearrange
  ( rearrangeShape,
    rearrangeInverse,
    rearrangeReach,
    rearrangeCompose,
    isPermutationOf,
    transposeIndex,
    isMapTranspose,
  )
where

import Data.List (sortOn, tails)
import Futhark.Util

-- | Calculate the given permutation of the list.  It is an error if
-- the permutation goes out of bounds.
rearrangeShape :: [Int] -> [a] -> [a]
rearrangeShape :: [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [a]
l = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> a
pick [Int]
perm
  where
    pick :: Int -> a
pick Int
i
      | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n = [a]
l [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
i
      | Bool
otherwise =
        [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Int] -> [Char]
forall a. Show a => a -> [Char]
show [Int]
perm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" is not a valid permutation for input."
    n :: Int
n = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l

-- | Produce the inverse permutation.
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse [Int]
perm = ((Int, Int) -> Int) -> [(Int, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int) -> Int
forall a b. (a, b) -> b
snd ([(Int, Int)] -> [Int]) -> [(Int, Int)] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((Int, Int) -> Int) -> [(Int, Int)] -> [(Int, Int)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, Int) -> Int
forall a b. (a, b) -> a
fst ([(Int, Int)] -> [(Int, Int)]) -> [(Int, Int)] -> [(Int, Int)]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
perm [Int
0 ..]

-- | Return the first dimension not affected by the permutation.  For
-- example, the permutation @[1,0,2]@ would return @2@.
rearrangeReach :: [Int] -> Int
rearrangeReach :: [Int] -> Int
rearrangeReach [Int]
perm = case (([Int], [Int]) -> Bool) -> [([Int], [Int])] -> [([Int], [Int])]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (([Int] -> [Int] -> Bool) -> ([Int], [Int]) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
(/=)) ([([Int], [Int])] -> [([Int], [Int])])
-> [([Int], [Int])] -> [([Int], [Int])]
forall a b. (a -> b) -> a -> b
$ [[Int]] -> [[Int]] -> [([Int], [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [[Int]]
forall a. [a] -> [[a]]
tails [Int]
perm) ([Int] -> [[Int]]
forall a. [a] -> [[a]]
tails [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]) of
  [] -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  ([Int]
perm', [Int]
_) : [([Int], [Int])]
_ -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm'
  where
    n :: Int
n = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm

-- | Compose two permutations, with the second given permutation being
-- applied first.
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose = [Int] -> [Int] -> [Int]
forall a. [Int] -> [a] -> [a]
rearrangeShape

-- | Check whether the first list is a permutation of the second, and
-- if so, return the permutation.  This will also find identity
-- permutations (i.e. the lists are the same) The implementation is
-- naive and slow.
isPermutationOf :: Eq a => [a] -> [a] -> Maybe [Int]
isPermutationOf :: [a] -> [a] -> Maybe [Int]
isPermutationOf [a]
l1 [a]
l2 =
  case ([Maybe a] -> a -> Maybe ([Maybe a], Int))
-> [Maybe a] -> [a] -> Maybe ([Maybe a], [Int])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick Int
0) ((a -> Maybe a) -> [a] -> [Maybe a]
forall a b. (a -> b) -> [a] -> [b]
map a -> Maybe a
forall a. a -> Maybe a
Just [a]
l2) [a]
l1 of
    Just ([Maybe a]
l2', [Int]
perm)
      | (Maybe a -> Bool) -> [Maybe a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe a -> Maybe a -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe a
forall a. Maybe a
Nothing) [Maybe a]
l2' -> [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    Maybe ([Maybe a], [Int])
_ -> Maybe [Int]
forall a. Maybe a
Nothing
  where
    pick :: Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
    pick :: Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick Int
_ [] a
_ = Maybe ([Maybe a], Int)
forall a. Maybe a
Nothing
    pick Int
i (Maybe a
x : [Maybe a]
xs) a
y
      | a -> Maybe a
forall a. a -> Maybe a
Just a
y Maybe a -> Maybe a -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe a
x = ([Maybe a], Int) -> Maybe ([Maybe a], Int)
forall a. a -> Maybe a
Just (Maybe a
forall a. Maybe a
Nothing Maybe a -> [Maybe a] -> [Maybe a]
forall a. a -> [a] -> [a]
: [Maybe a]
xs, Int
i)
      | Bool
otherwise = do
        ([Maybe a]
xs', Int
v) <- Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Maybe a]
xs a
y
        ([Maybe a], Int) -> Maybe ([Maybe a], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a
x Maybe a -> [Maybe a] -> [Maybe a]
forall a. a -> [a] -> [a]
: [Maybe a]
xs', Int
v)

-- | If @l@ is an index into the array @a@, then @transposeIndex k n
-- l@ is an index to the same element in the array @transposeArray k n
-- a@.
transposeIndex :: Int -> Int -> [a] -> [a]
transposeIndex :: Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [a]
l
  | Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l =
    let n' :: Int
n' = ((Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k
     in Int -> Int -> [a] -> [a]
forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n' [a]
l
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0,
    ([a]
pre, a
needle : [a]
end) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
k [a]
l,
    ([a]
beg, [a]
mid) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
pre Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) [a]
pre =
    [a]
beg [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
needle] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
mid [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
end
  | ([a]
beg, a
needle : [a]
post) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
k [a]
l,
    ([a]
mid, [a]
end) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [a]
post =
    [a]
beg [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
mid [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
needle] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
end
  | Bool
otherwise = [a]
l

-- | If @perm@ is conceptually a map of a transposition,
-- @isMapTranspose perm@ returns the number of dimensions being mapped
-- and the number dimension being transposed.  For example, we can
-- consider the permutation @[0,1,4,5,2,3]@ as a map of a transpose,
-- by considering dimensions @[0,1]@, @[4,5]@, and @[2,3]@ as single
-- dimensions each.
--
-- If the input is not a valid permutation, then the result is
-- undefined.
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm
  | [Int]
posttrans [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped .. [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
posttrans Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
pretrans,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
posttrans =
    (Int, Int, Int) -> Maybe (Int, Int, Int)
forall a. a -> Maybe a
Just ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped, [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
pretrans, [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
posttrans)
  | Bool
otherwise =
    Maybe (Int, Int, Int)
forall a. Maybe a
Nothing
  where
    ([Int]
mapped, [Int]
notmapped) = Int -> [Int] -> ([Int], [Int])
forall a. (Eq a, Num a) => a -> [a] -> ([a], [a])
findIncreasingFrom Int
0 [Int]
perm
    ([Int]
pretrans, [Int]
posttrans) = [Int] -> ([Int], [Int])
forall a. (Eq a, Num a) => [a] -> ([a], [a])
findTransposed [Int]
notmapped

    findIncreasingFrom :: a -> [a] -> ([a], [a])
findIncreasingFrom a
x (a
i : [a]
is)
      | a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x =
        let ([a]
js, [a]
ps) = a -> [a] -> ([a], [a])
findIncreasingFrom (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) [a]
is
         in (a
i a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
js, [a]
ps)
    findIncreasingFrom a
_ [a]
is =
      ([], [a]
is)

    findTransposed :: [a] -> ([a], [a])
findTransposed [] =
      ([], [])
    findTransposed (a
i : [a]
is) =
      a -> [a] -> ([a], [a])
forall a. (Eq a, Num a) => a -> [a] -> ([a], [a])
findIncreasingFrom a
i (a
i a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
is)