-- | This module is designed to be imported qualified. It provides unsafe
-- operations over the reference counted structures in order to build new
-- primitives wrt linearity.
--
-- @
-- import qualified Data.Counted.Unsafe as Unsafe
-- @
{-# LANGUAGE LinearTypes, NoImplicitPrelude, QualifiedDo, UnicodeSyntax,
   BlockArguments #-}
module Data.Linear.Alias.Unsafe where

import Prelude.Linear
import Control.Functor.Linear as Linear
import Control.Monad.IO.Class.Linear

import qualified Control.Concurrent.Counter as Counter

import Data.Linear.Alias.Internal

import qualified Unsafe.Linear as Unsafe

-- | Unsafely increment the counter of some reference counted resource
inc :: MonadIO m => Alias m' a  m (Alias m' a)
inc :: forall (m :: * -> *) (m' :: * -> *) a.
MonadIO m =>
Alias m' a %1 -> m (Alias m' a)
inc = (Alias m' a -> m (Alias m' a))
%1 -> Alias m' a %1 -> m (Alias m' a)
forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear \(Alias a %1 -> m' ()
f Counter
counter a
a) -> Linear.do
  Ur _ <- Counter -> m (Ur Int)
forall {m :: * -> *}. MonadIO m => Counter -> m (Ur Int)
incCounter Counter
counter -- increment reference count
  pure (Alias f counter a)
  where
    incCounter :: Counter -> m (Ur Int)
incCounter Counter
c = IO Int -> m (Ur Int)
forall a. IO a -> m (Ur a)
forall (m :: * -> *) a. MonadIO m => IO a -> m (Ur a)
liftSystemIOU (Counter -> Int -> IO Int
Counter.add Counter
c Int
1)

-- | Unsafely decrement the counter of some reference counted resource and get
-- the reference counted value (it's really, really quite unsafe).
--
-- This doesn't free the resource if the reference count reaches 0.
dec :: MonadIO m => Alias μ a  m (Ur a)
dec :: forall (m :: * -> *) (μ :: * -> *) a.
MonadIO m =>
Alias μ a %1 -> m (Ur a)
dec = (Alias μ a -> m (Ur a)) %1 -> Alias μ a %1 -> m (Ur a)
forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear \(Alias a %1 -> μ ()
_ Counter
counter a
a) -> Linear.do
  Ur _ <- Counter -> m (Ur Int)
forall {m :: * -> *}. MonadIO m => Counter -> m (Ur Int)
decCounter Counter
counter -- decrement reference count
  pure (Ur a)
  where
    decCounter :: Counter -> m (Ur Int)
decCounter Counter
c = IO Int -> m (Ur Int)
forall a. IO a -> m (Ur a)
forall (m :: * -> *) a. MonadIO m => IO a -> m (Ur a)
liftSystemIOU (Counter -> Int -> IO Int
Counter.sub Counter
c Int
1)

-- | Unsafely get an aliased value. All counters are kept unchanged.
get :: Alias m' a -> a
get :: forall (m' :: * -> *) a. Alias m' a -> a
get (Alias a %1 -> m' ()
_ Counter
_ a
a) = a
a