{-# LANGUAGE TypeApplications, MultiWayIf, OverloadedStrings #-}

module HsDev.Database.SQLite.Transaction (
	TransactionType(..),
	Retries(..), def, noRetry, retryForever, retryN,

	-- * Transactions
	withTransaction, beginTransaction, commitTransaction, rollbackTransaction,
	transaction, transaction_,

	-- * Retry functions
	retry, retry_
	) where

import Control.Concurrent
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.Default
import Database.SQLite.Simple as SQL hiding (withTransaction)

import HsDev.Server.Types (SessionMonad, serverSqlDatabase)

-- | Three types of transactions
data TransactionType = Deferred | Immediate | Exclusive
	deriving (Eq, Ord, Read, Show)

-- | Retry config
data Retries = Retries {
	retriesIntervals :: [Int],
	retriesError :: SQLError -> Bool }

instance Default Retries where
	def = Retries (replicate 10 100000) $ \e -> sqlError e `elem` [ErrorBusy, ErrorLocked]

-- | Don't retry
noRetry :: Retries
noRetry = Retries [] (const False)

-- | Retry forever
retryForever :: Int -> Retries
retryForever interval = def { retriesIntervals = repeat interval }

-- | Retry with interval N times
retryN :: Int -> Int -> Retries
retryN interval times = def { retriesIntervals = replicate times interval }

-- | Run actions inside transaction
withTransaction :: (MonadIO m, MonadMask m) => Connection -> TransactionType -> Retries -> m a -> m a
withTransaction conn t rs act = mask $ \restore -> do
	mretry restore (beginTransaction conn t)
	(restore act <* mretry restore (commitTransaction conn)) `onException` rollbackTransaction conn
	where
		mretry restore' fn = mretry' (retriesIntervals rs) where
			mretry' [] = fn
			mretry' (tm:tms) = catch @_ @SQLError fn $ \e -> if
				| retriesError rs e -> do
						_ <- restore' $ liftIO $ threadDelay tm
						mretry' tms
				| otherwise -> throwM e

-- | Begin transaction
beginTransaction :: MonadIO m => Connection -> TransactionType -> m ()
beginTransaction conn t = liftIO $ SQL.execute_ conn $ case t of
	Deferred -> "begin transaction;"
	Immediate -> "begin immediate transaction;"
	Exclusive -> "begin exclusive transaction;"

-- | Commit transaction
commitTransaction :: MonadIO m => Connection -> m ()
commitTransaction conn = liftIO $ SQL.execute_ conn "commit transaction;"

-- | Rollback transaction
rollbackTransaction :: MonadIO m => Connection -> m ()
rollbackTransaction conn = liftIO $ SQL.execute_ conn "rollback transaction;"

-- | Run transaction in @SessionMonad@
transaction :: SessionMonad m => TransactionType -> Retries -> m a -> m a
transaction t rs act = do
	conn <- serverSqlDatabase
	withTransaction conn t rs act

-- | Transaction with default retries config
transaction_ :: SessionMonad m => TransactionType -> m a -> m a
transaction_ t = transaction t def

-- | Retry operation
retry :: (MonadIO m, MonadCatch m) => Retries -> m a -> m a
retry rs = retry' (retriesIntervals rs) where
	retry' [] fn = fn
	retry' (tm:tms) fn = catch @_ @SQLError fn $ \e -> if
		| retriesError rs e -> do
			liftIO $ threadDelay tm
			retry' tms fn
		| otherwise -> throwM e

-- | Retry with default params
retry_ :: (MonadIO m, MonadCatch m) => m a -> m a
retry_ = retry def