{-# LANGUAGE GADTs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides destination arrays
--
-- == What are destination arrays? What are they good for?
--
-- Destination arrays are write-only arrays that are only allocated once,
-- thereby avoiding your reliance on GHC's fusion mechanisms to remove
-- unneccessary allocations.
--
-- The current status-quo for computations that have a write-only array
-- threaded along is to rely on fusion. While the optimizations in say,
-- `Data.Vector` are quite good at ensuring GHC fuses, they aren't
-- foolproof and can sometimes break by simple refactorings.
--
-- Avoiding extra allocations of a write-only array is easy in C, with
-- something the functional programming world calls destination passing style,
-- or DPS for short.
--
-- Here is a C function that manipulates an array written in DPS style; it
-- takes in the destiniation array @res@ and writes to it:
--
-- @
-- // ((a + b) * c) for vectors a,b and scalar c
-- void apbxc(int size, int *a, int *b, int c, int *res){
--   for (int i=0; i<size;++i){res[i]=a[i]+b[i];}
--   mult(size, c, res);
-- }
--
-- void mult(int size, int scalar, int* vec){
--   for (int i=0; i<size; ++i){vec[i] *= scalar;}
-- }
-- @
--
-- == Example: Stencil computation
--
-- One possible use of destination arrays could be the stencil computation
-- typically called
-- [jacobi](https://en.wikipedia.org/wiki/Iterative_Stencil_Loops#Example:_2D_Jacobi_iteration).
-- Here we show one time step of this computation in a single dimension:
--
-- @
-- jacobi1d :: Int -> Vector Double -> Vector Double
-- jacobi1d n oldA = case stepArr n oldA of 
--   newB -> stepArr n newB
--
-- -- @jacobi1d N A[N] B[N] = (new_A[N], new_B[N])@.
-- stepArr :: Int -> Vector Double -> Vector Double
-- stepArr n oldArr = alloc n $ \newArr -> fillArr newArr oldArr 1
--   where
--     fillArr :: DArray Double %1-> Vector Double -> Int -> ()
--     fillArr newA oldA ix
--       | ix == (n-1) = newA &
--           fill (0.33 * ((oldA ! (ix-1)) + (oldA ! ix) + (oldA ! (ix+1))))
--       | True = split 1 newA & \(fst, rest) ->
--           fill (0.33 * ((oldA ! (ix-1)) + (oldA ! ix) + (oldA ! (ix+1)))) fst &
--             \() -> fillArr rest oldA (ix+1)
-- @
--
-- We can be sure that @stepArr@ only allocates one array. In certain
-- variations and implementations of the jacobi kernel or similar dense array
-- computations, ensuring one allocation with @Data.Vector@'s fusion oriented
-- implementation may not be trivial.
--
-- For reference, the C equivalent of this code is the following:
--
-- @
-- static void jacobi_1d_time_step(int n, int *A, int *B){
--   int t, i;
--   for (i = 1; i < _PB_N - 1; i++)
--     B[i] = 0.33333 * (A[i-1] + A[i] + A[i + 1]);
--   for (i = 1; i < _PB_N - 1; i++)
--     A[i] = 0.33333 * (B[i-1] + B[i] + B[i + 1]);
-- }
-- @
--
-- This example is taken from the
-- [polybench test-suite](https://web.cse.ohio-state.edu/~pouchet.2/software/polybench/)
-- of dense array codes.
--
-- == Aside: Why do we need linear types?
--
-- Linear types avoids ambiguous writes to the destination array.
-- For example, this function could never be linear and hence we avoid
-- ambiguity:
--
-- @
--  nonLinearUse :: DArray Int -> ()
--  nonLinearUse arr = case (replicate 3 arr, replicate 4 arr) of
--    ((),()) -> ()
-- @
--
-- Furthermore, this API is safely implemented by mutating an underlying array
-- which is good for performance. The API is safe because linear types
-- enforce the fact that each reference to an underlying mutable array
-- (and there can be more than one by using @split@) is
-- linearly threaded through functions and at the end consumed by one of our
-- write functions.
--
-- Lastly, linear types are used to ensure that each cell in the destination
-- array is written to exactly once. This is because the only way to create and
-- use a destination array is via
--
-- @
-- alloc :: Int -> (DArray a %1-> ()) %1-> Vector a
-- @
--
-- and the only way to really consume a @DArray@ is via our API
-- which requires you to completely fill the array.
--
module Data.Array.Destination
  (
  -- * The Data Type
    DArray
  -- * Create and use a @DArray@
  , alloc
  , size
  -- * Ways to write to a @DArray@
  , replicate
  , split
  , mirror
  , fromFunction
  , fill
  , dropEmpty
  )
  where

import Data.Vector (Vector, (!))
import qualified Data.Vector as Vector
import Data.Vector.Mutable (MVector)
import qualified Data.Vector.Mutable as MVector
import GHC.Exts (RealWorld)
import qualified Prelude as Prelude
import System.IO.Unsafe (unsafeDupablePerformIO)
import GHC.Stack
import Data.Unrestricted.Linear
import Prelude.Linear hiding (replicate)
import qualified Unsafe.Linear as Unsafe

-- | A destination array, or @DArray@, is a write-only array that is filled
-- by some computation which ultimately returns an array.
data DArray a where
  DArray :: MVector RealWorld a -> DArray a

-- XXX: use of Vector in types is temporary. I will probably move away from
-- vectors and implement most stuff in terms of Array# and MutableArray#
-- eventually, anyway. This would allow to move the MutableArray logic to
-- linear IO, possibly, and segregate the unsafe casts to the Linear IO
-- module.  @`alloc` n k@ must be called with a non-negative value of @n@.
alloc :: Int -> (DArray a %1-> ()) %1-> Vector a
alloc :: forall a. Int -> (DArray a %1 -> ()) %1 -> Vector a
alloc Int
n DArray a %1 -> ()
writer = (\(Ur MVector RealWorld a
dest, Vector a
vec) -> DArray a %1 -> ()
writer (MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
dest) () %1 -> Vector a %1 -> Vector a
forall a b. Consumable a => a %1 -> b %1 -> b
`lseq` Vector a
vec) ((Ur (MVector RealWorld a), Vector a) %1 -> Vector a)
%1 -> (Ur (MVector RealWorld a), Vector a) %1 -> Vector a
forall a b. (a %1 -> b) %1 -> a %1 -> b
$
  IO (Ur (MVector RealWorld a), Vector a)
-> (Ur (MVector RealWorld a), Vector a)
forall a. IO a -> a
unsafeDupablePerformIO (IO (Ur (MVector RealWorld a), Vector a)
 -> (Ur (MVector RealWorld a), Vector a))
-> IO (Ur (MVector RealWorld a), Vector a)
-> (Ur (MVector RealWorld a), Vector a)
forall a b. (a -> b) -> a -> b
Prelude.$ do
    MVector RealWorld a
destArray <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MVector.unsafeNew Int
n
    Vector a
vec <- MVector (PrimState IO) a -> IO (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze MVector RealWorld a
MVector (PrimState IO) a
destArray
    (Ur (MVector RealWorld a), Vector a)
-> IO (Ur (MVector RealWorld a), Vector a)
forall (m :: * -> *) a. Monad m => a -> m a
Prelude.return (MVector RealWorld a -> Ur (MVector RealWorld a)
forall a. a -> Ur a
Ur MVector RealWorld a
destArray, Vector a
vec)

-- | Get the size of a destination array.
size :: DArray a %1-> (Ur Int, DArray a)
size :: forall a. DArray a %1 -> (Ur Int, DArray a)
size (DArray MVector RealWorld a
mvec) = (Int -> Ur Int
forall a. a -> Ur a
Ur (MVector RealWorld a -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec), MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
mvec)

-- | Fill a destination array with a constant
replicate :: a -> DArray a %1-> ()
replicate :: forall a. a -> DArray a %1 -> ()
replicate a
a = (Int -> a) -> DArray a %1 -> ()
forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction (a %1 -> Int -> a
forall a b. a %1 -> b -> a
const a
a)

-- | @fill a dest@ fills a singleton destination array.
-- Caution, @'fill' a dest@ will fail is @dest@ isn't of length exactly one.
fill :: HasCallStack => a %1-> DArray a %1-> ()
fill :: forall a. HasCallStack => a %1 -> DArray a %1 -> ()
fill a
a (DArray MVector RealWorld a
mvec) =
  if MVector RealWorld a -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec Int %1 -> Int %1 -> Bool
forall a. Eq a => a %1 -> a %1 -> Bool
/= Int
1
  then [Char] -> a %1 -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.fill: requires a destination of size 1" (a %1 -> ()) %1 -> a %1 -> ()
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ a
a
  else a
a a %1 -> (a %1 -> ()) %1 -> ()
forall a b. a %1 -> (a %1 -> b) %1 -> b
&
    (a -> ()) %1 -> a %1 -> ()
forall a b (p :: Multiplicity). (a %p -> b) %1 -> a %1 -> b
Unsafe.toLinear (\a
x -> IO () -> ()
forall a. IO a -> a
unsafeDupablePerformIO (MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.write MVector RealWorld a
MVector (PrimState IO) a
mvec Int
0 a
x))

-- | @dropEmpty dest@ consumes and empty array and fails otherwise.
dropEmpty :: HasCallStack => DArray a %1-> ()
dropEmpty :: forall a. HasCallStack => DArray a %1 -> ()
dropEmpty (DArray MVector RealWorld a
mvec)
  | MVector RealWorld a -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld a
mvec Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
> Int
0 = [Char] -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.dropEmpty on non-empty array."
  | Bool
otherwise = MVector RealWorld a
mvec MVector RealWorld a -> () %1 -> ()
forall a b. a -> b %1 -> b
`seq` ()

-- | @'split' n dest = (destl, destr)@ such as @destl@ has length @n@.
--
-- 'split' is total: if @n@ is larger than the length of @dest@, then
-- @destr@ is empty.
split :: Int -> DArray a %1-> (DArray a, DArray a)
split :: forall a. Int -> DArray a %1 -> (DArray a, DArray a)
split Int
n (DArray MVector RealWorld a
mvec) | (MVector RealWorld a
ml, MVector RealWorld a
mr) <- Int
-> MVector RealWorld a
-> (MVector RealWorld a, MVector RealWorld a)
forall s a. Int -> MVector s a -> (MVector s a, MVector s a)
MVector.splitAt Int
n MVector RealWorld a
mvec =
  (MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
ml, MVector RealWorld a -> DArray a
forall a. MVector RealWorld a -> DArray a
DArray MVector RealWorld a
mr)

-- | Fills the destination array with the contents of given vector.
--
-- Errors if the given vector is smaller than the destination array.
mirror :: HasCallStack => Vector a -> (a %1-> b) -> DArray b %1-> ()
mirror :: forall a b.
HasCallStack =>
Vector a -> (a %1 -> b) -> DArray b %1 -> ()
mirror Vector a
v a %1 -> b
f DArray b
arr =
  DArray b %1 -> (Ur Int, DArray b)
forall a. DArray a %1 -> (Ur Int, DArray a)
size DArray b
arr (Ur Int, DArray b) %1 -> ((Ur Int, DArray b) %1 -> ()) %1 -> ()
forall a b. a %1 -> (a %1 -> b) %1 -> b
& \(Ur Int
sz, DArray b
arr') ->
    if Vector a -> Int
forall a. Vector a -> Int
Vector.length Vector a
v Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
< Int
sz
    then [Char] -> DArray b %1 -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Destination.mirror: argument smaller than DArray" (DArray b %1 -> ()) %1 -> DArray b %1 -> ()
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ DArray b
arr'
    else (Int -> b) -> DArray b %1 -> ()
forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction (\Int
t -> a %1 -> b
f (Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
! Int
t)) DArray b
arr'

-- | Fill a destination array using the given index-to-value function.
fromFunction :: (Int -> b) -> DArray b %1-> ()
fromFunction :: forall b. (Int -> b) -> DArray b %1 -> ()
fromFunction Int -> b
f (DArray MVector RealWorld b
mvec) = IO () -> ()
forall a. IO a -> a
unsafeDupablePerformIO (IO () -> ()) -> IO () -> ()
forall a b. (a -> b) -> a -> b
Prelude.$ do
  let n :: Int
n = MVector RealWorld b -> Int
forall s a. MVector s a -> Int
MVector.length MVector RealWorld b
mvec
  [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
Prelude.sequence_ [MVector (PrimState IO) b -> Int -> b -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.unsafeWrite MVector RealWorld b
MVector (PrimState IO) b
mvec Int
m (Int -> b
f Int
m) | Int
m <- [Int
0..Int
nInt %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
1]]
-- The use of the mutable array is linear, since getting the length does not
-- touch any elements, and each write fills in exactly one slot, so
-- each slot of the destination array is filled.