{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Prelude -- Copyright : [2009..2017] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell -- [2010..2011] Ben Lever -- License : BSD3 -- -- Maintainer : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au> -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- Standard functions that are not part of the core set (directly represented in -- the AST), but are instead implemented in terms of the core set. -- module Data.Array.Accelerate.Prelude ( -- * Element-wise operations indexed, imap, -- * Zipping zipWith3, zipWith4, zipWith5, zipWith6, zipWith7, zipWith8, zipWith9, izipWith, izipWith3, izipWith4, izipWith5, izipWith6, izipWith7, izipWith8, izipWith9, zip, zip3, zip4, zip5, zip6, zip7, zip8, zip9, -- * Unzipping unzip, unzip3, unzip4, unzip5, unzip6, unzip7, unzip8, unzip9, -- * Reductions foldAll, fold1All, -- ** Specialised folds all, any, and, or, sum, product, minimum, maximum, -- * Scans prescanl, postscanl, prescanr, postscanr, -- ** Segmented scans scanlSeg, scanl'Seg, scanl1Seg, prescanlSeg, postscanlSeg, scanrSeg, scanr'Seg, scanr1Seg, prescanrSeg, postscanrSeg, -- * Shape manipulation flatten, -- * Enumeration and filling fill, enumFromN, enumFromStepN, -- * Concatenation (++), -- * Working with predicates -- ** Filtering filter, -- ** Scatter / Gather scatter, scatterIf, gather, gatherIf, -- * Permutations reverse, transpose, -- * Extracting sub-vectors init, tail, take, drop, slit, -- * Controlling execution compute, -- * Flow control IfThenElse(..), -- ** Array-level (?|), -- ** Expression-level (?), caseof, -- * Scalar iteration iterate, -- * Scalar reduction sfoldl, -- sfoldr, -- * Lifting and unlifting Lift(..), Unlift(..), lift1, lift2, lift3, ilift1, ilift2, ilift3, -- ** Tuple construction and destruction fst, afst, snd, asnd, curry, uncurry, -- ** Index construction and destruction index0, index1, unindex1, index2, unindex2, index3, unindex3, -- * Array operations with a scalar result the, null, length, -- * Sequence operations -- fromSeq, fromSeqElems, fromSeqShapes, toSeqInner, toSeqOuter2, toSeqOuter3, generateSeq, ) where -- avoid clashes with Prelude functions -- import Data.Typeable ( gcast ) import GHC.Base ( Constraint ) import Prelude ( (.), ($), Maybe(..), const, id, fromInteger, flip, undefined, fail ) -- friends import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Array.Sugar hiding ( (!), ignore, shape, size, intersect, toIndex, fromIndex ) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Data.Bits -- Element-wise operations -- ----------------------- -- | Pair each element with its index -- -- >>> let xs = fromList (Z:.5) [0..] -- >>> indexed (use xs) -- Vector (Z :. 5) [(Z :. 0,0.0),(Z :. 1,1.0),(Z :. 2,2.0),(Z :. 3,3.0),(Z :. 4,4.0)] -- -- >>> let mat = fromList (Z:.3:.4) [0..] -- >>> indexed (use mat) -- Matrix (Z :. 3 :. 4) -- [(Z :. 0 :. 0,0.0),(Z :. 0 :. 1,1.0), (Z :. 0 :. 2,2.0), (Z :. 0 :. 3,3.0), -- (Z :. 1 :. 0,4.0),(Z :. 1 :. 1,5.0), (Z :. 1 :. 2,6.0), (Z :. 1 :. 3,7.0), -- (Z :. 2 :. 0,8.0),(Z :. 2 :. 1,9.0),(Z :. 2 :. 2,10.0),(Z :. 2 :. 3,11.0)] -- indexed :: (Shape sh, Elt a) => Acc (Array sh a) -> Acc (Array sh (sh, a)) indexed xs = zip (generate (shape xs) id) xs -- | Apply a function to every element of an array and its index -- imap :: (Shape sh, Elt a, Elt b) => (Exp sh -> Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b) imap f xs = zipWith f (generate (shape xs) id) xs -- | Zip three arrays with the given function, analogous to 'zipWith'. -- zipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => (Exp a -> Exp b -> Exp c -> Exp d) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) zipWith3 f as bs cs = generate (shape as `intersect` shape bs `intersect` shape cs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix)) -- | Zip four arrays with the given function, analogous to 'zipWith'. -- zipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) zipWith4 f as bs cs ds = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix)) -- | Zip five arrays with the given function, analogous to 'zipWith'. -- zipWith5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) zipWith5 f as bs cs ds es = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix)) -- | Zip six arrays with the given function, analogous to 'zipWith'. -- zipWith6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) zipWith6 f as bs cs ds es fs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix)) -- | Zip seven arrays with the given function, analogous to 'zipWith'. -- zipWith7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) zipWith7 f as bs cs ds es fs gs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix)) -- | Zip eight arrays with the given function, analogous to 'zipWith'. -- zipWith8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) zipWith8 f as bs cs ds es fs gs hs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs `intersect` shape hs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix)) -- | Zip nine arrays with the given function, analogous to 'zipWith'. -- zipWith9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i -> Exp j) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) -> Acc (Array sh j) zipWith9 f as bs cs ds es fs gs hs is = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs `intersect` shape hs `intersect` shape is) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix) (is ! ix)) -- | Zip two arrays with a function that also takes the element index -- izipWith :: (Shape sh, Elt a, Elt b, Elt c) => (Exp sh -> Exp a -> Exp b -> Exp c) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) izipWith f as bs = generate (shape as `intersect` shape bs) (\ix -> f ix (as ! ix) (bs ! ix)) -- | Zip three arrays with a function that also takes the element index, -- analogous to 'izipWith'. -- izipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) izipWith3 f as bs cs = generate (shape as `intersect` shape bs `intersect` shape cs) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix)) -- | Zip four arrays with the given function that also takes the element index, -- analogous to 'zipWith'. -- izipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) izipWith4 f as bs cs ds = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix)) -- | Zip five arrays with the given function that also takes the element index, -- analogous to 'zipWith'. -- izipWith5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) izipWith5 f as bs cs ds es = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix)) -- | Zip six arrays with the given function that also takes the element index, -- analogous to 'zipWith'. -- izipWith6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) izipWith6 f as bs cs ds es fs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix)) -- | Zip seven arrays with the given function that also takes the element -- index, analogous to 'zipWith'. -- izipWith7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) izipWith7 f as bs cs ds es fs gs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix)) -- | Zip eight arrays with the given function that also takes the element -- index, analogous to 'zipWith'. -- izipWith8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) izipWith8 f as bs cs ds es fs gs hs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs `intersect` shape hs) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix)) -- | Zip nine arrays with the given function that also takes the element index, -- analogous to 'zipWith'. -- izipWith9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j) => (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i -> Exp j) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) -> Acc (Array sh j) izipWith9 f as bs cs ds es fs gs hs is = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs `intersect` shape hs `intersect` shape is) (\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix) (is ! ix)) -- | Combine the elements of two arrays pairwise. The shape of the result is the -- intersection of the two argument shapes. -- zip :: (Shape sh, Elt a, Elt b) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh (a, b)) zip = zipWith (curry lift) -- | Take three arrays and return an array of triples, analogous to zip. -- zip3 :: (Shape sh, Elt a, Elt b, Elt c) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh (a, b, c)) zip3 = zipWith3 (\a b c -> lift (a,b,c)) -- | Take four arrays and return an array of quadruples, analogous to zip. -- zip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh (a, b, c, d)) zip4 = zipWith4 (\a b c d -> lift (a,b,c,d)) -- | Take five arrays and return an array of five-tuples, analogous to zip. -- zip5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh (a, b, c, d, e)) zip5 = zipWith5 (\a b c d e -> lift (a,b,c,d,e)) -- | Take six arrays and return an array of six-tuples, analogous to zip. -- zip6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh (a, b, c, d, e, f)) zip6 = zipWith6 (\a b c d e f -> lift (a,b,c,d,e,f)) -- | Take seven arrays and return an array of seven-tuples, analogous to zip. -- zip7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh (a, b, c, d, e, f, g)) zip7 = zipWith7 (\a b c d e f g -> lift (a,b,c,d,e,f,g)) -- | Take seven arrays and return an array of seven-tuples, analogous to zip. -- zip8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh (a, b, c, d, e, f, g, h)) zip8 = zipWith8 (\a b c d e f g h -> lift (a,b,c,d,e,f,g,h)) -- | Take seven arrays and return an array of seven-tuples, analogous to zip. -- zip9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) -> Acc (Array sh (a, b, c, d, e, f, g, h, i)) zip9 = zipWith9 (\a b c d e f g h i -> lift (a,b,c,d,e,f,g,h,i)) -- | The converse of 'zip', but the shape of the two results is identical to the -- shape of the argument. -- -- If the argument array is manifest in memory, 'unzip' is a no-op. -- unzip :: (Shape sh, Elt a, Elt b) => Acc (Array sh (a, b)) -> (Acc (Array sh a), Acc (Array sh b)) unzip arr = (map fst arr, map snd arr) -- | Take an array of triples and return three arrays, analogous to 'unzip'. -- unzip3 :: (Shape sh, Elt a, Elt b, Elt c) => Acc (Array sh (a, b, c)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)) unzip3 xs = (map get1 xs, map get2 xs, map get3 xs) where get1 x = let (a,_,_) = untup3 x in a get2 x = let (_,b,_) = untup3 x in b get3 x = let (_,_,c) = untup3 x in c -- | Take an array of quadruples and return four arrays, analogous to 'unzip'. -- unzip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => Acc (Array sh (a, b, c, d)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d)) unzip4 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs) where get1 x = let (a,_,_,_) = untup4 x in a get2 x = let (_,b,_,_) = untup4 x in b get3 x = let (_,_,c,_) = untup4 x in c get4 x = let (_,_,_,d) = untup4 x in d -- | Take an array of 5-tuples and return five arrays, analogous to 'unzip'. -- unzip5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => Acc (Array sh (a, b, c, d, e)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d), Acc (Array sh e)) unzip5 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs, map get5 xs) where get1 x = let (a,_,_,_,_) = untup5 x in a get2 x = let (_,b,_,_,_) = untup5 x in b get3 x = let (_,_,c,_,_) = untup5 x in c get4 x = let (_,_,_,d,_) = untup5 x in d get5 x = let (_,_,_,_,e) = untup5 x in e -- | Take an array of 6-tuples and return six arrays, analogous to 'unzip'. -- unzip6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => Acc (Array sh (a, b, c, d, e, f)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f)) unzip6 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs, map get5 xs, map get6 xs) where get1 x = let (a,_,_,_,_,_) = untup6 x in a get2 x = let (_,b,_,_,_,_) = untup6 x in b get3 x = let (_,_,c,_,_,_) = untup6 x in c get4 x = let (_,_,_,d,_,_) = untup6 x in d get5 x = let (_,_,_,_,e,_) = untup6 x in e get6 x = let (_,_,_,_,_,f) = untup6 x in f -- | Take an array of 7-tuples and return seven arrays, analogous to 'unzip'. -- unzip7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => Acc (Array sh (a, b, c, d, e, f, g)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f) , Acc (Array sh g)) unzip7 xs = ( map get1 xs, map get2 xs, map get3 xs , map get4 xs, map get5 xs, map get6 xs , map get7 xs ) where get1 x = let (a,_,_,_,_,_,_) = untup7 x in a get2 x = let (_,b,_,_,_,_,_) = untup7 x in b get3 x = let (_,_,c,_,_,_,_) = untup7 x in c get4 x = let (_,_,_,d,_,_,_) = untup7 x in d get5 x = let (_,_,_,_,e,_,_) = untup7 x in e get6 x = let (_,_,_,_,_,f,_) = untup7 x in f get7 x = let (_,_,_,_,_,_,g) = untup7 x in g -- | Take an array of 8-tuples and return eight arrays, analogous to 'unzip'. -- unzip8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => Acc (Array sh (a, b, c, d, e, f, g, h)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f) , Acc (Array sh g), Acc (Array sh h) ) unzip8 xs = ( map get1 xs, map get2 xs, map get3 xs , map get4 xs, map get5 xs, map get6 xs , map get7 xs, map get8 xs ) where get1 x = let (a,_,_,_,_,_,_,_) = untup8 x in a get2 x = let (_,b,_,_,_,_,_,_) = untup8 x in b get3 x = let (_,_,c,_,_,_,_,_) = untup8 x in c get4 x = let (_,_,_,d,_,_,_,_) = untup8 x in d get5 x = let (_,_,_,_,e,_,_,_) = untup8 x in e get6 x = let (_,_,_,_,_,f,_,_) = untup8 x in f get7 x = let (_,_,_,_,_,_,g,_) = untup8 x in g get8 x = let (_,_,_,_,_,_,_,h) = untup8 x in h -- | Take an array of 8-tuples and return eight arrays, analogous to 'unzip'. -- unzip9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => Acc (Array sh (a, b, c, d, e, f, g, h, i)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f) , Acc (Array sh g), Acc (Array sh h), Acc (Array sh i)) unzip9 xs = ( map get1 xs, map get2 xs, map get3 xs , map get4 xs, map get5 xs, map get6 xs , map get7 xs, map get8 xs, map get9 xs ) where get1 x = let (a,_,_,_,_,_,_,_,_) = untup9 x in a get2 x = let (_,b,_,_,_,_,_,_,_) = untup9 x in b get3 x = let (_,_,c,_,_,_,_,_,_) = untup9 x in c get4 x = let (_,_,_,d,_,_,_,_,_) = untup9 x in d get5 x = let (_,_,_,_,e,_,_,_,_) = untup9 x in e get6 x = let (_,_,_,_,_,f,_,_,_) = untup9 x in f get7 x = let (_,_,_,_,_,_,g,_,_) = untup9 x in g get8 x = let (_,_,_,_,_,_,_,h,_) = untup9 x in h get9 x = let (_,_,_,_,_,_,_,_,i) = untup9 x in i -- Reductions -- ---------- -- | Reduction of an array of arbitrary rank to a single scalar value. The first -- argument needs to be an /associative/ function to enable efficient parallel -- implementation. The initial element does not need to be an identity element. -- -- >>> let vec = fromList (Z:.10) [0..] -- >>> foldAll (+) 42 (use vec) -- Scalar Z [87] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> foldAll (+) 0 (use mat) -- Scalar Z [1225] -- foldAll :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array sh a) -> Acc (Scalar a) foldAll f e arr = fold f e (flatten arr) -- | Variant of 'foldAll' that requires the reduced array to be non-empty and -- does not need a default value. The first argument must be an /associative/ -- function. -- fold1All :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array sh a) -> Acc (Scalar a) fold1All f arr = fold1 f (flatten arr) -- Specialised reductions -- ---------------------- -- -- Leave the results of these as scalar arrays to make it clear that these are -- array computations, and thus can not be nested. -- | Check if all elements along the innermost dimension satisfy a predicate. -- -- >>> let mat = fromList (Z :. 4 :. 10) [1,2,3,4,5,6,7,8,9,10,1,1,1,1,1,2,2,2,2,2,2,4,6,8,10,12,14,16,18,20,1,3,5,7,9,11,13,15,17,19] :: Array DIM2 Int -- >>> mat -- Matrix (Z :. 4 :. 10) -- [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -- 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, -- 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, -- 1, 3, 5, 7, 9, 11, 13, 15, 17, 19] -- -- >>> all even (use mat) -- Vector (Z :. 4) [False,False,True,False] -- all :: (Shape sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array (sh:.Int) e) -> Acc (Array sh Bool) all f = and . map f -- | Check if any element along the innermost dimension satisfies the predicate. -- -- >>> let mat = fromList (Z :. 4 :. 10) [1,2,3,4,5,6,7,8,9,10,1,1,1,1,1,2,2,2,2,2,2,4,6,8,10,12,14,16,18,20,1,3,5,7,9,11,13,15,17,19] :: Array DIM2 Int -- >>> mat -- Matrix (Z :. 4 :. 10) -- [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -- 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, -- 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, -- 1, 3, 5, 7, 9, 11, 13, 15, 17, 19] -- -- >>> any even (use mat) -- Vector (Z :. 4) [True,True,True,False] -- any :: (Shape sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array (sh:.Int) e) -> Acc (Array sh Bool) any f = or . map f -- | Check if all elements along the innermost dimension are 'True'. -- and :: Shape sh => Acc (Array (sh:.Int) Bool) -> Acc (Array sh Bool) and = fold (&&) (constant True) -- | Check if any element along the innermost dimension is 'True'. -- or :: Shape sh => Acc (Array (sh:.Int) Bool) -> Acc (Array sh Bool) or = fold (||) (constant False) -- | Compute the sum of elements along the innermost dimension of the array. To -- find the sum of the entire array, 'flatten' it first. -- -- >>> let mat = fromList (Z:.2:.5) [0..] -- Vector (Z :. 2) [10,35] -- sum :: (Shape sh, Num e) => Acc (Array (sh:.Int) e) -> Acc (Array sh e) sum = fold (+) 0 -- | Compute the product of the elements along the innermost dimension of the -- array. To find the product of the entire array, 'flatten' it first. -- -- >>> let mat = fromList (Z:.2:.5) [0..] -- Vector (Z :. 2) [0,15120] -- product :: (Shape sh, Num e) => Acc (Array (sh:.Int) e) -> Acc (Array sh e) product = fold (*) 1 -- | Yield the minimum element along the innermost dimension of the array. To -- find find the minimum element of the entire array, 'flatten' it first. -- -- The array must not be empty. See also 'fold1'. -- -- >>> let mat = fromList (Z :. 3 :. 4) [1,4,3,8, 0,2,8,4, 7,9,8,8] -- >>> mat -- Matrix (Z :. 3 :. 4) -- [ 1, 4, 3, 8, -- 0, 2, 8, 4, -- 7, 9, 8, 8] -- -- >>> minimum (use mat) -- Vector (Z :. 3) [1,0,7] -- minimum :: (Shape sh, Ord e) => Acc (Array (sh:.Int) e) -> Acc (Array sh e) minimum = fold1 min -- | Yield the maximum element along the innermost dimension of the array. To -- find the maximum element of the entire array, 'flatten' it first. -- -- The array must not be empty. See also 'fold1'. -- -- >>> let mat = fromList (Z :. 3 :. 4) [1,4,3,8, 0,2,8,4, 7,9,8,8] -- >>> mat -- Matrix (Z :. 3 :. 4) -- [ 1, 4, 3, 8, -- 0, 2, 8, 4, -- 7, 9, 8, 8] -- -- >>> maximum (use mat) -- Vector (Z :. 3) [8,8,9] -- maximum :: (Shape sh, Ord e) => Acc (Array (sh:.Int) e) -> Acc (Array sh e) maximum = fold1 max -- Composite scans -- --------------- -- | Left-to-right pre-scan (aka exclusive scan). As for 'scan', the first -- argument must be an /associative/ function. Denotationally, we have: -- -- > prescanl f e = afst . scanl' f e -- -- >>> let vec = fromList (Z:.10) [1..10] -- >>> prescanl (+) 0 (use vec) -- Vector (Z :. 10) [0,1,3,6,10,15,21,28,36,45] -- prescanl :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) prescanl f e = afst . scanl' f e -- | Left-to-right post-scan, a variant of 'scanl1' with an initial value. As -- with 'scanl1', the array must not be empty. Denotationally, we have: -- -- > postscanl f e = map (e `f`) . scanl1 f -- -- >>> let vec = fromList (Z:.10) [1..10] -- >>> postscanl (+) 42 (use vec) -- Vector (Z :. 10) [43,45,48,52,57,63,70,78,87,97] -- postscanl :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) postscanl f e = map (e `f`) . scanl1 f -- | Right-to-left pre-scan (aka exclusive scan). As for 'scan', the first -- argument must be an /associative/ function. Denotationally, we have: -- -- > prescanr f e = afst . scanr' f e -- prescanr :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) prescanr f e = afst . scanr' f e -- | Right-to-left postscan, a variant of 'scanr1' with an initial value. -- Denotationally, we have: -- -- > postscanr f e = map (e `f`) . scanr1 f -- postscanr :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int) a) postscanr f e = map (`f` e) . scanr1 f -- Segmented scans -- --------------- -- | Segmented version of 'scanl' along the innermost dimension of an array. The -- innermost dimension must have at least as many elements as the sum of the -- segment descriptor. -- -- >>> let seg = fromList (Z:.4) [1,4,0,3] -- >>> seg -- Vector (Z :. 4) [1,4,0,3] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> scanlSeg (+) 0 (use mat) (use seg) -- Matrix (Z :. 5 :. 12) -- [ 0, 0, 0, 1, 3, 6, 10, 0, 0, 5, 11, 18, -- 0, 10, 0, 11, 23, 36, 50, 0, 0, 15, 31, 48, -- 0, 20, 0, 21, 43, 66, 90, 0, 0, 25, 51, 78, -- 0, 30, 0, 31, 63, 96, 130, 0, 0, 35, 71, 108, -- 0, 40, 0, 41, 83, 126, 170, 0, 0, 45, 91, 138] -- scanlSeg :: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) scanlSeg f z arr seg = if null arr || null flags then fill (lift (sh:.sz + length seg)) z else scanl1Seg f arr' seg' where sh :. sz = unlift (shape arr) :: Exp sh :. Exp Int -- Segmented exclusive scan is implemented by first injecting the seed -- element at the head of each segment, and then performing a segmented -- inclusive scan. -- -- This is done by creating a creating a vector entirely of the seed -- element, and overlaying the input data in all places other than at the -- start of a segment. -- seg' = map (+1) seg arr' = permute const (fill (lift (sh :. sz + length seg)) z) (\ix -> let sx :. i = unlift ix :: Exp sh :. Exp Int in lift (sx :. i + fromIntegral (inc ! index1 i))) (take (length flags) arr) -- Each element in the segments must be shifted to the right one additional -- place for each successive segment, to make room for the seed element. -- Here, we make use of the fact that the vector returned by 'mkHeadFlags' -- contains non-unit entries, which indicate zero length segments. -- flags = mkHeadFlags seg inc = scanl1 (+) flags -- | Segmented version of 'scanl'' along the innermost dimension of an array. The -- innermost dimension must have at least as many elements as the sum of the -- segment descriptor. -- -- The first element of the resulting tuple is a vector of scanned values. The -- second element is a vector of segment scan totals and has the same size as -- the segment vector. -- -- >>> let seg = fromList (Z:.4) [1,4,0,3] -- >>> seg -- Vector (Z :. 4) [1,4,0,3] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> let (res,sums) = scanl'Seg (+) 0 (use mat) (use seg) -- >>> res -- Matrix (Z :. 5 :. 8) -- [ 0, 0, 1, 3, 6, 0, 5, 11, -- 0, 0, 11, 23, 36, 0, 15, 31, -- 0, 0, 21, 43, 66, 0, 25, 51, -- 0, 0, 31, 63, 96, 0, 35, 71, -- 0, 0, 41, 83, 126, 0, 45, 91] -- >>> sums -- Matrix (Z :. 5 :. 4) -- [ 0, 10, 0, 18, -- 10, 50, 0, 48, -- 20, 90, 0, 78, -- 30, 130, 0, 108, -- 40, 170, 0, 138] -- scanl'Seg :: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e, Array (sh:.Int) e) scanl'Seg f z arr seg = if null arr then lift (arr, fill (lift (indexTail (shape arr) :. length seg)) z) else lift (body, sums) where -- Segmented scan' is implemented by deconstructing a segmented exclusive -- scan, to separate the final value and scan body. -- -- TLM: Segmented scans, and this version in particular, expend a lot of -- effort scanning flag arrays. On inspection it appears that several -- of these operations are duplicated, but this will not be picked up -- by sharing _observation_. Perhaps a global CSE-style pass would be -- beneficial. -- arr' = scanlSeg f z arr seg -- Extract the final reduction value for each segment, which is at the last -- index of each segment. -- seg' = map (+1) seg tails = zipWith (+) seg $ prescanl (+) 0 seg' sums = backpermute (lift (indexTail (shape arr') :. length seg)) (\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int in lift (sz :. fromIntegral (tails ! index1 i))) arr' -- Slice out the body of each segment. -- -- Build a head-flags representation based on the original segment -- descriptor. This contains the target length of each of the body segments, -- which is one fewer element than the actual bodies stored in arr'. Thus, -- the flags align with the last element of each body section, and when -- scanned, this element will be incremented over. -- offset = scanl1 (+) seg inc = scanl1 (+) $ permute (+) (fill (index1 $ size arr + 1) 0) (\ix -> index1' $ offset ! ix) (fill (shape seg) (1 :: Exp i)) len = offset ! index1 (length offset - 1) body = backpermute (lift (indexTail (shape arr) :. fromIntegral len)) (\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int in lift (sz :. i + fromIntegral (inc ! index1 i))) arr' -- | Segmented version of 'scanl1' along the innermost dimension. -- -- As with 'scanl1', the total number of elements considered, in this case given -- by the 'sum' of segment descriptor, must not be zero. The input vector must -- contain at least this many elements. -- -- Zero length segments are allowed, and the behaviour is as if those entries -- were not present in the segment descriptor; that is: -- -- > scanl1Seg f xs [n,0,0] == scanl1Seg f xs [n] where n /= 0 -- -- >>> let seg = fromList (Z:.4) [1,4,0,3] -- >>> seg -- Vector (Z :. 4) [1,4,0,3] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> scanl1Seg (+) (use mat) (use seg) -- Matrix (Z :. 5 :. 8) -- [ 0, 1, 3, 6, 10, 5, 11, 18, -- 10, 11, 23, 36, 50, 15, 31, 48, -- 20, 21, 43, 66, 90, 25, 51, 78, -- 30, 31, 63, 96, 130, 35, 71, 108, -- 40, 41, 83, 126, 170, 45, 91, 138] -- scanl1Seg :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) scanl1Seg f arr seg = map snd . scanl1 (segmented f) $ zip (replicate (lift (indexTail (shape arr) :. All)) (mkHeadFlags seg)) arr -- |Segmented version of 'prescanl'. -- prescanlSeg :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) prescanlSeg f e vec seg = afst $ scanl'Seg f e vec seg -- |Segmented version of 'postscanl'. -- postscanlSeg :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) postscanlSeg f e vec seg = map (f e) $ scanl1Seg f vec seg -- | Segmented version of 'scanr' along the innermost dimension of an array. The -- innermost dimension must have at least as many elements as the sum of the -- segment descriptor. -- -- >>> let seg = fromList (Z:.4) [1,4,0,3] -- >>> seg -- Vector (Z :. 4) [1,4,0,3] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> scanrSeg (+) 0 (use mat) (use seg) -- Matrix (Z :. 5 :. 12) -- [ 2, 0, 18, 15, 11, 6, 0, 0, 24, 17, 9, 0, -- 12, 0, 58, 45, 31, 16, 0, 0, 54, 37, 19, 0, -- 22, 0, 98, 75, 51, 26, 0, 0, 84, 57, 29, 0, -- 32, 0, 138, 105, 71, 36, 0, 0, 114, 77, 39, 0, -- 42, 0, 178, 135, 91, 46, 0, 0, 144, 97, 49, 0] -- scanrSeg :: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) scanrSeg f z arr seg = if null arr || null flags then fill (lift (sh :. sz + length seg)) z else scanr1Seg f arr' seg' where sh :. sz = unlift (shape arr) :: Exp sh :. Exp Int -- Using technique described for 'scanlSeg', where we intersperse the array -- with the seed element at the start of each segment, and then perform an -- inclusive segmented scan. -- flags = mkHeadFlags seg inc = scanl1 (+) flags seg' = map (+1) seg arr' = permute const (fill (lift (sh :. sz + length seg)) z) (\ix -> let sx :. i = unlift ix :: Exp sh :. Exp Int in lift (sx :. i + fromIntegral (inc ! index1 i) - 1)) (drop (sz - length flags) arr) -- | Segmented version of 'scanr''. -- -- >>> let seg = fromList (Z:.4) [1,4,0,3] -- >>> seg -- Vector (Z :. 4) [1,4,0,3] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> let (res,sums) = scanr'Seg (+) 0 (use mat) (use seg) -- >>> res -- Matrix (Z :. 5 :. 8) -- [ 0, 15, 11, 6, 0, 17, 9, 0, -- 0, 45, 31, 16, 0, 37, 19, 0, -- 0, 75, 51, 26, 0, 57, 29, 0, -- 0, 105, 71, 36, 0, 77, 39, 0, -- 0, 135, 91, 46, 0, 97, 49, 0] -- >>> sums -- Matrix (Z :. 5 :. 4) -- [ 2, 18, 0, 24, -- 12, 58, 0, 54, -- 22, 98, 0, 84, -- 32, 138, 0, 114, -- 42, 178, 0, 144] -- scanr'Seg :: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e, Array (sh:.Int) e) scanr'Seg f z arr seg = if null arr then lift (arr, fill (lift (indexTail (shape arr) :. length seg)) z) else lift (body, sums) where -- Using technique described for scanl'Seg -- arr' = scanrSeg f z arr seg -- reduction values seg' = map (+1) seg heads = prescanl (+) 0 seg' sums = backpermute (lift (indexTail (shape arr') :. length seg)) (\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int in lift (sz :. fromIntegral (heads ! index1 i))) arr' -- body segments flags = mkHeadFlags seg inc = scanl1 (+) flags body = backpermute (lift (indexTail (shape arr) :. indexHead (shape flags))) (\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int in lift (sz :. i + fromIntegral (inc ! index1 i))) arr' -- | Segmented version of 'scanr1'. -- -- >>> let seg = fromList (Z:.4) [1,4,0,3] -- >>> seg -- Vector (Z :. 4) [1,4,0,3] -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> scanr1Seg (+) (use mat) (use seg) -- Matrix (Z :. 5 :. 8) -- [ 0, 10, 9, 7, 4, 18, 13, 7, -- 10, 50, 39, 27, 14, 48, 33, 17, -- 20, 90, 69, 47, 24, 78, 53, 27, -- 30, 130, 99, 67, 34, 108, 73, 37, -- 40, 170, 129, 87, 44, 138, 93, 47] -- scanr1Seg :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) scanr1Seg f arr seg = map snd . scanr1 (flip (segmented f)) $ zip (replicate (lift (indexTail (shape arr) :. All)) (mkTailFlags seg)) arr -- |Segmented version of 'prescanr'. -- prescanrSeg :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) prescanrSeg f e vec seg = afst $ scanr'Seg f e vec seg -- |Segmented version of 'postscanr'. -- postscanrSeg :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) -> Acc (Segments i) -> Acc (Array (sh:.Int) e) postscanrSeg f e vec seg = map (f e) $ scanr1Seg f vec seg -- Segmented scan helpers -- ---------------------- -- |Compute head flags vector from segment vector for left-scans. -- -- The vector will be full of zeros in the body of a segment, and non-zero -- otherwise. The "flag" value, if greater than one, indicates that several -- empty segments are represented by this single flag entry. This is additional -- data is used by exclusive segmented scan. -- mkHeadFlags :: (Integral i, FromIntegral i Int) => Acc (Segments i) -> Acc (Segments i) mkHeadFlags seg = init $ permute (+) zeros (\ix -> index1' (offset ! ix)) ones where (offset, len) = unlift (scanl' (+) 0 seg) zeros = fill (index1' $ the len + 1) 0 ones = fill (index1 $ size offset) 1 -- |Compute tail flags vector from segment vector for right-scans. That is, the -- flag is placed at the last place in each segment. -- mkTailFlags :: (Integral i, FromIntegral i Int) => Acc (Segments i) -> Acc (Segments i) mkTailFlags seg = init $ permute (+) zeros (\ix -> index1' (the len - 1 - offset ! ix)) ones where (offset, len) = unlift (scanr' (+) 0 seg) zeros = fill (index1' $ the len + 1) 0 ones = fill (index1 $ size offset) 1 -- |Construct a segmented version of a function from a non-segmented version. -- The segmented apply operates on a head-flag value tuple, and follows the -- procedure of Sengupta et. al. -- segmented :: (Elt e, Num i, Bits i) => (Exp e -> Exp e -> Exp e) -> Exp (i, e) -> Exp (i, e) -> Exp (i, e) segmented f a b = let (aF, aV) = unlift a (bF, bV) = unlift b in lift (aF .|. bF, bF /= 0 ? (bV, f aV bV)) -- |Index construction and destruction generalised to integral types. -- -- We generalise the segment descriptor to integral types because some -- architectures, such as GPUs, have poor performance for 64-bit types. So, -- there is a tension between performance and requiring 64-bit indices for some -- applications, and we would not like to restrict ourselves to either one. -- -- As we don't yet support non-Int dimensions in shapes, we will need to convert -- back to concrete Int. However, don't put these generalised forms into the -- base library, because it results in too many ambiguity errors. -- index1' :: (Integral i, FromIntegral i Int) => Exp i -> Exp DIM1 index1' i = lift (Z :. fromIntegral i) -- Reshaping of arrays -- ------------------- -- | Flatten the given array of arbitrary dimension into a one-dimensional -- vector. As with 'reshape', this operation performs no work. -- flatten :: forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Acc (Vector e) flatten a | Just Refl <- matchShapeType (undefined::sh) (undefined::DIM1) = a flatten a = reshape (index1 (size a)) a -- Enumeration and filling -- ----------------------- -- | Create an array where all elements are the same value. -- -- >>> let zeros = fill (Z:.10) 0 -- Vector (Z :. 10) [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] -- fill :: (Shape sh, Elt e) => Exp sh -> Exp e -> Acc (Array sh e) fill sh c = generate sh (const c) -- | Create an array of the given shape containing the values @x@, @x+1@, etc. -- (in row-major order). -- -- >>> enumFromN (constant (Z:.5:.10)) 0 :: Array DIM2 Int -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- enumFromN :: (Shape sh, Num e, FromIntegral Int e) => Exp sh -> Exp e -> Acc (Array sh e) enumFromN sh x = enumFromStepN sh x 1 -- | Create an array of the given shape containing the values @x@, @x+y@, -- @x+y+y@ etc. (in row-major order). -- -- >>> enumFromStepN (constant (Z:.5:.10)) 0 0.5 :: Array DIM2 Float -- Matrix (Z :. 5 :. 10) -- [ 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, -- 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, -- 10.0, 10.5, 11.0, 11.5, 12.0, 12.5, 13.0, 13.5, 14.0, 14.5, -- 15.0, 15.5, 16.0, 16.5, 17.0, 17.5, 18.0, 18.5, 19.0, 19.5, -- 20.0, 20.5, 21.0, 21.5, 22.0, 22.5, 23.0, 23.5, 24.0, 24.5] -- enumFromStepN :: (Shape sh, Num e, FromIntegral Int e) => Exp sh -> Exp e -- ^ x: start -> Exp e -- ^ y: step -> Acc (Array sh e) enumFromStepN sh x y = reshape sh $ generate (index1 $ shapeSize sh) (\ix -> (fromIntegral (unindex1 ix :: Exp Int) * y) + x) -- Concatenation -- ------------- -- | Concatenate innermost component of two arrays. The extent of the lower -- dimensional component is the intersection of the two arrays. -- -- >>> let m1 = fromList (Z:.5:.10) [0..] -- >>> m1 -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> let m2 = fromList (Z:.10:.3) [0..] -- >>> m2 -- Matrix (Z :. 10 :. 3) -- [ 0, 1, 2, -- 3, 4, 5, -- 6, 7, 8, -- 9, 10, 11, -- 12, 13, 14, -- 15, 16, 17, -- 18, 19, 20, -- 21, 22, 23, -- 24, 25, 26, -- 27, 28, 29] -- -- >>> use m1 ++ use m2 -- Matrix (Z :. 5 :. 13) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3, 4, 5, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 6, 7, 8, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 9, 10, 11, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 12, 13, 14] -- infixr 5 ++ (++) :: forall sh e. (Slice sh, Shape sh, Elt e) => Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) (++) xs ys = let sh1 :. n = unlift (shape xs) :: Exp sh :. Exp Int sh2 :. m = unlift (shape ys) :: Exp sh :. Exp Int in generate (lift (intersect sh1 sh2 :. n + m)) (\ix -> let sh :. i = unlift ix :: Exp sh :. Exp Int in i < n ? ( xs ! ix, ys ! lift (sh :. i-n)) ) -- TLM: If we have something like (concat . split) then the source array will -- have two use sites, but is actually safe (and better) to inline. -- Filtering -- --------- -- | Drop elements that do not satisfy the predicate. Returns the elements which -- pass the predicate, together with a segment descriptor indicating how many -- elements along each outer dimension were valid. -- -- >>> let vec = fromList (Z :. 10) [1..10] :: Vector Int -- >>> vec -- Vector (Z :. 10) [1,2,3,4,5,6,7,8,9,10] -- -- >>> filter even (use vec) -- (Vector (Z :. 5) [2,4,6,8,10], Scalar Z [5]) -- -- >>> let mat = fromList (Z :. 4 :. 10) [1,2,3,4,5,6,7,8,9,10,1,1,1,1,1,2,2,2,2,2,2,4,6,8,10,12,14,16,18,20,1,3,5,7,9,11,13,15,17,19] :: Array DIM2 Int -- >>> mat -- Matrix (Z :. 4 :. 10) -- [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -- 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, -- 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, -- 1, 3, 5, 7, 9, 11, 13, 15, 17, 19] -- -- >>> filter odd (use mat) -- (Vector (Z :. 20) [1,3,5,7,9,1,1,1,1,1,1,3,5,7,9,11,13,15,17,19], Vector (Z :. 4) [5,5,0,10]) -- filter :: forall sh e. (Shape sh, Slice sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array (sh:.Int) e) -> Acc (Vector e, Array sh Int) filter p arr -- Optimise 1-dimensional arrays, where we can avoid additional computations -- for the offset indices. | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = let keep = map p arr (target, len) = unlift $ scanl' (+) 0 (map boolToInt keep) prj ix = keep!ix ? ( index1 (target!ix), ignore ) dummy = backpermute (index1 (the len)) id arr result = permute const dummy prj arr in if null arr then lift (emptyArray, fill (constant Z) 0) else lift (result, len) filter p arr = let sz = indexTail (shape arr) keep = map p arr (target, len) = unlift $ scanl' (+) 0 (map boolToInt keep) (offset, valid) = unlift $ scanl' (+) 0 (flatten len) prj ix = if keep!ix then index1 $ offset!index1 (toIndex sz (indexTail ix)) + target!ix else ignore dummy = backpermute (index1 (the valid)) id (flatten arr) result = permute const dummy prj arr in if null arr then lift (emptyArray, fill sz 0) else lift (result, len) -- FIXME: [Permute in the filter operation] -- -- This is abusing 'permute' in that the first two arguments, the combination -- function and array of default values, are only justified because we know the -- permutation function will write to each location in the target exactly once. -- -- Instead, we should have a primitive that directly encodes the compaction -- pattern of the permutation function. This may be more efficient to compute, -- and avoids the computation of the defaults array, which is ultimately wasted -- work. -- {-# NOINLINE filter #-} {-# RULES "ACC filter/filter" forall f g arr. filter f (afst (filter g arr)) = filter (\x -> g x && f x) arr #-} -- Gather operations -- ----------------- -- | Gather elements from a source array by reading values at the given indices. -- -- >>> let input = fromList (Z:.9) [1,9,6,4,4,2,0,1,2] -- >>> let from = fromList (Z:.6) [1,3,7,2,5,3] -- >>> gather (use from) (use input) -- Vector (Z :. 6) [9,4,1,6,2,4] -- gather :: (Shape sh, Elt e) => Acc (Array sh Int) -- ^ index of source at each index to gather -> Acc (Vector e) -- ^ source values -> Acc (Array sh e) gather indices input = map (input !!) indices -- TLM NOTES: -- * (!!) has potential for later optimisation -- * We needn't fix the source array to Vector, but this matches the -- intuition that 'Int' ~ 'DIM1'. -- | Conditionally copy elements from source array to destination array -- according to an index mapping. -- -- In addition, the 'mask' vector and associated predication function specifies -- whether the element is copied or a default value is used instead. -- -- >>> let defaults = fromList (Z :. 6) [6,6,6,6,6,6] -- >>> let from = fromList (Z :. 6) [1,3,7,2,5,3] -- >>> let mask = fromList (Z :. 6) [3,4,9,2,7,5] -- >>> let input = fromList (Z :. 9) [1,9,6,4,4,2,0,1,2] -- >>> gatherIf (use from) (use mask) (> 4) (use defaults) (use input) -- Vector (Z :. 6) [6,6,1,6,2,4] -- gatherIf :: (Elt a, Elt b) => Acc (Vector Int) -- ^ source indices to gather from -> Acc (Vector a) -- ^ mask vector -> (Exp a -> Exp Bool) -- ^ predicate function -> Acc (Vector b) -- ^ default values -> Acc (Vector b) -- ^ source values -> Acc (Vector b) gatherIf from maskV pred defaults input = zipWith zf pf gatheredV where zf p g = p ? (unlift g) gatheredV = zip (gather from input) defaults pf = map pred maskV -- Scatter operations -- ------------------ -- | Overwrite elements of the destination by scattering the values of the -- source array according to the given index mapping. -- -- Note that if the destination index appears more than once in the mapping the -- result is undefined. -- -- >>> let to = fromList (Z :. 6) [1,3,7,2,5,8] -- >>> let input = fromList (Z :. 7) [1,9,6,4,4,2,5] -- >>> scatter (use to) (fill (constant (Z:.10)) 0) (use input) -- Vector (Z :. 10) [0,1,4,9,0,4,0,6,2,0] -- scatter :: Elt e => Acc (Vector Int) -- ^ destination indices to scatter into -> Acc (Vector e) -- ^ default values -> Acc (Vector e) -- ^ source values -> Acc (Vector e) scatter to defaults input = permute const defaults pf input' where pf ix = index1 (to ! ix) input' = backpermute (shape to `intersect` shape input) id input -- | Conditionally overwrite elements of the destination by scattering values of -- the source array according to a given index mapping, whenever the masking -- function resolves to 'True'. -- -- Note that if the destination index appears more than once in the mapping the -- result is undefined. -- -- >>> let to = fromList (Z :. 6) [1,3,7,2,5,8] -- >>> let mask = fromList (Z :. 6) [3,4,9,2,7,5] -- >>> let input = fromList (Z :. 7) [1,9,6,4,4,2,5] -- >>> scatterIf (use to) (use mask) (> 4) (fill (constant (Z:.10)) 0) (use input) -- Vector (Z :. 10) [0,0,0,0,0,4,0,6,2,0] -- scatterIf :: (Elt e, Elt e') => Acc (Vector Int) -- ^ destination indices to scatter into -> Acc (Vector e) -- ^ mask vector -> (Exp e -> Exp Bool) -- ^ predicate function -> Acc (Vector e') -- ^ default values -> Acc (Vector e') -- ^ source values -> Acc (Vector e') scatterIf to maskV pred defaults input = permute const defaults pf input' where pf ix = pred (maskV ! ix) ? ( index1 (to ! ix), ignore ) input' = backpermute (shape to `intersect` shape input) id input -- Permutations -- ------------ -- | Reverse the elements of a vector. -- reverse :: Elt e => Acc (Vector e) -> Acc (Vector e) reverse xs = let len = unindex1 (shape xs) pf i = len - i - 1 in backpermute (shape xs) (ilift1 pf) xs -- | Transpose the rows and columns of a matrix. -- transpose :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e) transpose mat = let swap = lift1 $ \(Z:.x:.y) -> Z:.y:.x :: Z:.Exp Int:.Exp Int in backpermute (swap $ shape mat) swap mat -- Extracting sub-vectors -- ---------------------- -- | Yield the first @n@ elements in the innermost dimension of the array (plus -- all lower dimensional elements). -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> take 5 (use mat) -- Matrix (Z :. 5 :. 5) -- [ 0, 1, 2, 3, 4, -- 10, 11, 12, 13, 14, -- 20, 21, 22, 23, 24, -- 30, 31, 32, 33, 34, -- 40, 41, 42, 43, 44] -- take :: forall sh e. (Slice sh, Shape sh, Elt e) => Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) take n acc = let n' = the (unit (n `min` sz)) sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int in backpermute (lift (sh :. n')) id acc -- | Yield all but the first @n@ elements along the innermost dimension of the -- array (plus all lower dimensional elements). -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> drop 7 (use mat) -- Matrix (Z :. 5 :. 3) -- [ 7, 8, 9, -- 17, 18, 19, -- 27, 28, 29, -- 37, 38, 39, -- 47, 48, 49] -- drop :: forall sh e. (Slice sh, Shape sh, Elt e) => Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) drop n acc = let n' = the (unit n) sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int index ix = let j :. i = unlift ix :: Exp sh :. Exp Int in lift (j :. i + n') in backpermute (lift (sh :. 0 `max` (sz - n'))) index acc -- | Yield all but the elements in the last index of the innermost dimension. -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> init (use mat) -- Matrix (Z :. 5 :. 9) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, -- 40, 41, 42, 43, 44, 45, 46, 47, 48] -- init :: forall sh e. (Slice sh, Shape sh, Elt e) => Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) init acc = let sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int in backpermute (lift (sh :. sz `min` (sz - 1))) id acc -- | Yield all but the first element along the innermost dimension of an array. -- The innermost dimension must not be empty. -- -- >>> let mat = fromList (Z:.5:.10) [0..] -- >>> mat -- Matrix (Z :. 5 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] -- -- >>> tail (use mat) -- Matrix (Z :. 5 :. 9) -- [ 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 11, 12, 13, 14, 15, 16, 17, 18, 19, -- 21, 22, 23, 24, 25, 26, 27, 28, 29, -- 31, 32, 33, 34, 35, 36, 37, 38, 39, -- 41, 42, 43, 44, 45, 46, 47, 48, 49] -- tail :: forall sh e. (Slice sh, Shape sh, Elt e) => Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) tail acc = let sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int index ix = let j :. i = unlift ix :: Exp sh :. Exp Int in lift (j :. i + 1) in backpermute (lift (sh :. 0 `max` (sz - 1))) index acc -- | Yield a slit (slice) of the innermost indices of an array. Denotationally, -- we have: -- -- > slit i n = take n . drop i -- slit :: forall sh e. (Slice sh, Shape sh, Elt e) => Exp Int -> Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) slit m n acc = let m' = the (unit m) n' = the (unit n) sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int index ix = let j :. i = unlift ix :: Exp sh :. Exp Int in lift (j :. i + m') in backpermute (lift (sh :. (n' `min` ((sz - m') `max` 0)))) index acc -- Controlling execution -- --------------------- -- | Force an array expression to be evaluated, preventing it from fusing with -- other operations. Forcing operations to be computed to memory, rather than -- being fused into their consuming function, can sometimes improve performance. -- For example, computing a matrix 'transpose' could provide better memory -- locality for the subsequent operation. Preventing fusion to split large -- operations into several simpler steps could also help by reducing register -- pressure. -- -- Preventing fusion also means that the individual operations are available to -- be executed concurrently with other kernels. In particular, consider using -- this if you have a series of operations that are compute bound rather than -- memory bound. -- -- Here is the synthetic example: -- -- > loop :: Exp Int -> Exp Int -- > loop ticks = -- > let clockRate = 900000 -- kHz -- > in while (\i -> i < clockRate * ticks) (+1) 0 -- > -- > test :: Acc (Vector Int) -- > test = -- > zip3 -- > (compute $ map loop (use $ fromList (Z:.1) [10])) -- > (compute $ map loop (use $ fromList (Z:.1) [10])) -- > (compute $ map loop (use $ fromList (Z:.1) [10])) -- > -- -- Without the use of 'compute', the operations are fused together and the three -- long-running loops are executed sequentially in a single kernel. Instead, the -- individual operations can now be executed concurrently, potentially reducing -- overall runtime. -- compute :: Arrays a => Acc a -> Acc a compute = id >-> id -- Flow control -- ------------ -- | Infix version of 'acond'. If the predicate evaluates to 'True', the first -- component of the tuple is returned, else the second. -- -- Enabling the @RebindableSyntax@ extension will allow you to use the standard -- if-then-else syntax instead. -- infix 0 ?| (?|) :: Arrays a => Exp Bool -> (Acc a, Acc a) -> Acc a c ?| (t, e) = acond c t e -- | An infix version of 'cond'. If the predicate evaluates to 'True', the first -- component of the tuple is returned, else the second. -- -- Enabling the @RebindableSyntax@ extension will allow you to use the standard -- if-then-else syntax instead. -- infix 0 ? (?) :: Elt t => Exp Bool -> (Exp t, Exp t) -> Exp t c ? (t, e) = cond c t e -- | A case-like control structure -- caseof :: (Elt a, Elt b) => Exp a -- ^ case subject -> [(Exp a -> Exp Bool, Exp b)] -- ^ list of cases to attempt -> Exp b -- ^ default value -> Exp b caseof _ [] e = e caseof x ((p,b):l) e = cond (p x) b (caseof x l e) -- | For use with @-XRebindableSyntax@, this class provides 'ifThenElse' lifted -- to both scalar and array types. -- class IfThenElse t where type EltT t a :: Constraint ifThenElse :: EltT t a => Exp Bool -> t a -> t a -> t a instance IfThenElse Exp where type EltT Exp t = Elt t ifThenElse = cond instance IfThenElse Acc where type EltT Acc a = Arrays a ifThenElse = acond -- Scalar iteration -- ---------------- -- | Repeatedly apply a function a fixed number of times -- iterate :: forall a. Elt a => Exp Int -> (Exp a -> Exp a) -> Exp a -> Exp a iterate n f z = let step :: (Exp Int, Exp a) -> (Exp Int, Exp a) step (i, acc) = ( i+1, f acc ) in snd $ while (\v -> fst v < n) (lift1 step) (lift (constant 0, z)) -- Scalar bulk operations -- ---------------------- -- | Reduce along an innermost slice of an array /sequentially/, by applying a -- binary operator to a starting value and the array from left to right. -- sfoldl :: forall sh a b. (Shape sh, Slice sh, Elt a, Elt b) => (Exp a -> Exp b -> Exp a) -> Exp a -> Exp sh -> Acc (Array (sh :. Int) b) -> Exp a sfoldl f z ix xs = let step :: (Exp Int, Exp a) -> (Exp Int, Exp a) step (i, acc) = ( i+1, acc `f` (xs ! lift (ix :. i)) ) (_ :. n) = unlift (shape xs) :: Exp sh :. Exp Int in snd $ while (\v -> fst v < n) (lift1 step) (lift (constant 0, z)) -- Tuples -- ------ -- |Extract the first component of a scalar pair. -- fst :: forall a b. (Elt a, Elt b) => Exp (a, b) -> Exp a fst e = let (x, _::Exp b) = unlift e in x -- |Extract the first component of an array pair. {-# NOINLINE[1] afst #-} afst :: forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc a afst a = let (x, _::Acc b) = unlift a in x -- |Extract the second component of a scalar pair. -- snd :: forall a b. (Elt a, Elt b) => Exp (a, b) -> Exp b snd e = let (_:: Exp a, y) = unlift e in y -- | Extract the second component of an array pair asnd :: forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc b asnd a = let (_::Acc a, y) = unlift a in y -- |Converts an uncurried function to a curried function. -- curry :: Lift f (f a, f b) => (f (Plain (f a), Plain (f b)) -> f c) -> f a -> f b -> f c curry f x y = f (lift (x, y)) -- |Converts a curried function to a function on pairs. -- uncurry :: Unlift f (f a, f b) => (f a -> f b -> f c) -> f (Plain (f a), Plain (f b)) -> f c uncurry f t = let (x, y) = unlift t in f x y -- Shapes and indices -- ------------------ -- | The one index for a rank-0 array. -- index0 :: Exp Z index0 = lift Z -- | Turn an 'Int' expression into a rank-1 indexing expression. -- index1 :: Elt i => Exp i -> Exp (Z :. i) index1 i = lift (Z :. i) -- | Turn a rank-1 indexing expression into an 'Int' expression. -- unindex1 :: Elt i => Exp (Z :. i) -> Exp i unindex1 ix = let Z :. i = unlift ix in i -- | Creates a rank-2 index from two Exp Int`s -- index2 :: (Elt i, Slice (Z :. i)) => Exp i -> Exp i -> Exp (Z :. i :. i) index2 i j = lift (Z :. i :. j) -- | Destructs a rank-2 index to an Exp tuple of two Int`s. -- unindex2 :: forall i. (Elt i, Slice (Z :. i)) => Exp (Z :. i :. i) -> Exp (i, i) unindex2 ix = let Z :. i :. j = unlift ix :: Z :. Exp i :. Exp i in lift (i, j) -- | Create a rank-3 index from three Exp Int`s -- index3 :: (Elt i, Slice (Z :. i), Slice (Z :. i :. i)) => Exp i -> Exp i -> Exp i -> Exp (Z :. i :. i :. i) index3 k j i = lift (Z :. k :. j :. i) -- | Destruct a rank-3 index into an Exp tuple of Int`s unindex3 :: forall i. (Elt i, Slice (Z :. i), Slice (Z :. i :. i)) => Exp (Z :. i :. i :. i) -> Exp (i, i, i) unindex3 ix = let Z :. k :. j :. i = unlift ix :: Z :. Exp i :. Exp i :. Exp i in lift (k, j, i) -- Array operations with a scalar result -- ------------------------------------- -- | Extract the element of a singleton array. -- -- > the xs == xs ! Z -- the :: Elt e => Acc (Scalar e) -> Exp e the = (!index0) -- | Test whether an array is empty. -- null :: (Shape sh, Elt e) => Acc (Array sh e) -> Exp Bool null arr = size arr == 0 -- | Get the length of a vector. -- length :: Elt e => Acc (Vector e) -> Exp Int length = unindex1 . shape {-- -- Sequence operations -- -------------------------------------- -- | Reduce a sequence by appending all the shapes and all the elements in two -- separate vectors. -- fromSeq :: (Shape sh, Elt a) => Seq [Array sh a] -> Seq (Vector sh, Vector a) fromSeq = foldSeqFlatten f (lift (emptyArray, emptyArray)) where f x sh1 a1 = let (sh0, a0) = unlift x in lift (sh0 ++ sh1, a0 ++ a1) fromSeqElems :: (Shape sh, Elt a) => Seq [Array sh a] -> Seq (Vector a) fromSeqElems = foldSeqFlatten f emptyArray where f a0 _ a1 = a0 ++ a1 fromSeqShapes :: (Shape sh, Elt a) => Seq [Array sh a] -> Seq (Vector sh) fromSeqShapes = foldSeqFlatten f emptyArray where f sh0 sh1 _ = sh0 ++ sh1 -- | Sequence an array on the innermost dimension. -- toSeqInner :: (Shape sh, Elt a) => Acc (Array (sh :. Int) a) -> Seq [Array sh a] toSeqInner a = toSeq (Any :. Split) a -- | Sequence a 2-dimensional array on the outermost dimension. -- toSeqOuter2 :: Elt a => Acc (Array DIM2 a) -> Seq [Array DIM1 a] toSeqOuter2 a = toSeq (Z :. Split :. All) a -- | Sequence a 3-dimensional array on the outermost dimension. toSeqOuter3 :: Elt a => Acc (Array DIM3 a) -> Seq [Array DIM2 a] toSeqOuter3 a = toSeq (Z :. Split :. All :. All) a -- | Generate a scalar sequence of a fixed given length, by applying -- the given scalar function at each index. generateSeq :: Elt a => Exp Int -> (Exp Int -> Exp a) -> Seq [Scalar a] generateSeq n f = toSeq (Z :. Split) (generate (index1 n) (f . unindex1)) --} -- Utilities -- --------- emptyArray :: (Shape sh, Elt e) => Acc (Array sh e) emptyArray = use (fromList empty []) -- Utilities -- --------- matchShapeType :: forall s t. (Shape s, Shape t) => s -> t -> Maybe (s :~: t) matchShapeType _ _ | Just Refl <- matchTupleType (eltType (undefined::s)) (eltType (undefined::t)) = gcast Refl matchShapeType _ _ = Nothing