{-# LANGUAGE StrictData #-}
{-# LANGUAGE NoFieldSelectors #-}

module Sq.Pool
   ( Pool
   , pool
   , subPool
   , readTransaction
   , commitTransaction
   , rollbackTransaction
   )
where

import Control.Concurrent
import Control.DeepSeq
import Control.Exception.Safe qualified as Ex
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Resource.Extra qualified as R
import Data.Acquire qualified as A
import Data.Acquire.Internal qualified as A
import Data.Pool qualified as P
import Data.Word
import Di.Df1 qualified as Di
import GHC.Records
import Prelude hiding (Read, log, read)

import Sq.Connection
import Sq.Mode
import Sq.Support

--------------------------------------------------------------------------------

newtype PoolId = PoolId Word64
   deriving newtype (PoolId -> PoolId -> Bool
(PoolId -> PoolId -> Bool)
-> (PoolId -> PoolId -> Bool) -> Eq PoolId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PoolId -> PoolId -> Bool
== :: PoolId -> PoolId -> Bool
$c/= :: PoolId -> PoolId -> Bool
/= :: PoolId -> PoolId -> Bool
Eq, Eq PoolId
Eq PoolId =>
(PoolId -> PoolId -> Ordering)
-> (PoolId -> PoolId -> Bool)
-> (PoolId -> PoolId -> Bool)
-> (PoolId -> PoolId -> Bool)
-> (PoolId -> PoolId -> Bool)
-> (PoolId -> PoolId -> PoolId)
-> (PoolId -> PoolId -> PoolId)
-> Ord PoolId
PoolId -> PoolId -> Bool
PoolId -> PoolId -> Ordering
PoolId -> PoolId -> PoolId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: PoolId -> PoolId -> Ordering
compare :: PoolId -> PoolId -> Ordering
$c< :: PoolId -> PoolId -> Bool
< :: PoolId -> PoolId -> Bool
$c<= :: PoolId -> PoolId -> Bool
<= :: PoolId -> PoolId -> Bool
$c> :: PoolId -> PoolId -> Bool
> :: PoolId -> PoolId -> Bool
$c>= :: PoolId -> PoolId -> Bool
>= :: PoolId -> PoolId -> Bool
$cmax :: PoolId -> PoolId -> PoolId
max :: PoolId -> PoolId -> PoolId
$cmin :: PoolId -> PoolId -> PoolId
min :: PoolId -> PoolId -> PoolId
Ord, Int -> PoolId -> ShowS
[PoolId] -> ShowS
PoolId -> String
(Int -> PoolId -> ShowS)
-> (PoolId -> String) -> ([PoolId] -> ShowS) -> Show PoolId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PoolId -> ShowS
showsPrec :: Int -> PoolId -> ShowS
$cshow :: PoolId -> String
show :: PoolId -> String
$cshowList :: [PoolId] -> ShowS
showList :: [PoolId] -> ShowS
Show, PoolId -> ()
(PoolId -> ()) -> NFData PoolId
forall a. (a -> ()) -> NFData a
$crnf :: PoolId -> ()
rnf :: PoolId -> ()
NFData, PoolId -> Value
(PoolId -> Value) -> ToValue PoolId
forall a. (a -> Value) -> ToValue a
$cvalue :: PoolId -> Value
value :: PoolId -> Value
Di.ToValue)

newPoolId :: (MonadIO m) => m PoolId
newPoolId :: forall (m :: * -> *). MonadIO m => m PoolId
newPoolId = Word64 -> PoolId
PoolId (Word64 -> PoolId) -> m Word64 -> m PoolId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Word64
forall (m :: * -> *). MonadIO m => m Word64
newUnique

--------------------------------------------------------------------------------

-- | Pool of connections to a SQLite database.
--
-- * @p@ indicates whether 'Read'-only or read-'Write' 'Statement's are
-- supported by this 'Pool'.
--
-- * Obtain with 'Sq.readPool', 'Sq.writePool' or 'Sq.tempPool'.
--
-- * It's safe and efficient to use a 'Pool' concurrently as is.
-- Concurrency is handled internally.
data Pool (p :: Mode) where
   Pool_Read
      :: PoolId
      -> P.Pool (A.Allocated (Connection Read))
      -> Pool Read
   Pool_Write
      :: PoolId
      -> Connection Write
      -> P.Pool (A.Allocated (Connection Read))
      -> Pool Write

-- | Use 'subPool' to obtain the 'Read'-only subset from a read-'Write' 'Pool'.
--
-- * Useful if you are passing the 'Pool' as an argument to some code,
-- and you want to ensure that it can't performs 'Write' operations on it.
--
-- * The “new” 'Pool' is not new. It shares all the underlying resources with the
-- original one, including their lifetime.
subPool :: Pool 'Write -> Pool 'Read
subPool :: Pool 'Write -> Pool 'Read
subPool (Pool_Write PoolId
i Connection 'Write
_w Pool (Allocated (Connection 'Read))
r) =
   -- It's alright to "forget" about '_w' here. The original 'Write' pool
   -- is the one that deals with resource management, anyway.
   PoolId -> Pool (Allocated (Connection 'Read)) -> Pool 'Read
Pool_Read PoolId
i Pool (Allocated (Connection 'Read))
r

instance NFData (Pool p) where
   rnf :: Pool p -> ()
rnf (Pool_Read !PoolId
_ !Pool (Allocated (Connection 'Read))
_) = ()
   rnf (Pool_Write !PoolId
_ Connection 'Write
a !Pool (Allocated (Connection 'Read))
_) = Connection 'Write -> ()
forall a. NFData a => a -> ()
rnf Connection 'Write
a

instance HasField "id" (Pool p) PoolId where
   getField :: Pool p -> PoolId
getField = \case
      Pool_Read PoolId
x Pool (Allocated (Connection 'Read))
_ -> PoolId
x
      Pool_Write PoolId
x Connection 'Write
_ Pool (Allocated (Connection 'Read))
_ -> PoolId
x

pool :: SMode p -> Di.Df1 -> Settings -> A.Acquire (Pool p)
pool :: forall (p :: Mode). SMode p -> Df1 -> Settings -> Acquire (Pool p)
pool SMode p
smode Df1
di0 Settings
cs = do
   PoolId
pId <- Acquire PoolId
forall (m :: * -> *). MonadIO m => m PoolId
newPoolId
   let di1 :: Df1
di1 = Key -> SMode p -> Df1 -> Df1
forall value level msg.
ToValue value =>
Key -> value -> Di level Path msg -> Di level Path msg
Di.attr Key
"pool-mode" SMode p
smode (Df1 -> Df1) -> Df1 -> Df1
forall a b. (a -> b) -> a -> b
$ Key -> PoolId -> Df1 -> Df1
forall value level msg.
ToValue value =>
Key -> value -> Di level Path msg -> Di level Path msg
Di.attr Key
"pool" PoolId
pId Df1
di0
   Pool (Allocated (Connection 'Read))
ppcr <- Df1 -> Acquire (Pool (Allocated (Connection 'Read)))
ppoolConnRead Df1
di1
   case SMode p
smode of
      SMode p
SRead -> Pool 'Read -> Acquire (Pool 'Read)
forall a. a -> Acquire a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pool 'Read -> Acquire (Pool 'Read))
-> Pool 'Read -> Acquire (Pool 'Read)
forall a b. (a -> b) -> a -> b
$ PoolId -> Pool (Allocated (Connection 'Read)) -> Pool 'Read
Pool_Read PoolId
pId Pool (Allocated (Connection 'Read))
ppcr
      SMode p
SWrite -> do
         Connection 'Write
cw <- SMode 'Write -> Df1 -> Settings -> Acquire (Connection 'Write)
forall (mode :: Mode) (c :: Mode).
SMode mode -> Df1 -> Settings -> Acquire (Connection c)
connection SMode 'Write
SWrite Df1
di1 Settings
cs
         Pool 'Write -> Acquire (Pool 'Write)
forall a. a -> Acquire a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pool 'Write -> Acquire (Pool 'Write))
-> Pool 'Write -> Acquire (Pool 'Write)
forall a b. (a -> b) -> a -> b
$ PoolId
-> Connection 'Write
-> Pool (Allocated (Connection 'Read))
-> Pool 'Write
Pool_Write PoolId
pId Connection 'Write
cw Pool (Allocated (Connection 'Read))
ppcr
  where
   ppoolConnRead
      :: Di.Df1 -> A.Acquire (P.Pool (A.Allocated (Connection Read)))
   ppoolConnRead :: Df1 -> Acquire (Pool (Allocated (Connection 'Read)))
ppoolConnRead Df1
di1 =
      ((forall x. IO x -> IO x)
 -> IO (Pool (Allocated (Connection 'Read))))
-> (Pool (Allocated (Connection 'Read)) -> IO ())
-> Acquire (Pool (Allocated (Connection 'Read)))
forall a.
((forall x. IO x -> IO x) -> IO a) -> (a -> IO ()) -> Acquire a
R.acquire1
         ( \forall x. IO x -> IO x
res -> do
            let A.Acquire (forall x. IO x -> IO x) -> IO (Allocated (Connection 'Read))
f = SMode 'Read -> Df1 -> Settings -> Acquire (Connection 'Read)
forall (mode :: Mode) (c :: Mode).
SMode mode -> Df1 -> Settings -> Acquire (Connection c)
connection SMode 'Read
SRead Df1
di1 Settings
cs
            Int
maxResources <- Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Int -> Int) -> IO Int -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int
getNumCapabilities
            PoolConfig (Allocated (Connection 'Read))
-> IO (Pool (Allocated (Connection 'Read)))
forall a. PoolConfig a -> IO (Pool a)
P.newPool (PoolConfig (Allocated (Connection 'Read))
 -> IO (Pool (Allocated (Connection 'Read))))
-> PoolConfig (Allocated (Connection 'Read))
-> IO (Pool (Allocated (Connection 'Read)))
forall a b. (a -> b) -> a -> b
$
               IO (Allocated (Connection 'Read))
-> (Allocated (Connection 'Read) -> IO ())
-> Double
-> Int
-> PoolConfig (Allocated (Connection 'Read))
forall a. IO a -> (a -> IO ()) -> Double -> Int -> PoolConfig a
P.defaultPoolConfig
                  ((forall x. IO x -> IO x) -> IO (Allocated (Connection 'Read))
f IO b -> IO b
forall x. IO x -> IO x
res)
                  (\(A.Allocated Connection 'Read
_ ReleaseType -> IO ()
g) -> ReleaseType -> IO ()
g ReleaseType
A.ReleaseNormal)
                  (Double
60 {- timeout seconds -})
                  Int
maxResources
         )
         Pool (Allocated (Connection 'Read)) -> IO ()
forall a. Pool a -> IO ()
P.destroyAllResources

-- | Acquire a read-only transaction.
--
-- * You may need this function if you are using one of 'Sq.embed',
-- 'Sq.foldIO' or 'Sq.streamIO'. Otherwise, just use 'Sq.read'.
readTransaction :: Pool mode -> A.Acquire (Transaction 'Read)
readTransaction :: forall (mode :: Mode). Pool mode -> Acquire (Transaction 'Read)
readTransaction Pool mode
p = Pool mode -> Acquire (Connection 'Read)
forall (mode :: Mode). Pool mode -> Acquire (Connection 'Read)
poolConnectionRead Pool mode
p Acquire (Connection 'Read)
-> (Connection 'Read -> Acquire (Transaction 'Read))
-> Acquire (Transaction 'Read)
forall a b. Acquire a -> (a -> Acquire b) -> Acquire b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection 'Read -> Acquire (Transaction 'Read)
forall (c :: Mode).
SubMode c 'Read =>
Connection c -> Acquire (Transaction 'Read)
connectionReadTransaction

-- | Acquire a read-write transaction where changes are finally commited to
-- the database unless there is an unhandled exception during the transaction,
-- in which case they are rolled back.
--
-- * You may need this function if you are using one of 'Sq.embed',
-- 'Sq.foldIO' or 'Sq.streamIO'. Otherwise, just use 'Sq.commit'.
commitTransaction :: Pool Write -> A.Acquire (Transaction 'Write)
commitTransaction :: Pool 'Write -> Acquire (Transaction 'Write)
commitTransaction (Pool_Write PoolId
_ Connection 'Write
c Pool (Allocated (Connection 'Read))
_) = Bool -> Connection 'Write -> Acquire (Transaction 'Write)
connectionWriteTransaction Bool
True Connection 'Write
c

-- | Acquire a read-write transaction where changes are always rolled back.
-- This is mostly useful for testing purposes.
--
-- * You may need this function if you are using one of 'Sq.embed',
-- 'Sq.foldIO' or 'Sq.streamIO'. Otherwise, just use 'Sq.commit'.
--
-- * An equivalent behavior can be achieved by
-- 'Control.Exception.Safe.bracket'ing changes between 'Sq.savepoint' and
-- 'Sq.rollbackTo' in a 'commitTransaction'ting transaction. Or by using
-- 'Ex.throwM' and 'Ex.catch' within 'Transactional'. However, using a
-- 'rollbackTransaction' is much faster than using 'Sq.Savepoint's.
rollbackTransaction :: Pool Write -> A.Acquire (Transaction 'Write)
rollbackTransaction :: Pool 'Write -> Acquire (Transaction 'Write)
rollbackTransaction (Pool_Write PoolId
_ Connection 'Write
c Pool (Allocated (Connection 'Read))
_) = Bool -> Connection 'Write -> Acquire (Transaction 'Write)
connectionWriteTransaction Bool
False Connection 'Write
c

poolConnectionRead :: Pool mode -> A.Acquire (Connection Read)
poolConnectionRead :: forall (mode :: Mode). Pool mode -> Acquire (Connection 'Read)
poolConnectionRead Pool mode
p = do
   let ppr :: Pool (Allocated (Connection 'Read))
ppr = case Pool mode
p of Pool_Write PoolId
_ Connection 'Write
_ Pool (Allocated (Connection 'Read))
x -> Pool (Allocated (Connection 'Read))
x; Pool_Read PoolId
_ Pool (Allocated (Connection 'Read))
x -> Pool (Allocated (Connection 'Read))
x
   ThreadId
tid <- (IO ThreadId -> (ThreadId -> IO ()) -> Acquire ThreadId)
-> (ThreadId -> IO ()) -> IO ThreadId -> Acquire ThreadId
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO ThreadId -> (ThreadId -> IO ()) -> Acquire ThreadId
forall a. IO a -> (a -> IO ()) -> Acquire a
R.mkAcquire1 ThreadId -> IO ()
killThread (IO ThreadId -> Acquire ThreadId)
-> IO ThreadId -> Acquire ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
      Int -> IO ()
threadDelay Int
10_000_000
      String -> IO ()
putStrLn String
"Waited 10 seconds to acquire a database connection from the pool"
   ((Allocated (Connection 'Read),
  LocalPool (Allocated (Connection 'Read)))
 -> Connection 'Read)
-> Acquire
     (Allocated (Connection 'Read),
      LocalPool (Allocated (Connection 'Read)))
-> Acquire (Connection 'Read)
forall a b. (a -> b) -> Acquire a -> Acquire b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(A.Allocated Connection 'Read
c ReleaseType -> IO ()
_, LocalPool (Allocated (Connection 'Read))
_) -> Connection 'Read
c) do
      IO
  (Allocated (Connection 'Read),
   LocalPool (Allocated (Connection 'Read)))
-> ((Allocated (Connection 'Read),
     LocalPool (Allocated (Connection 'Read)))
    -> ReleaseType -> IO ())
-> Acquire
     (Allocated (Connection 'Read),
      LocalPool (Allocated (Connection 'Read)))
forall a. IO a -> (a -> ReleaseType -> IO ()) -> Acquire a
R.mkAcquireType1
         (Pool (Allocated (Connection 'Read))
-> IO
     (Allocated (Connection 'Read),
      LocalPool (Allocated (Connection 'Read)))
forall a. Pool a -> IO (a, LocalPool a)
P.takeResource Pool (Allocated (Connection 'Read))
ppr IO
  (Allocated (Connection 'Read),
   LocalPool (Allocated (Connection 'Read)))
-> IO ()
-> IO
     (Allocated (Connection 'Read),
      LocalPool (Allocated (Connection 'Read)))
forall a b. IO a -> IO b -> IO a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ThreadId -> IO ()
killThread ThreadId
tid)
         \(a :: Allocated (Connection 'Read)
a@(A.Allocated Connection 'Read
_ ReleaseType -> IO ()
rel), LocalPool (Allocated (Connection 'Read))
lp) ReleaseType
t -> case ReleaseType
t of
            A.ReleaseExceptionWith SomeException
_ ->
               ReleaseType -> IO ()
rel ReleaseType
t IO () -> IO () -> IO ()
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
`Ex.finally` Pool (Allocated (Connection 'Read))
-> LocalPool (Allocated (Connection 'Read))
-> Allocated (Connection 'Read)
-> IO ()
forall a. Pool a -> LocalPool a -> a -> IO ()
P.destroyResource Pool (Allocated (Connection 'Read))
ppr LocalPool (Allocated (Connection 'Read))
lp Allocated (Connection 'Read)
a
            ReleaseType
_ -> LocalPool (Allocated (Connection 'Read))
-> Allocated (Connection 'Read) -> IO ()
forall a. LocalPool a -> a -> IO ()
P.putResource LocalPool (Allocated (Connection 'Read))
lp Allocated (Connection 'Read)
a