{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}

-- |
-- This module provides a pure linear interface for arrays with in-place
-- mutation.
--
-- To use these mutable arrays, create a linear computation of type
-- @Array a %1-> Ur b@ and feed it to 'alloc' or 'fromList'.
--
-- == A Tiny Example
--
-- >>> :set -XLinearTypes
-- >>> :set -XNoImplicitPrelude
-- >>> import Prelude.Linear
-- >>> import qualified Data.Array.Mutable.Linear as Array
-- >>> :{
--  isFirstZero :: Array.Array Int %1-> Ur Bool
--  isFirstZero arr =
--    Array.get 0 arr
--      & \(Ur val, arr') -> arr' `lseq` Ur (val == 0)
-- :}
--
-- >>> unur $ Array.fromList [0..10] isFirstZero
-- True
-- >>> unur $ Array.fromList [1,2,3] isFirstZero
-- False
module Data.Array.Mutable.Linear
  ( -- * Mutable Linear Arrays
    Array,
    -- * Performing Computations with Arrays
    alloc,
    allocBeside,
    fromList,
    -- * Modifications
    set,
    unsafeSet,
    resize,
    map,
    -- * Accessors
    get,
    unsafeGet,
    size,
    slice,
    toList,
    freeze,
    -- * Mutable-style interface
    read,
    unsafeRead,
    write,
    unsafeWrite
  )
where

import Data.Unrestricted.Linear
import GHC.Stack
import Data.Array.Mutable.Unlifted.Linear (Array#)
import qualified Data.Array.Mutable.Unlifted.Linear as Unlifted
import qualified Data.Functor.Linear as Data
import qualified Data.Vector as Vector
import qualified Data.Vector.Mutable as MVector
import Prelude.Linear ((&), forget)
import qualified Data.Primitive.Array as Prim
import System.IO.Unsafe (unsafeDupablePerformIO)
import Prelude hiding (read, map)

-- # Data types
-------------------------------------------------------------------------------

data Array a = Array (Array# a)

-- # Creation
-------------------------------------------------------------------------------

-- | Allocate a constant array given a size and an initial value
-- The size must be non-negative, otherwise this errors.
alloc :: HasCallStack =>
  Int -> a -> (Array a %1-> Ur b) %1-> Ur b
alloc :: forall a b.
HasCallStack =>
Int -> a -> (Array a %1 -> Ur b) %1 -> Ur b
alloc Int
s a
x Array a %1 -> Ur b
f
  | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 =
    ([Char] -> x %1 -> x
forall a. HasCallStack => [Char] -> a
error ([Char]
"Array.alloc: negative size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
s) :: x %1-> x)
    (Array a %1 -> Ur b
f Array a
forall a. HasCallStack => a
undefined)
  | Bool
otherwise = Int -> a -> (Array# a %1 -> Ur b) %1 -> Ur b
forall a b. Int -> a -> (Array# a %1 -> Ur b) %1 -> Ur b
Unlifted.alloc Int
s a
x (\Array# a
arr -> Array a %1 -> Ur b
f (Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
arr))

-- | Allocate a constant array given a size and an initial value,
-- using another array as a uniqueness proof.
allocBeside :: Int -> a -> Array b %1-> (Array a, Array b)
allocBeside :: forall a b. Int -> a -> Array b %1 -> (Array a, Array b)
allocBeside Int
s a
x (Array Array# b
orig)
  | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 =
     Array# b %1 -> (Array a, Array b) %1 -> (Array a, Array b)
forall a b. Array# a %1 -> b %1 -> b
Unlifted.lseq
       Array# b
orig
       ([Char] -> (Array a, Array b)
forall a. HasCallStack => [Char] -> a
error ([Char]
"Array.allocBeside: negative size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
s))
  | Bool
otherwise =
      (# Array# a, Array# b #) %1 -> (Array a, Array b)
forall a b. (# Array# a, Array# b #) %1 -> (Array a, Array b)
wrap (Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
forall a b. Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
Unlifted.allocBeside Int
s a
x Array# b
orig)
     where
      wrap :: (# Array# a, Array# b #) %1-> (Array a, Array b)
      wrap :: forall a b. (# Array# a, Array# b #) %1 -> (Array a, Array b)
wrap (# Array# a
orig, Array# b
new #) = (Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
orig, Array# b %1 -> Array b
forall a. Array# a -> Array a
Array Array# b
new)

-- | Allocate an array from a list
fromList :: HasCallStack =>
  [a] -> (Array a %1-> Ur b) %1-> Ur b
fromList :: forall a b. HasCallStack => [a] -> (Array a %1 -> Ur b) %1 -> Ur b
fromList [a]
list (Array a %1 -> Ur b
f :: Array a %1-> Ur b) =
  Int -> a -> (Array a %1 -> Ur b) %1 -> Ur b
forall a b.
HasCallStack =>
Int -> a -> (Array a %1 -> Ur b) %1 -> Ur b
alloc
    ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
list)
    ([Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"invariant violation: unintialized array position")
    (\Array a
arr -> Array a %1 -> Ur b
f (Array a %1 -> Array a
insert Array a
arr))
 where
  insert :: Array a %1-> Array a
  insert :: Array a %1 -> Array a
insert = [(a, Int)] -> Array a %1 -> Array a
doWrites ([a] -> [Int] -> [(a, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
list [Int
0..])

  doWrites :: [(a,Int)] -> Array a %1-> Array a
  doWrites :: [(a, Int)] -> Array a %1 -> Array a
doWrites [] Array a
arr = Array a
arr
  doWrites ((a
a,Int
ix):[(a, Int)]
xs) Array a
arr = [(a, Int)] -> Array a %1 -> Array a
doWrites [(a, Int)]
xs (Int -> a -> Array a %1 -> Array a
forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
ix a
a Array a
arr)

-- # Mutations and Reads
-------------------------------------------------------------------------------

size :: Array a %1-> (Ur Int, Array a)
size :: forall a. Array a %1 -> (Ur Int, Array a)
size (Array Array# a
arr) = (# Ur Int, Array# a #) %1 -> (Ur Int, Array a)
forall a. (# Ur Int, Array# a #) %1 -> (Ur Int, Array a)
f (Array# a %1 -> (# Ur Int, Array# a #)
forall a. Array# a %1 -> (# Ur Int, Array# a #)
Unlifted.size Array# a
arr)
 where
  f :: (# Ur Int, Array# a #) %1-> (Ur Int, Array a)
  f :: forall a. (# Ur Int, Array# a #) %1 -> (Ur Int, Array a)
f (# Ur Int
s, Array# a
arr #) = (Ur Int
s, Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
arr)

-- | Sets the value of an index. The index should be less than the arrays
-- size, otherwise this errors.
set :: HasCallStack => Int -> a -> Array a %1-> Array a
set :: forall a. HasCallStack => Int -> a -> Array a %1 -> Array a
set Int
i a
x Array a
arr = Int -> a -> Array a %1 -> Array a
forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
i a
x (Int -> Array a %1 -> Array a
forall a. HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange Int
i Array a
arr)

-- | Same as 'set, but does not do bounds-checking. The behaviour is undefined
-- if an out-of-bounds index is provided.
unsafeSet :: Int -> a -> Array a %1-> Array a
unsafeSet :: forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
ix a
val (Array Array# a
arr) =
  Array# a %1 -> Array a
forall a. Array# a -> Array a
Array (Int -> a -> Array# a %1 -> Array# a
forall a. Int -> a -> Array# a %1 -> Array# a
Unlifted.set Int
ix a
val Array# a
arr)

-- | Get the value of an index. The index should be less than the arrays 'size',
-- otherwise this errors.
get :: HasCallStack => Int -> Array a %1-> (Ur a, Array a)
get :: forall a. HasCallStack => Int -> Array a %1 -> (Ur a, Array a)
get Int
i Array a
arr = Int -> Array a %1 -> (Ur a, Array a)
forall a. Int -> Array a %1 -> (Ur a, Array a)
unsafeGet Int
i (Int -> Array a %1 -> Array a
forall a. HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange Int
i Array a
arr)

-- | Same as 'get', but does not do bounds-checking. The behaviour is undefined
-- if an out-of-bounds index is provided.
unsafeGet :: Int -> Array a %1-> (Ur a, Array a)
unsafeGet :: forall a. Int -> Array a %1 -> (Ur a, Array a)
unsafeGet Int
ix (Array Array# a
arr) = (# Ur a, Array# a #) %1 -> (Ur a, Array a)
forall a. (# Ur a, Array# a #) %1 -> (Ur a, Array a)
wrap (Int -> Array# a %1 -> (# Ur a, Array# a #)
forall a. Int -> Array# a %1 -> (# Ur a, Array# a #)
Unlifted.get Int
ix Array# a
arr)
 where
  wrap :: (# Ur a, Array# a #) %1-> (Ur a, Array a)
  wrap :: forall a. (# Ur a, Array# a #) %1 -> (Ur a, Array a)
wrap (# Ur a
ret, Array# a
arr #) = (Ur a
ret, Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
arr)

-- | Resize an array. That is, given an array, a target size, and a seed
-- value; resize the array to the given size using the seed value to fill
-- in the new cells when necessary and copying over all the unchanged cells.
--
-- Target size should be non-negative.
--
-- @
-- let b = resize n x a,
--   then size b = n,
--   and b[i] = a[i] for i < size a,
--   and b[i] = x for size a <= i < n.
-- @
resize :: HasCallStack => Int -> a -> Array a %1-> Array a
resize :: forall a. HasCallStack => Int -> a -> Array a %1 -> Array a
resize Int
newSize a
seed (Array Array# a
arr :: Array a)
  | Int
newSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 =
      Array# a %1 -> Array a %1 -> Array a
forall a b. Array# a %1 -> b %1 -> b
Unlifted.lseq
        Array# a
arr
        ([Char] -> Array a
forall a. HasCallStack => [Char] -> a
error [Char]
"Trying to resize to a negative size.")
  | Bool
otherwise =
      (# Array# a, Array# a #) %1 -> Array a
doCopy (Int -> a -> Array# a %1 -> (# Array# a, Array# a #)
forall a b. Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
Unlifted.allocBeside Int
newSize a
seed Array# a
arr)
     where
      doCopy :: (# Array# a, Array# a #) %1-> Array a
      doCopy :: (# Array# a, Array# a #) %1 -> Array a
doCopy (# Array# a
new, Array# a
old #) = (# Array# a, Array# a #) %1 -> Array a
wrap (Int -> Array# a %1 -> Array# a %1 -> (# Array# a, Array# a #)
forall a.
Int -> Array# a %1 -> Array# a %1 -> (# Array# a, Array# a #)
Unlifted.copyInto Int
0 Array# a
old Array# a
new)

      wrap :: (# Array# a, Array# a #) %1-> Array a
      wrap :: (# Array# a, Array# a #) %1 -> Array a
wrap (# Array# a
src, Array# a
dst #) = Array# a
src Array# a %1 -> Array a %1 -> Array a
forall a b. Array# a %1 -> b %1 -> b
`Unlifted.lseq` Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
dst


-- | Return the array elements as a lazy list.
toList :: Array a %1-> Ur [a]
toList :: forall a. Array a %1 -> Ur [a]
toList (Array Array# a
arr) = Array# a %1 -> Ur [a]
forall a. Array# a %1 -> Ur [a]
Unlifted.toList Array# a
arr

-- | Copy a slice of the array, starting from given offset and copying given
-- number of elements. Returns the pair (oldArray, slice).
--
-- Start offset + target size should be within the input array, and both should
-- be non-negative.
--
-- @
-- let b = slice i n a,
--   then size b = n,
--   and b[j] = a[i+j] for 0 <= j < n
-- @
slice
  :: HasCallStack
  => Int -- ^ Start offset
  -> Int -- ^ Target size
  -> Array a %1-> (Array a, Array a)
slice :: forall a.
HasCallStack =>
Int -> Int -> Array a %1 -> (Array a, Array a)
slice Int
from Int
targetSize Array a
arr =
  Array a %1 -> (Ur Int, Array a)
forall a. Array a %1 -> (Ur Int, Array a)
size Array a
arr (Ur Int, Array a)
%1 -> ((Ur Int, Array a) %1 -> (Array a, Array a))
%1 -> (Array a, Array a)
forall a b. a %1 -> (a %1 -> b) %1 -> b
& \case
    (Ur Int
s, Array Array# a
old)
      | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
from Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
targetSize ->
          Array# a %1 -> (Array a, Array a) %1 -> (Array a, Array a)
forall a b. Array# a %1 -> b %1 -> b
Unlifted.lseq
            Array# a
old
            ([Char] -> (Array a, Array a)
forall a. HasCallStack => [Char] -> a
error [Char]
"Slice index out of bounds.")
      | Bool
otherwise ->
          (# Array# a, Array# a #) %1 -> (Array a, Array a)
forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
doCopy
            (Int -> a -> Array# a %1 -> (# Array# a, Array# a #)
forall a b. Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
Unlifted.allocBeside
               Int
targetSize
               ([Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"invariant violation: uninitialized array index")
               Array# a
old)
  where
    doCopy :: (# Array# a, Array# a #) %1-> (Array a, Array a)
    doCopy :: forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
doCopy (# Array# a
new, Array# a
old #) = (# Array# a, Array# a #) %1 -> (Array a, Array a)
forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (Int -> Array# a %1 -> Array# a %1 -> (# Array# a, Array# a #)
forall a.
Int -> Array# a %1 -> Array# a %1 -> (# Array# a, Array# a #)
Unlifted.copyInto Int
from Array# a
old Array# a
new)

    wrap :: (# Array# a, Array# a  #) %1-> (Array a, Array a)
    wrap :: forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (# Array# a
old, Array# a
new #) = (Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
old, Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
new)

-- | /O(1)/ Convert an 'Array' to an immutable 'Vector.Vector' (from
-- 'vector' package).
freeze :: Array a %1-> Ur (Vector.Vector a)
freeze :: forall a. Array a %1 -> Ur (Vector a)
freeze (Array Array# a
arr) =
  (Array# a -> Vector a) -> Array# a %1 -> Ur (Vector a)
forall a b. (Array# a -> b) -> Array# a %1 -> Ur b
Unlifted.freeze Array# a -> Vector a
forall {a}. Array# a -> Vector a
go Array# a
arr
 where
   go :: Array# a -> Vector a
go Array# a
arr = IO (Vector a) -> Vector a
forall a. IO a -> a
unsafeDupablePerformIO (IO (Vector a) -> Vector a) -> IO (Vector a) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
     MutableArray RealWorld a
mut <- Array a -> IO (MutableArray (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
Array a -> m (MutableArray (PrimState m) a)
Prim.unsafeThawArray (Array# a -> Array a
forall a. Array# a -> Array a
Prim.Array Array# a
arr)
     let mv :: MVector RealWorld a
mv = Int -> Int -> MutableArray RealWorld a -> MVector RealWorld a
forall s a. Int -> Int -> MutableArray s a -> MVector s a
MVector.MVector Int
0 (MutableArray RealWorld a -> Int
forall s a. MutableArray s a -> Int
Prim.sizeofMutableArray MutableArray RealWorld a
mut) MutableArray RealWorld a
mut
     MVector (PrimState IO) a -> IO (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze MVector RealWorld a
MVector (PrimState IO) a
mv
   -- We only need to do above because 'Vector' constructor is hidden.
   -- Once it is exposed, we should be able to replace it with something
   -- safer like: `go arr = Vector 0 (sizeof arr) arr`

map :: (a -> b) -> Array a %1-> Array b
map :: forall a b. (a -> b) -> Array a %1 -> Array b
map a -> b
f (Array Array# a
arr) = Array# b %1 -> Array b
forall a. Array# a -> Array a
Array ((a -> b) -> Array# a %1 -> Array# b
forall a b. (a -> b) -> Array# a %1 -> Array# b
Unlifted.map a -> b
f Array# a
arr)

-- # Mutation-style API
-------------------------------------------------------------------------------

-- | Same as 'set', but takes the 'Array' as the first parameter.
write :: HasCallStack => Array a %1-> Int -> a -> Array a
write :: forall a. HasCallStack => Array a %1 -> Int -> a -> Array a
write Array a
arr Int
i a
a = Int -> a -> Array a %1 -> Array a
forall a. HasCallStack => Int -> a -> Array a %1 -> Array a
set Int
i a
a Array a
arr

-- | Same as 'unsafeSafe', but takes the 'Array' as the first parameter.
unsafeWrite ::  Array a %1-> Int -> a -> Array a
unsafeWrite :: forall a. Array a %1 -> Int -> a -> Array a
unsafeWrite Array a
arr Int
i a
a = Int -> a -> Array a %1 -> Array a
forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
i a
a Array a
arr

-- | Same as 'get', but takes the 'Array' as the first parameter.
read :: HasCallStack => Array a %1-> Int -> (Ur a, Array a)
read :: forall a. HasCallStack => Array a %1 -> Int -> (Ur a, Array a)
read Array a
arr Int
i = Int -> Array a %1 -> (Ur a, Array a)
forall a. HasCallStack => Int -> Array a %1 -> (Ur a, Array a)
get Int
i Array a
arr

-- | Same as 'unsafeGet', but takes the 'Array' as the first parameter.
unsafeRead :: Array a %1-> Int -> (Ur a, Array a)
unsafeRead :: forall a. Array a %1 -> Int -> (Ur a, Array a)
unsafeRead Array a
arr Int
i = Int -> Array a %1 -> (Ur a, Array a)
forall a. Int -> Array a %1 -> (Ur a, Array a)
unsafeGet Int
i Array a
arr

-- # Instances
-------------------------------------------------------------------------------

instance Consumable (Array a) where
  consume :: Array a %1-> ()
  consume :: Array a %1 -> ()
consume (Array Array# a
arr) = Array# a
arr Array# a %1 -> () %1 -> ()
forall a b. Array# a %1 -> b %1 -> b
`Unlifted.lseq` ()

instance Dupable (Array a) where
  dup2 :: Array a %1-> (Array a, Array a)
  dup2 :: Array a %1 -> (Array a, Array a)
dup2 (Array Array# a
arr) = (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (Array# a %1 -> (# Array# a, Array# a #)
forall a. Array# a %1 -> (# Array# a, Array# a #)
Unlifted.dup2 Array# a
arr)
   where
     wrap :: (# Array# a, Array# a #) %1-> (Array a, Array a)
     wrap :: (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (# Array# a
a1, Array# a
a2 #) = (Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
a1, Array# a %1 -> Array a
forall a. Array# a -> Array a
Array Array# a
a2)

instance Data.Functor Array where
  fmap :: forall a b. (a %1 -> b) -> Array a %1 -> Array b
fmap a %1 -> b
f Array a
arr = (a -> b) -> Array a %1 -> Array b
forall a b. (a -> b) -> Array a %1 -> Array b
map ((a %1 -> b) %1 -> a -> b
forall a b. (a %1 -> b) %1 -> a -> b
forget a %1 -> b
f) Array a
arr

-- # Internal library
-------------------------------------------------------------------------------

-- | Check if given index is within the Array, otherwise panic.
assertIndexInRange :: HasCallStack => Int -> Array a %1-> Array a
assertIndexInRange :: forall a. HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange Int
i Array a
arr =
  Array a %1 -> (Ur Int, Array a)
forall a. Array a %1 -> (Ur Int, Array a)
size Array a
arr (Ur Int, Array a)
%1 -> ((Ur Int, Array a) %1 -> Array a) %1 -> Array a
forall a b. a %1 -> (a %1 -> b) %1 -> b
& \(Ur Int
s, Array a
arr') ->
    if Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
s
    then Array a
arr'
    else Array a
arr' Array a %1 -> Array a %1 -> Array a
forall a b. Consumable a => a %1 -> b %1 -> b
`lseq` [Char] -> Array a
forall a. HasCallStack => [Char] -> a
error [Char]
"Array: index out of bounds"