module Database.PostgreSQL.PQTypes.Internal.Connection (
Connection(..)
, ConnectionData(..)
, withConnectionData
, ConnectionStats(..)
, ConnectionSettings(..)
, def
, ConnectionSourceM(..)
, ConnectionSource(..)
, simpleSource
, poolSource
, connect
, disconnect
) where
import Control.Arrow (first)
import Control.Concurrent
import Control.Monad
import Control.Monad.Base
import Control.Monad.Catch
import Data.Default.Class
import Data.Function
import Data.Pool
import Data.Time.Clock
import Foreign.C.String
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Exts
import qualified Control.Exception as E
import qualified Data.ByteString as BS
import qualified Data.Foldable as F
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Composite
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Utils
data ConnectionSettings = ConnectionSettings {
csConnInfo :: !T.Text
, csClientEncoding :: !(Maybe T.Text)
, csComposites :: ![T.Text]
} deriving (Eq, Ord, Show)
instance Default ConnectionSettings where
def = ConnectionSettings {
csConnInfo = T.empty
, csClientEncoding = Just "UTF-8"
, csComposites = []
}
data ConnectionStats = ConnectionStats {
statsQueries :: !Int
, statsRows :: !Int
, statsValues :: !Int
, statsParams :: !Int
} deriving (Eq, Ord, Show)
initialStats :: ConnectionStats
initialStats = ConnectionStats {
statsQueries = 0
, statsRows = 0
, statsValues = 0
, statsParams = 0
}
data ConnectionData = ConnectionData {
cdFrgnPtr :: !(ForeignPtr (Ptr PGconn))
, cdPtr :: !(Ptr PGconn)
, cdStats :: !ConnectionStats
}
newtype Connection = Connection {
unConnection :: MVar (Maybe ConnectionData)
}
withConnectionData
:: Connection
-> String
-> (ConnectionData -> IO (ConnectionData, r))
-> IO r
withConnectionData (Connection mvc) fname f =
modifyMVar mvc $ \mc -> case mc of
Nothing -> hpqTypesError $ fname ++ ": no connection"
Just cd -> first Just <$> f cd
newtype ConnectionSourceM m = ConnectionSourceM {
withConnection :: forall r. (Connection -> m r) -> m r
}
newtype ConnectionSource (cs :: [(* -> *) -> Constraint]) = ConnectionSource {
unConnectionSource :: forall m. MkConstraint m cs => ConnectionSourceM m
}
simpleSource
:: ConnectionSettings
-> ConnectionSource [MonadBase IO, MonadMask]
simpleSource cs = ConnectionSource $ ConnectionSourceM {
withConnection = bracket (liftBase $ connect cs) (liftBase . disconnect)
}
poolSource
:: ConnectionSettings
-> Int
-> NominalDiffTime
-> Int
-> IO (ConnectionSource [MonadBase IO, MonadMask])
poolSource cs numStripes idleTime maxResources = do
pool <- createPool (connect cs) disconnect numStripes idleTime maxResources
return $ ConnectionSource $ ConnectionSourceM {
withConnection = withResource' pool . (clearStats >=>)
}
where
withResource' pool m = mask $ \restore -> do
(resource, local) <- liftBase $ takeResource pool
ret <- restore (m resource) `onException`
liftBase (destroyResource pool local resource)
liftBase $ putResource local resource
return ret
clearStats conn@(Connection mv) = do
liftBase . modifyMVar_ mv $ \mconn ->
return $ (\cd -> cd { cdStats = initialStats }) <$> mconn
return conn
connect :: ConnectionSettings -> IO Connection
connect ConnectionSettings{..} = do
fconn <- BS.useAsCString (T.encodeUtf8 csConnInfo) openConnection
withForeignPtr fconn $ \connPtr -> do
conn <- peek connPtr
status <- c_PQstatus conn
when (status /= c_CONNECTION_OK) $
throwLibPQError conn fname
F.forM_ csClientEncoding $ \enc -> do
res <- BS.useAsCString (T.encodeUtf8 enc) (c_PQsetClientEncoding conn)
when (res == -1) $
throwLibPQError conn fname
c_PQinitTypes conn
registerComposites conn csComposites
Connection <$> newMVar (Just ConnectionData {
cdFrgnPtr = fconn
, cdPtr = conn
, cdStats = initialStats
})
where
fname = "connect"
openConnection :: CString -> IO (ForeignPtr (Ptr PGconn))
openConnection conninfo = E.mask $ \restore -> do
conn <- c_PQconnectStart conninfo
when (conn == nullPtr) $
throwError "PQconnectStart returned a null pointer"
connPtr <- mallocForeignPtr
withForeignPtr connPtr (`poke` conn)
addForeignPtrFinalizer c_ptr_PQfinishPtr connPtr
restore $ fix $ \loop -> do
ps <- c_PQconnectPoll conn
if | ps == c_PGRES_POLLING_READING -> (threadWaitRead =<< getFd conn) >> loop
| ps == c_PGRES_POLLING_WRITING -> (threadWaitWrite =<< getFd conn) >> loop
| otherwise -> return connPtr
where
getFd conn = do
fd <- c_PQsocket conn
when (fd == -1) $
throwError "invalid file descriptor"
return fd
throwError = hpqTypesError . (fname ++) . (": " ++)
disconnect :: Connection -> IO ()
disconnect (Connection mvconn) = modifyMVar_ mvconn $ \mconn -> do
case mconn of
Just cd -> withForeignPtr (cdFrgnPtr cd) c_PQfinishPtr
Nothing -> E.throwIO (HPQTypesError "disconnect: no connection (shouldn't happen)")
return Nothing