module Database.Seakale.PostgreSQL
( Connection
, ConnectInfo(..)
, connect
, connectString
, disconnect
, Request
, RequestT
, runRequest
, runRequestT
, Select
, SelectT
, runSelect
, runSelectT
, Store
, StoreT
, runStore
, runStoreT
, HasConnection(..)
, PSQL(..)
, defaultPSQL
, SeakaleError(..)
, T.Query(..)
, Field(..)
, Row
, ColumnInfo(..)
, QueryData
, Vector(..)
, cons, (<:>)
, nil, (<:|)
, Zero
, One
, Two
, Three
, Four
, Five
, Six
, Seven
, Eight
, Nine
, Ten
) where
import Control.Monad.Except
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Trans.Free
import Data.Maybe
import Data.Monoid
import Data.Word
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BSL
import System.IO
import Database.PostgreSQL.LibPQ hiding (Row, status)
import Database.Seakale.Types
hiding (runQuery, runExecute, EmptyQuery)
import qualified Database.Seakale.Request.Internal as I
import qualified Database.Seakale.Store.Internal as I
import qualified Database.Seakale.Types as T
data ConnectInfo = ConnectInfo
{ ciHostname :: String
, ciPort :: Word16
, ciUsername :: String
, ciPassword :: String
, ciDatabase :: String
}
toConnectionString :: ConnectInfo -> BS.ByteString
toConnectionString ConnectInfo{..} =
"host=" <> quote ciHostname
<> " port=" <> BS.pack (show ciPort)
<> " user=" <> quote ciUsername
<> " password=" <> quote ciPassword
<> " dbname=" <> quote ciDatabase
where
quote :: String -> BS.ByteString
quote = ("'" <>) . (<> "'") . escapeQuotes "" . BS.pack
escapeQuotes :: BS.ByteString -> BS.ByteString -> BS.ByteString
escapeQuotes _ "" = ""
escapeQuotes prefix s =
let (start, end) = fmap (BS.drop 1) $ BS.break (=='\'') s
in prefix <> start <> escapeQuotes "''" end
connect :: ConnectInfo -> IO Connection
connect = connectString . toConnectionString
connectString :: BS.ByteString -> IO Connection
connectString = connectdb
disconnect :: Connection -> IO ()
disconnect = finish
class Monad m => HasConnection m where
withConn :: (Connection -> m a) -> m a
instance Monad m => HasConnection (ReaderT Connection m) where
withConn f = f =<< ask
instance HasConnection m => HasConnection (ExceptT e m) where
withConn f = do
eRes <- lift $ withConn $ runExceptT . f
either throwError return eRes
instance HasConnection m => HasConnection (StateT s m) where
withConn f = do
s <- get
(x, s') <- lift $ withConn $ flip runStateT s . f
put s'
return x
type TypeCache = [(Oid, BS.ByteString)]
data PSQL = PSQL { psqlLogQueries :: Bool }
defaultPSQL :: PSQL
defaultPSQL = PSQL False
instance Backend PSQL where
type ColumnType PSQL = BS.ByteString
type MonadBackend PSQL m = ( HasConnection m
, MonadState TypeCache m
, MonadIO m
)
runQuery backend = runExceptT . runQuery backend
runExecute backend = runExceptT . runExecute backend
type RequestT = I.RequestT PSQL
type Request = I.Request PSQL
runRequestT :: (HasConnection m, MonadIO m) => PSQL -> RequestT m a
-> m (Either SeakaleError a)
runRequestT backend =
fmap fst . flip runStateT [] . I.runRequestT backend . hoistFreeT lift
runRequest :: (HasConnection m, MonadIO m) => PSQL -> Request a
-> m (Either SeakaleError a)
runRequest backend = runRequestT backend . hoistFreeT (return . runIdentity)
type SelectT = I.SelectT PSQL
type Select = I.Select PSQL
runSelectT :: Monad m => SelectT m a -> RequestT m a
runSelectT = I.runSelectT
runSelect :: Select a -> Request a
runSelect = I.runSelect
type StoreT m = I.StoreT PSQL m
type Store = I.Store PSQL
runStoreT :: Monad m => StoreT m a -> RequestT m a
runStoreT = I.runStoreT
runStore :: Store a -> Request a
runStore = I.runStore
runQuery :: MonadBackend PSQL m => PSQL -> BSL.ByteString
-> ExceptT BS.ByteString m ([ColumnInfo PSQL], [Row PSQL])
runQuery PSQL{..} lazyReq = do
let req = BSL.toStrict lazyReq
when psqlLogQueries $ liftIO $ BS.hPutStrLn stderr $ "runQuery: " <> req
res <- exec' req
let _until i = takeWhile (/= i) $ iterate (+1) 0
ncols <- liftIO $ nfields res
nrows <- liftIO $ ntuples res
let colIndices = _until ncols
cols <- forM colIndices $ \col -> do
name <- liftIO $ fname res col
oid <- liftIO $ ftype res col
typ <- resolveType oid
return ColumnInfo
{ colInfoName = name
, colInfoType = typ
}
rows <- liftIO $ forM (_until nrows) $ \row -> do
forM colIndices $ \col -> do
mValue <- getvalue' res row col
return Field { fieldValue = mValue }
return (cols, rows)
runExecute :: MonadBackend PSQL m => PSQL -> BSL.ByteString
-> ExceptT BS.ByteString m Integer
runExecute PSQL{..} lazyReq = do
let req = BSL.toStrict lazyReq
when psqlLogQueries $ liftIO $ BS.hPutStrLn stderr $ "runExecute: " <> req
res <- exec' req
mBS <- liftIO $ cmdTuples res
case fmap (reads . BS.unpack) mBS of
Just ((n,"") : _) -> return n
_ -> throwError $ "Can't get number of rows affected for " <> req
resolveType :: MonadBackend PSQL m => Oid
-> ExceptT BS.ByteString m BS.ByteString
resolveType oid = do
cache <- get
case lookup oid cache of
Just n -> return n
Nothing -> do
cache' <- getTypes
put cache'
case lookup oid cache' of
Just n -> return n
Nothing -> throwError $ BS.pack $ "Can't resolve type for " ++ show oid
getTypes :: MonadBackend PSQL m => ExceptT BS.ByteString m TypeCache
getTypes = do
res <- exec' "SELECT oid, typname FROM pg_type"
let _until i = takeWhile (/= i) $ iterate (+1) 0
nrows <- liftIO $ ntuples res
forM (_until nrows) $ \row -> do
oidBS <- liftIO $ getvalue' res row 0
nameBS <- liftIO $ getvalue' res row 1
case (fmap (reads . BS.unpack) oidBS, nameBS) of
(Just ((i,"") : _), Just name) ->
return (Oid (fromInteger i), name)
_ -> throwError "Can't read types from pg_type"
exec' :: MonadBackend PSQL m => BS.ByteString
-> ExceptT BS.ByteString m Result
exec' req = do
res <- withConn $ \conn -> do
mRes <- liftIO $ exec conn req
case mRes of
Nothing -> do
err <- liftIO $ fromMaybe "Fatal error" <$> errorMessage conn
throwError err
Just r -> return r
let err bs = do
msg <- liftIO $ fromMaybe bs <$> resultErrorMessage res
throwError msg
status <- liftIO $ resultStatus res
case status of
EmptyQuery -> err "Empty query"
CopyIn -> err "COPY not supported"
CopyOut -> err "COPY not supported"
BadResponse -> err "Bad response"
FatalError -> err "Fatal error"
_ -> return ()
return res