{-# language MagicHash #-}
{-# language MultiParamTypeClasses #-}
{-# language ScopedTypeVariables #-}
{-# language BangPatterns #-}
{-# language FlexibleInstances #-}
{-# language RoleAnnotations #-}
{- OPTIONS_GHC -ddump-simpl #-}


{- |
This module is a drop-in replacement for @Control.Concurrent.STM.TArray@
in the @stm@ package. It has the same fundamental inefficiency of the
classic @TArray@, but it's a /little/ faster and more compact.
Specifically, this implementation uses two fewer words of memory
and one fewer indirection per element.
We also add an 'MArray' instance for working in 'IO' that the 'stm'
version lacks.
Finally, the 'Eq' instance for the official @TArray@ is currently a little broken
thanks to a bug in the instance for @Data.Array.Array@ (See GHC Gitlab issue
#18700). We fix that bug here.
-}

module Data.Primitive.TArray.Classic (TArray) where
import GHC.Conc (STM, TVar, newTVar, readTVar, writeTVar
                , newTVarIO, readTVarIO, atomically)
import Data.Primitive.Unlifted.Array
import Data.Array.Base (MArray (..))
import Data.Ix (Ix, rangeSize)
import GHC.Exts (TVar#, RealWorld)

data TArray i a = TArray {
    forall i a. TArray i a -> i
_lb :: !i         -- the lower bound
  , forall i a. TArray i a -> i
_ub :: !i         -- the upper bound
  , forall i a. TArray i a -> Int
range :: !Int    -- A cache of (rangeSize (l, u))
                     -- used to make sure an index is really in range
  , forall i a.
TArray i a -> UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr :: !(UnliftedArray_ (TVar# RealWorld a) (TVar a))
  }
type role TArray nominal representational

instance Eq i => Eq (TArray i a) where
  -- There's no way for TVars to move from one TArray to another, so two of
  -- them are equal iff they're both empty, with the same bounds, or they're
  -- actually the same array. There's no "safe" way to check if they're the
  -- same array (though we can use `unsafeCoerce#` with
  -- `sameMutableUnliftedArray#` if we want to). But we can just do a quick size
  -- check and then look at the first TVar of each.
  --
  -- Note: The instance in stm leans on the instance for @Array@ in @base@. As
  -- of base-4.14.0.0, that instance is broken. See GHC Gitlab issue #18700. It
  -- looks like that's probably going to get fixed, so we fix it here.
  TArray i
lb1 i
ub1 Int
range1 UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr1 == :: TArray i a -> TArray i a -> Bool
== TArray i
lb2 i
ub2 Int
range2 UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr2
    | Int
range1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
range2 = Bool
False
      -- If the arrays are both empty, then they may still have been
      -- created with different bounds (e.g., (2,1) and (1,0)), so we
      -- check.
    | Int
range1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = i
lb1 i -> i -> Bool
forall a. Eq a => a -> a -> Bool
== i
lb2 Bool -> Bool -> Bool
&& i
ub1 i -> i -> Bool
forall a. Eq a => a -> a -> Bool
== i
ub2
      -- If the arrays are not empty, but the first TVar of each is the
      -- same, then they must have been created by the *same* newArray
      -- action. Therefore they are sure to have the same bounds, and
      -- are equal.
    | Bool
otherwise = UnliftedArray (TVar a) -> Int -> TVar a
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray UnliftedArray_ (TVar# RealWorld a) (TVar a)
UnliftedArray (TVar a)
arr1 Int
0 TVar a -> TVar a -> Bool
forall a. Eq a => a -> a -> Bool
== UnliftedArray (TVar a) -> Int -> TVar a
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray UnliftedArray_ (TVar# RealWorld a) (TVar a)
UnliftedArray (TVar a)
arr2 Int
0

instance MArray TArray e STM where
  getBounds :: forall i. Ix i => TArray i e -> STM (i, i)
getBounds (TArray i
l i
u Int
_ UnliftedArray_ (TVar# RealWorld e) (TVar e)
_) = (i, i) -> STM (i, i)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (i
l, i
u)
  newArray :: forall i. Ix i => (i, i) -> e -> STM (TArray i e)
newArray (i, i)
b e
e = do
    [TVar e]
tvs <- Int -> STM (TVar e) -> STM [TVar e]
forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
rep ((i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize (i, i)
b) (e -> STM (TVar e)
forall a. a -> STM (TVar a)
newTVar e
e)
    TArray i e -> STM (TArray i e)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TArray i e -> STM (TArray i e)) -> TArray i e -> STM (TArray i e)
forall a b. (a -> b) -> a -> b
$ (i, i) -> [TVar e] -> TArray i e
forall i e. Ix i => (i, i) -> [TVar e] -> TArray i e
listTArray (i, i)
b [TVar e]
tvs
  -- The stm version defines newArray_, but the default does the
  -- same thing.
  unsafeRead :: forall i. Ix i => TArray i e -> Int -> STM e
unsafeRead TArray i e
tarr Int
i = TVar e -> STM e
forall a. TVar a -> STM a
readTVar (TVar e -> STM e) -> TVar e -> STM e
forall a b. (a -> b) -> a -> b
$ UnliftedArray (TVar e) -> Int -> TVar e
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray (TArray i e -> UnliftedArray_ (TVar# RealWorld e) (TVar e)
forall i a.
TArray i a -> UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr TArray i e
tarr) Int
i
  unsafeWrite :: forall i. Ix i => TArray i e -> Int -> e -> STM ()
unsafeWrite TArray i e
tarr Int
i e
e = TVar e -> e -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (UnliftedArray (TVar e) -> Int -> TVar e
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray (TArray i e -> UnliftedArray_ (TVar# RealWorld e) (TVar e)
forall i a.
TArray i a -> UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr TArray i e
tarr) Int
i) e
e
  getNumElements :: forall i. Ix i => TArray i e -> STM Int
getNumElements !TArray i e
tarr = Int -> STM Int
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TArray i e -> Int
forall i a. TArray i a -> Int
range TArray i e
tarr)

-- | Writes are slow in 'IO'.
instance MArray TArray e IO where
  getBounds :: forall i. Ix i => TArray i e -> IO (i, i)
getBounds (TArray i
l i
u Int
_ UnliftedArray_ (TVar# RealWorld e) (TVar e)
_) = (i, i) -> IO (i, i)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (i
l, i
u)
  newArray :: forall i. Ix i => (i, i) -> e -> IO (TArray i e)
newArray (i, i)
b e
e = do
    [TVar e]
tvs <- Int -> IO (TVar e) -> IO [TVar e]
forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
rep ((i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize (i, i)
b) (e -> IO (TVar e)
forall a. a -> IO (TVar a)
newTVarIO e
e)
    TArray i e -> IO (TArray i e)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TArray i e -> IO (TArray i e)) -> TArray i e -> IO (TArray i e)
forall a b. (a -> b) -> a -> b
$ (i, i) -> [TVar e] -> TArray i e
forall i e. Ix i => (i, i) -> [TVar e] -> TArray i e
listTArray (i, i)
b [TVar e]
tvs
  -- The stm version defines newArray_, but the default does the
  -- same thing.
  unsafeRead :: forall i. Ix i => TArray i e -> Int -> IO e
unsafeRead TArray i e
tarr Int
i = TVar e -> IO e
forall a. TVar a -> IO a
readTVarIO (TVar e -> IO e) -> TVar e -> IO e
forall a b. (a -> b) -> a -> b
$ UnliftedArray (TVar e) -> Int -> TVar e
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray (TArray i e -> UnliftedArray_ (TVar# RealWorld e) (TVar e)
forall i a.
TArray i a -> UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr TArray i e
tarr) Int
i
  unsafeWrite :: forall i. Ix i => TArray i e -> Int -> e -> IO ()
unsafeWrite TArray i e
tarr Int
i e
e = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar e -> e -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (UnliftedArray (TVar e) -> Int -> TVar e
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
indexUnliftedArray (TArray i e -> UnliftedArray_ (TVar# RealWorld e) (TVar e)
forall i a.
TArray i a -> UnliftedArray_ (TVar# RealWorld a) (TVar a)
arr TArray i e
tarr) Int
i) e
e
  getNumElements :: forall i. Ix i => TArray i e -> IO Int
getNumElements !TArray i e
tarr = Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TArray i e -> Int
forall i a. TArray i a -> Int
range TArray i e
tarr)

-- | Stolen from stm:
-- Like 'replicateM' but uses an accumulator to prevent stack overflows.
-- Unlike 'replicateM' the returned list is in reversed order.
-- This doesn't matter though since this function is only used to create
-- arrays with identical elements.
--
-- TODO: For `IO`, we should surely build the array directly, rather
-- than first making a list. For STM, I'm *guessing* this would be a
-- safe place to use unsafeIOtoSTM to do the same.
rep :: Monad m => Int -> m a -> m [a]
rep :: forall (m :: * -> *) a. Monad m => Int -> m a -> m [a]
rep Int
n m a
m = Int -> [a] -> m [a]
forall {t}. (Eq t, Num t) => t -> [a] -> m [a]
go Int
n []
    where
      go :: t -> [a] -> m [a]
go t
0 [a]
xs = [a] -> m [a]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
xs
      go t
i [a]
xs = do
          a
x <- m a
m
          t -> [a] -> m [a]
go (t
it -> t -> t
forall a. Num a => a -> a -> a
-t
1) (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs)

listTArray :: Ix i => (i, i) -> [TVar e] -> TArray i e
listTArray :: forall i e. Ix i => (i, i) -> [TVar e] -> TArray i e
listTArray (i
l, i
u) [TVar e]
tvs = i
-> i
-> Int
-> UnliftedArray_ (TVar# RealWorld e) (TVar e)
-> TArray i e
forall i a.
i
-> i
-> Int
-> UnliftedArray_ (TVar# RealWorld a) (TVar a)
-> TArray i a
TArray i
l i
u Int
n (Int -> [TVar e] -> UnliftedArray (TVar e)
forall a. PrimUnlifted a => Int -> [a] -> UnliftedArray a
unliftedArrayFromListN Int
n [TVar e]
tvs)
  where
    !n :: Int
n = (i, i) -> Int
forall a. Ix a => (a, a) -> Int
rangeSize (i
l, i
u)