{-# LANGUAGE TypeOperators, ExplicitForAll, FlexibleContexts #-}
module Data.Array.Repa.Operators.IndexSpace
( reshape
, append, (++)
, transpose
, extract
, backpermute, unsafeBackpermute
, backpermuteDft, unsafeBackpermuteDft
, extend, unsafeExtend
, slice, unsafeSlice)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape as S
import Prelude hiding ((++), traverse)
import qualified Prelude as P
stage :: [Char]
stage = [Char]
"Data.Array.Repa.Operators.IndexSpace"
reshape :: ( Shape sh1, Shape sh2
, Source r1 e)
=> sh2
-> Array r1 sh1 e
-> Array D sh2 e
reshape :: sh2 -> Array r1 sh1 e -> Array D sh2 e
reshape sh2
sh2 Array r1 sh1 e
arr
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ sh2 -> Int
forall sh. Shape sh => sh -> Int
S.size sh2
sh2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== sh1 -> Int
forall sh. Shape sh => sh -> Int
S.size (Array r1 sh1 e -> sh1
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh1 e
arr)
= [Char] -> Array D sh2 e
forall a. HasCallStack => [Char] -> a
error
([Char] -> Array D sh2 e) -> [Char] -> Array D sh2 e
forall a b. (a -> b) -> a -> b
$ [Char]
stage [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
".reshape: reshaped array will not match size of the original"
reshape sh2
sh2 Array r1 sh1 e
arr
= sh2 -> (sh2 -> e) -> Array D sh2 e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh2
sh2
((sh2 -> e) -> Array D sh2 e) -> (sh2 -> e) -> Array D sh2 e
forall a b. (a -> b) -> a -> b
$ Array r1 sh1 e -> sh1 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
unsafeIndex Array r1 sh1 e
arr (sh1 -> e) -> (sh2 -> sh1) -> sh2 -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh1 -> Int -> sh1
forall sh. Shape sh => sh -> Int -> sh
fromIndex (Array r1 sh1 e -> sh1
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh1 e
arr) (Int -> sh1) -> (sh2 -> Int) -> sh2 -> sh1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh2 -> sh2 -> Int
forall sh. Shape sh => sh -> sh -> Int
toIndex sh2
sh2
{-# INLINE [2] reshape #-}
append, (++)
:: ( Shape sh
, Source r1 e, Source r2 e)
=> Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e
-> Array D (sh :. Int) e
append :: Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
append Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2
= Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e
-> ((sh :. Int) -> (sh :. Int) -> sh :. Int)
-> (((sh :. Int) -> e) -> ((sh :. Int) -> e) -> (sh :. Int) -> e)
-> Array D (sh :. Int) e
forall r1 r2 sh sh' sh'' a b c.
(Source r1 a, Source r2 b, Shape sh, Shape sh') =>
Array r1 sh a
-> Array r2 sh' b
-> (sh -> sh' -> sh'')
-> ((sh -> a) -> (sh' -> b) -> sh'' -> c)
-> Array D sh'' c
unsafeTraverse2 Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2 (sh :. Int) -> (sh :. Int) -> sh :. Int
forall tail head.
(Shape tail, Num head) =>
(tail :. head) -> (tail :. head) -> tail :. head
fnExtent ((sh :. Int) -> e) -> ((sh :. Int) -> e) -> (sh :. Int) -> e
fnElem
where
(sh
_ :. Int
n) = Array r1 (sh :. Int) e -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 (sh :. Int) e
arr1
fnExtent :: (tail :. head) -> (tail :. head) -> tail :. head
fnExtent (tail
sh1 :. head
i) (tail
sh2 :. head
j)
= tail -> tail -> tail
forall sh. Shape sh => sh -> sh -> sh
intersectDim tail
sh1 tail
sh2 tail -> head -> tail :. head
forall tail head. tail -> head -> tail :. head
:. (head
i head -> head -> head
forall a. Num a => a -> a -> a
+ head
j)
fnElem :: ((sh :. Int) -> e) -> ((sh :. Int) -> e) -> (sh :. Int) -> e
fnElem (sh :. Int) -> e
f1 (sh :. Int) -> e
f2 (sh
sh :. Int
i)
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n = (sh :. Int) -> e
f1 (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i)
| Bool
otherwise = (sh :. Int) -> e
f2 (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n))
{-# INLINE [2] append #-}
++ :: Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
(++) Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2 = Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
forall sh r1 e r2.
(Shape sh, Source r1 e, Source r2 e) =>
Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
append Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2
{-# INLINE (++) #-}
transpose
:: (Shape sh, Source r e)
=> Array r (sh :. Int :. Int) e
-> Array D (sh :. Int :. Int) e
transpose :: Array r ((sh :. Int) :. Int) e -> Array D ((sh :. Int) :. Int) e
transpose Array r ((sh :. Int) :. Int) e
arr
= Array r ((sh :. Int) :. Int) e
-> (((sh :. Int) :. Int) -> (sh :. Int) :. Int)
-> ((((sh :. Int) :. Int) -> e) -> ((sh :. Int) :. Int) -> e)
-> Array D ((sh :. Int) :. Int) e
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
unsafeTraverse Array r ((sh :. Int) :. Int) e
arr
(\(sh
sh :. Int
m :. Int
n) -> (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
n (sh :. Int) -> Int -> (sh :. Int) :. Int
forall tail head. tail -> head -> tail :. head
:.Int
m))
(\((sh :. Int) :. Int) -> e
f -> \(sh
sh :. Int
i :. Int
j) -> ((sh :. Int) :. Int) -> e
f (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
j (sh :. Int) -> Int -> (sh :. Int) :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i))
{-# INLINE [2] transpose #-}
extract :: (Shape sh, Source r e)
=> sh
-> sh
-> Array r sh e
-> Array D sh e
sh
start sh
sz Array r sh e
arr
= sh -> (sh -> e) -> Array D sh e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh
sz (\sh
ix -> Array r sh e
arr Array r sh e -> sh -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` (sh -> sh -> sh
forall sh. Shape sh => sh -> sh -> sh
addDim sh
start sh
ix))
{-# INLINE [2] extract #-}
backpermute, unsafeBackpermute
:: forall r sh1 sh2 e
. ( Shape sh1
, Source r e)
=> sh2
-> (sh2 -> sh1)
-> Array r sh1 e
-> Array D sh2 e
backpermute :: sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
backpermute sh2
newExtent sh2 -> sh1
perm Array r sh1 e
arr
= Array r sh1 e
-> (sh1 -> sh2) -> ((sh1 -> e) -> sh2 -> e) -> Array D sh2 e
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
traverse Array r sh1 e
arr (sh2 -> sh1 -> sh2
forall a b. a -> b -> a
const sh2
newExtent) ((sh1 -> e) -> (sh2 -> sh1) -> sh2 -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh2 -> sh1
perm)
{-# INLINE [2] backpermute #-}
unsafeBackpermute :: sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute sh2
newExtent sh2 -> sh1
perm Array r sh1 e
arr
= Array r sh1 e
-> (sh1 -> sh2) -> ((sh1 -> e) -> sh2 -> e) -> Array D sh2 e
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
unsafeTraverse Array r sh1 e
arr (sh2 -> sh1 -> sh2
forall a b. a -> b -> a
const sh2
newExtent) ((sh1 -> e) -> (sh2 -> sh1) -> sh2 -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh2 -> sh1
perm)
{-# INLINE [2] unsafeBackpermute #-}
backpermuteDft, unsafeBackpermuteDft
:: forall r1 r2 sh1 sh2 e
. ( Shape sh1, Shape sh2
, Source r1 e, Source r2 e)
=> Array r2 sh2 e
-> (sh2 -> Maybe sh1)
-> Array r1 sh1 e
-> Array D sh2 e
backpermuteDft :: Array r2 sh2 e
-> (sh2 -> Maybe sh1) -> Array r1 sh1 e -> Array D sh2 e
backpermuteDft Array r2 sh2 e
arrDft sh2 -> Maybe sh1
fnIndex Array r1 sh1 e
arrSrc
= sh2 -> (sh2 -> e) -> Array D sh2 e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (Array r2 sh2 e -> sh2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh2 e
arrDft) sh2 -> e
fnElem
where fnElem :: sh2 -> e
fnElem sh2
ix
= case sh2 -> Maybe sh1
fnIndex sh2
ix of
Just sh1
ix' -> Array r1 sh1 e
arrSrc Array r1 sh1 e -> sh1 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`index` sh1
ix'
Maybe sh1
Nothing -> Array r2 sh2 e
arrDft Array r2 sh2 e -> sh2 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`index` sh2
ix
{-# INLINE [2] backpermuteDft #-}
unsafeBackpermuteDft :: Array r2 sh2 e
-> (sh2 -> Maybe sh1) -> Array r1 sh1 e -> Array D sh2 e
unsafeBackpermuteDft Array r2 sh2 e
arrDft sh2 -> Maybe sh1
fnIndex Array r1 sh1 e
arrSrc
= sh2 -> (sh2 -> e) -> Array D sh2 e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (Array r2 sh2 e -> sh2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh2 e
arrDft) sh2 -> e
fnElem
where fnElem :: sh2 -> e
fnElem sh2
ix
= case sh2 -> Maybe sh1
fnIndex sh2
ix of
Just sh1
ix' -> Array r1 sh1 e
arrSrc Array r1 sh1 e -> sh1 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh1
ix'
Maybe sh1
Nothing -> Array r2 sh2 e
arrDft Array r2 sh2 e -> sh2 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh2
ix
{-# INLINE [2] unsafeBackpermuteDft #-}
extend, unsafeExtend
:: ( Slice sl
, Shape (SliceShape sl)
, Source r e)
=> sl
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
extend :: sl -> Array r (SliceShape sl) e -> Array D (FullShape sl) e
extend sl
sl Array r (SliceShape sl) e
arr
= FullShape sl
-> (FullShape sl -> SliceShape sl)
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
backpermute
(sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl (Array r (SliceShape sl) e -> SliceShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (SliceShape sl) e
arr))
(sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl)
Array r (SliceShape sl) e
arr
{-# INLINE [2] extend #-}
unsafeExtend :: sl -> Array r (SliceShape sl) e -> Array D (FullShape sl) e
unsafeExtend sl
sl Array r (SliceShape sl) e
arr
= FullShape sl
-> (FullShape sl -> SliceShape sl)
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute
(sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl (Array r (SliceShape sl) e -> SliceShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (SliceShape sl) e
arr))
(sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl)
Array r (SliceShape sl) e
arr
{-# INLINE [2] unsafeExtend #-}
slice, unsafeSlice
:: ( Slice sl
, Shape (FullShape sl)
, Source r e)
=> Array r (FullShape sl) e
-> sl
-> Array D (SliceShape sl) e
slice :: Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
slice Array r (FullShape sl) e
arr sl
sl
= SliceShape sl
-> (SliceShape sl -> FullShape sl)
-> Array r (FullShape sl) e
-> Array D (SliceShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
backpermute
(sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl (Array r (FullShape sl) e -> FullShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (FullShape sl) e
arr))
(sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl)
Array r (FullShape sl) e
arr
{-# INLINE [2] slice #-}
unsafeSlice :: Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array r (FullShape sl) e
arr sl
sl
= SliceShape sl
-> (SliceShape sl -> FullShape sl)
-> Array r (FullShape sl) e
-> Array D (SliceShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute
(sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl (Array r (FullShape sl) e -> FullShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (FullShape sl) e
arr))
(sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl)
Array r (FullShape sl) e
arr
{-# INLINE [2] unsafeSlice #-}