-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE MultiWayIf          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}

module Database.CQL.IO.Pool
    ( Pool
    , create
    , destroy
    , purge
    , with

    , PoolSettings
    , defSettings
    , idleTimeout
    , maxConnections
    , maxTimeouts
    , poolStripes
    ) where

import Control.Applicative
import Control.AutoUpdate
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Lens ((^.), makeLenses, view)
import Control.Monad.IO.Class
import Control.Monad hiding (forM_, mapM_)
import Data.Foldable (forM_, mapM_, find)
import Data.Function (on)
import Data.Hashable
import Data.IORef
import Prelude hiding (mapM_)
import Data.Sequence (Seq, ViewL (..), (|>), (><))
import Data.Time.Clock (UTCTime, NominalDiffTime, getCurrentTime, diffUTCTime)
import Data.Vector (Vector, (!))
import Database.CQL.IO.Connection (Connection)
import Database.CQL.IO.Types (Timeout, ignore)
import System.Logger hiding (create, defSettings, settings)

import qualified Data.Sequence as Seq
import qualified Data.Vector   as Vec

-----------------------------------------------------------------------------
-- API

data PoolSettings = PoolSettings
    { _idleTimeout    :: !NominalDiffTime
    , _maxConnections :: !Int
    , _maxTimeouts    :: !Int
    , _poolStripes    :: !Int
    }

data Pool = Pool
    { _createFn    :: !(IO Connection)
    , _destroyFn   :: !(Connection -> IO ())
    , _logger      :: !Logger
    , _settings    :: !PoolSettings
    , _maxRefs     :: !Int
    , _currentTime :: !(IO UTCTime)
    , _stripes     :: !(Vector Stripe)
    , _finaliser   :: !(IORef ())
    }

data Resource = Resource
    { tstamp   :: !UTCTime
    , refcnt   :: !Int
    , timeouts :: !Int
    , value    :: !Connection
    } deriving Show

data Box
    = New  !(IO Resource)
    | Used !Resource
    | Empty

data Stripe = Stripe
    { conns :: !(TVar (Seq Resource))
    , inUse :: !(TVar Int)
    }

makeLenses ''PoolSettings
makeLenses ''Pool

defSettings :: PoolSettings
defSettings = PoolSettings
    60 -- idle timeout
    2  -- max connections per stripe
    16 -- max timeouts per connection
    4  -- max stripes

create :: IO Connection -> (Connection -> IO ()) -> Logger -> PoolSettings -> Int -> IO Pool
create mk del g s k = do
    p <- Pool mk del g s k
            <$> mkAutoUpdate defaultUpdateSettings { updateAction = getCurrentTime }
            <*> Vec.replicateM (s^.poolStripes) (Stripe <$> newTVarIO Seq.empty <*> newTVarIO 0)
            <*> newIORef ()
    r <- async $ reaper p
    void $ mkWeakIORef (p^.finaliser) (cancel r >> destroy p)
    return p

destroy :: Pool -> IO ()
destroy = purge

with :: MonadIO m => Pool -> (Connection -> IO a) -> m (Maybe a)
with p f = liftIO $ do
    s <- stripe p
    mask $ \restore -> do
        r <- take1 p s
        case r of
            Just  v -> do
                x <- restore (f (value v)) `catches` handlers p s v
                put p s v id
                return (Just x)
            Nothing -> return Nothing

purge :: Pool -> IO ()
purge p = Vec.forM_ (p^.stripes) $ \s ->
    atomically (swapTVar (conns s) Seq.empty) >>= mapM_ (ignore . view destroyFn p . value)

-----------------------------------------------------------------------------
-- Internal

handlers :: Pool -> Stripe -> Resource -> [Handler a]
handlers p s r =
    [ Handler $ \(x :: Timeout)       -> onTimeout      >> throwIO x
    , Handler $ \(x :: SomeException) -> destroyR p s r >> throwIO x
    ]
  where
    onTimeout =
        if timeouts r > p^.settings.maxTimeouts
            then do
                info (p^.logger) $ msg (show (value r) +++ val " has too many timeouts.")
                destroyR p s r
            else put p s r incrTimeouts

take1 :: Pool -> Stripe -> IO (Maybe Resource)
take1 p s = do
    r <- atomically $ do
        c <- readTVar (conns s)
        u <- readTVar (inUse s)
        let n = Seq.length c
        check (u == n)
        let r :< rr = Seq.viewl $ Seq.unstableSortBy (compare `on` refcnt) c
        if | u < p^.settings.maxConnections -> do
                writeTVar (inUse s) $! u + 1
                mkNew p
           | n > 0 && refcnt r < p^.maxRefs -> use s r rr
           | otherwise                      -> return Empty
    case r of
        New io -> do
            x <- io `onException` atomically (modifyTVar' (inUse s) (subtract 1))
            atomically (modifyTVar' (conns s) (|> x))
            return (Just x)
        Used x -> return (Just x)
        Empty  -> return Nothing

use :: Stripe -> Resource -> Seq Resource -> STM Box
use s r rr = do
    writeTVar (conns s) $! rr |> r { refcnt = refcnt r + 1 }
    return (Used r)
{-# INLINE use #-}

mkNew :: Pool -> STM Box
mkNew p = return (New $ Resource <$> p^.currentTime <*> pure 1 <*> pure 0 <*> p^.createFn)
{-# INLINE mkNew #-}

put :: Pool -> Stripe -> Resource -> (Resource -> Resource) -> IO ()
put p s r f = do
    now <- p^.currentTime
    let updated x = f x { tstamp = now, refcnt = refcnt x - 1 }
    atomically $ do
        rs <- readTVar (conns s)
        let (xs, rr) = Seq.breakl ((value r ==) . value) rs
        case Seq.viewl rr of
            EmptyL  -> writeTVar (conns s) $! xs         |> updated r
            y :< ys -> writeTVar (conns s) $! (xs >< ys) |> updated y

destroyR :: Pool -> Stripe -> Resource -> IO ()
destroyR p s r = do
    atomically $ do
        rs <- readTVar (conns s)
        case find ((value r ==) . value) rs of
            Nothing -> return ()
            Just  _ -> do
                modifyTVar' (inUse s) (subtract 1)
                writeTVar (conns s) $! Seq.filter ((value r /=) . value) rs
    ignore $ p^.destroyFn $ value r

reaper :: Pool -> IO ()
reaper p = forever $ do
    threadDelay 1000000
    now <- p^.currentTime
    let isStale r = refcnt r == 0 && now `diffUTCTime` tstamp r > p^.settings.idleTimeout
    Vec.forM_ (p^.stripes) $ \s -> do
        x <- atomically $ do
                (stale, okay) <- Seq.partition isStale <$> readTVar (conns s)
                unless (Seq.null stale) $ do
                    writeTVar   (conns s) okay
                    modifyTVar' (inUse s) (subtract (Seq.length stale))
                return stale
        forM_ x $ \v -> ignore $ do
            trace (p^.logger) $ "reap" .= show (value v)
            p^.destroyFn $ (value v)

stripe :: Pool -> IO Stripe
stripe p = ((p^.stripes) !) <$> ((`mod` (p^.settings.poolStripes)) . hash) <$> myThreadId
{-# INLINE stripe #-}

incrTimeouts :: Resource -> Resource
incrTimeouts r = r { timeouts = timeouts r + 1 }
{-# INLINE incrTimeouts #-}