{-# LANGUAGE MagicHash #-}
-- | Evaluate an array by breaking it up into linear chunks and filling
--   each chunk in parallel.
module Data.Array.Repa.Eval.Chunked
	( fillLinearS
        , fillBlock2S
        , fillChunkedP
	, fillChunkedIOP)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Eval.Gang

import GHC.Exts
import Prelude		as P

-------------------------------------------------------------------------------
-- | Fill something sequentially.
-- 
--   * The array is filled linearly from start to finish.  
-- 
fillLinearS
	:: Int                  -- ^ Number of elements.
	-> (Int -> a -> IO ())	-- ^ Update function to write into result buffer.
	-> (Int -> a)	        -- ^ Fn to get the value at a given index.
	-> IO ()

fillLinearS !(I# len) write getElem
 = fill 0#
 where	fill !ix
	 | ix >=# len	= return ()
	 | otherwise
	 = do	write (I# ix) (getElem (I# ix))
		fill (ix +# 1#)
{-# INLINE [0] fillLinearS #-}


-------------------------------------------------------------------------------
-- | Fill a block in a rank-2 array, sequentially.
--
--   * Blockwise filling can be more cache-efficient than linear filling for
--     rank-2 arrays.
--
--   * The block is filled in row major order from top to bottom.
--
fillBlock2S
        :: (Int  -> a -> IO ()) -- ^ Update function to write into result buffer.
        -> (DIM2 -> a)          -- ^ Fn to get the value at the given index.
        -> Int#                 -- ^ Width of the whole array.
        -> Int#                 -- ^ x0 lower left corner of block to fill.
        -> Int#                 -- ^ y0
        -> Int#                 -- ^ w0 width of block to fill
        -> Int#                 -- ^ h0 height of block to fill
        -> IO ()

fillBlock2S
        write getElem
        !imageWidth !x0 !y0 !w0 h0

 = do   fillBlock y0 ix0
 where  !x1     = x0 +# w0
        !y1     = y0 +# h0
        !ix0    = x0 +# (y0 *# imageWidth)

        {-# INLINE fillBlock #-}
        fillBlock !y !ix
         | y >=# y1     = return ()
         | otherwise
         = do   fillLine1 x0 ix
                fillBlock (y +# 1#) (ix +# imageWidth)

         where  {-# INLINE fillLine1 #-}
                fillLine1 !x !ix'
                 | x >=# x1             = return ()
                 | otherwise
                 = do   write (I# ix') (getElem (Z :. (I# y) :. (I# x)))
                        fillLine1 (x +# 1#) (ix' +# 1#)

{-# INLINE [0] fillBlock2S #-}


-------------------------------------------------------------------------------
-- | Fill something in parallel.
-- 
--   * The array is split into linear chunks,
--     and each thread linearly fills one chunk.
-- 
fillChunkedP
        :: Int                  -- ^ Number of elements.
	-> (Int -> a -> IO ())	-- ^ Update function to write into result buffer.
	-> (Int -> a)	        -- ^ Fn to get the value at a given index.
	-> IO ()

fillChunkedP !(I# len) write getElem
 = 	gangIO theGang
	 $  \(I# thread) -> 
              let !start   = splitIx thread
                  !end     = splitIx (thread +# 1#)
              in  fill start end

 where
	-- Decide now to split the work across the threads.
	-- If the length of the vector doesn't divide evenly among the threads,
	-- then the first few get an extra element.
	!(I# threads) 	= gangSize theGang
	!chunkLen 	= len `quotInt#` threads
	!chunkLeftover	= len `remInt#`  threads

	{-# INLINE splitIx #-}
	splitIx thread
	 | thread <# chunkLeftover = thread *# (chunkLen +# 1#)
	 | otherwise	 	   = thread *# chunkLen  +# chunkLeftover

	-- Evaluate the elements of a single chunk.
	{-# INLINE fill #-}
	fill !ix !end
	 | ix >=# end		= return ()
	 | otherwise
	 = do	write (I# ix) (getElem (I# ix))
		fill (ix +# 1#) end
{-# INLINE [0] fillChunkedP #-}


-------------------------------------------------------------------------------
-- | Fill something in parallel, using a separate IO action for each thread.
--
--   * The array is split into linear chunks,
--     and each thread linearly fills one chunk.
--
fillChunkedIOP
        :: Int  -- ^ Number of elements.
        -> (Int -> a -> IO ())          
                -- ^ Update fn to write into result buffer.
        -> (Int -> IO (Int -> IO a))    
                -- ^ Create a fn to get the value at a given index.
                --   The first `Int` is the thread number, so you can do some
                --   per-thread initialisation.
        -> IO ()

fillChunkedIOP !(I# len) write mkGetElem
 = 	gangIO theGang
	 $  \(I# thread) -> 
              let !start = splitIx thread
                  !end   = splitIx (thread +# 1#)
              in fillChunk thread start end 

 where
	-- Decide now to split the work across the threads.
	-- If the length of the vector doesn't divide evenly among the threads,
	-- then the first few get an extra element.
	!(I# threads) 	= gangSize theGang
	!chunkLen 	= len `quotInt#` threads
	!chunkLeftover	= len `remInt#`  threads

	{-# INLINE splitIx #-}
	splitIx thread
	 | thread <# chunkLeftover = thread *# (chunkLen +# 1#)
	 | otherwise		   = thread *# chunkLen  +# chunkLeftover

        -- Given the threadId, starting and ending indices. 
        --      Make a function to get each element for this chunk
        --      and call it for every index.
        {-# INLINE fillChunk #-}
        fillChunk !thread !ixStart !ixEnd
         = do   getElem <- mkGetElem (I# thread)
                fill getElem ixStart ixEnd
                
        -- Call the provided getElem function for every element
        --      in a chunk, and feed the result to the write function.
	{-# INLINE fill #-}
	fill !getElem !ix0 !end
	 = go ix0 
	 where  go !ix
	         | ix >=# end	= return ()
 	         | otherwise
	         = do	x       <- getElem (I# ix)
	                write (I# ix) x
                        go (ix +# 1#)
{-# INLINE [0] fillChunkedIOP #-}