-- | Working with 'Foreign.Ptr's in a way that prevents use after free
--
-- >>> :set -XPostfixOperators
-- >>> import Control.Monad.Scoped.Internal
-- >>> scoped do x <- mut (69 :: Word); x .= 42; (x ?)
-- 42
module Control.Monad.Scoped.Ptr
  ( Ptr
  , mut
  , (.=)
  , (?)
  )
where

import Control.Monad.Scoped.Internal (Scoped, ScopedResource (unsafeUnwrapScopedResource), bracketScoped, (:<))
import Foreign qualified
import UnliftIO (MonadIO (liftIO), MonadUnliftIO)

-- | A 'Foreign.Ptr' that is associated to a scope but it is mutable (can be read from and written to)
type Ptr s a = ScopedResource s (Foreign.Ptr a)

-- | Acquire mutable memory for the duration of a scope. The value is automatically dropped at the end of the scope.
mut :: (Foreign.Storable a, MonadUnliftIO m) => a -> Scoped (s : ss) m (Ptr s a)
mut :: forall a (m :: Type -> Type) s (ss :: [Type]).
(Storable a, MonadUnliftIO m) =>
a -> Scoped (s : ss) m (Ptr s a)
mut a
a = m (Ptr a)
-> (Ptr a -> m ()) -> Scoped (s : ss) m (ScopedResource s (Ptr a))
forall (m :: Type -> Type) a b s (ss :: [Type]).
MonadUnliftIO m =>
m a -> (a -> m b) -> Scoped (s : ss) m (ScopedResource s a)
bracketScoped (IO (Ptr a) -> m (Ptr a)
forall a. IO a -> m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr a) -> m (Ptr a)) -> IO (Ptr a) -> m (Ptr a)
forall a b. (a -> b) -> a -> b
$ a -> IO (Ptr a)
forall a. Storable a => a -> IO (Ptr a)
Foreign.new a
a) (IO () -> m ()
forall a. IO a -> m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Ptr a -> IO ()) -> Ptr a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> IO ()
forall a. Ptr a -> IO ()
Foreign.free)

-- | write a value to a pointer
(.=) :: (Foreign.Storable a, MonadIO m, s :< ss) => Ptr s a -> a -> Scoped ss m ()
.= :: forall a (m :: Type -> Type) s (ss :: [Type]).
(Storable a, MonadIO m, s :< ss) =>
Ptr s a -> a -> Scoped ss m ()
(.=) Ptr s a
ptr = IO () -> Scoped ss m ()
forall a. IO a -> Scoped ss m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO () -> Scoped ss m ()) -> (a -> IO ()) -> a -> Scoped ss m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
Foreign.poke (Ptr s a -> Ptr a
forall s a. ScopedResource s a -> a
unsafeUnwrapScopedResource Ptr s a
ptr)

-- | read a value from a pointer
(?) :: (Foreign.Storable a, MonadIO m, s :< ss) => Ptr s a -> Scoped ss m a
? :: forall a (m :: Type -> Type) s (ss :: [Type]).
(Storable a, MonadIO m, s :< ss) =>
Ptr s a -> Scoped ss m a
(?) Ptr s a
ptr = IO a -> Scoped ss m a
forall a. IO a -> Scoped ss m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
Foreign.peek (Ptr s a -> Ptr a
forall s a. ScopedResource s a -> a
unsafeUnwrapScopedResource Ptr s a
ptr))