{-# LANGUAGE QuasiQuotes #-}

module PostgREST.Config.Database
  ( queryDbSettings
  , queryPgVersion
  ) where

import PostgREST.Config.PgVersion (PgVersion (..))

import qualified Hasql.Decoders             as HD
import qualified Hasql.Encoders             as HE
import qualified Hasql.Pool                 as SQL
import           Hasql.Session              (Session, statement)
import qualified Hasql.Statement            as SQL
import qualified Hasql.Transaction          as SQL
import qualified Hasql.Transaction.Sessions as SQL

import Text.InterpolatedString.Perl6 (q)

import Protolude

queryPgVersion :: Session PgVersion
queryPgVersion :: Session PgVersion
queryPgVersion = () -> Statement () PgVersion -> Session PgVersion
forall params result.
params -> Statement params result -> Session result
statement ()
forall a. Monoid a => a
mempty (Statement () PgVersion -> Session PgVersion)
-> Statement () PgVersion -> Session PgVersion
forall a b. (a -> b) -> a -> b
$ ByteString
-> Params () -> Result PgVersion -> Bool -> Statement () PgVersion
forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
SQL.Statement ByteString
sql Params ()
HE.noParams Result PgVersion
versionRow Bool
False
  where
    sql :: ByteString
sql = ByteString
"SELECT current_setting('server_version_num')::integer, current_setting('server_version')"
    versionRow :: Result PgVersion
versionRow = Row PgVersion -> Result PgVersion
forall a. Row a -> Result a
HD.singleRow (Row PgVersion -> Result PgVersion)
-> Row PgVersion -> Result PgVersion
forall a b. (a -> b) -> a -> b
$ Int32 -> Text -> PgVersion
PgVersion (Int32 -> Text -> PgVersion)
-> Row Int32 -> Row (Text -> PgVersion)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value Int32 -> Row Int32
forall a. Value a -> Row a
column Value Int32
HD.int4 Row (Text -> PgVersion) -> Row Text -> Row PgVersion
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value Text -> Row Text
forall a. Value a -> Row a
column Value Text
HD.text

queryDbSettings :: SQL.Pool -> Bool -> IO (Either SQL.UsageError [(Text, Text)])
queryDbSettings :: Pool -> Bool -> IO (Either UsageError [(Text, Text)])
queryDbSettings Pool
pool Bool
prepared =
  let transaction :: IsolationLevel -> Mode -> Transaction a -> Session a
transaction = if Bool
prepared then IsolationLevel -> Mode -> Transaction a -> Session a
forall a. IsolationLevel -> Mode -> Transaction a -> Session a
SQL.transaction else IsolationLevel -> Mode -> Transaction a -> Session a
forall a. IsolationLevel -> Mode -> Transaction a -> Session a
SQL.unpreparedTransaction in
  Pool
-> Session [(Text, Text)] -> IO (Either UsageError [(Text, Text)])
forall a. Pool -> Session a -> IO (Either UsageError a)
SQL.use Pool
pool (Session [(Text, Text)] -> IO (Either UsageError [(Text, Text)]))
-> (Transaction [(Text, Text)] -> Session [(Text, Text)])
-> Transaction [(Text, Text)]
-> IO (Either UsageError [(Text, Text)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IsolationLevel
-> Mode -> Transaction [(Text, Text)] -> Session [(Text, Text)]
forall a. IsolationLevel -> Mode -> Transaction a -> Session a
transaction IsolationLevel
SQL.ReadCommitted Mode
SQL.Read (Transaction [(Text, Text)]
 -> IO (Either UsageError [(Text, Text)]))
-> Transaction [(Text, Text)]
-> IO (Either UsageError [(Text, Text)])
forall a b. (a -> b) -> a -> b
$
    () -> Statement () [(Text, Text)] -> Transaction [(Text, Text)]
forall a b. a -> Statement a b -> Transaction b
SQL.statement ()
forall a. Monoid a => a
mempty Statement () [(Text, Text)]
dbSettingsStatement

-- | Get db settings from the connection role. Global settings will be overridden by database specific settings.
dbSettingsStatement :: SQL.Statement () [(Text, Text)]
dbSettingsStatement :: Statement () [(Text, Text)]
dbSettingsStatement = ByteString
-> Params ()
-> Result [(Text, Text)]
-> Bool
-> Statement () [(Text, Text)]
forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
SQL.Statement ByteString
sql Params ()
HE.noParams Result [(Text, Text)]
decodeSettings Bool
False
  where
    sql :: ByteString
sql = [q|
      with
      role_setting as (
        select setdatabase, unnest(setconfig) as setting from pg_catalog.pg_db_role_setting
        where setrole = current_user::regrole::oid
          and setdatabase in (0, (select oid from pg_catalog.pg_database where datname = current_catalog))
      ),
      kv_settings as (
        select setdatabase, split_part(setting, '=', 1) as k, split_part(setting, '=', 2) as value from role_setting
        where setting like 'pgrst.%'
      )
      select distinct on (key) replace(k, 'pgrst.', '') as key, value
      from kv_settings
      order by key, setdatabase desc;
    |]
    decodeSettings :: Result [(Text, Text)]
decodeSettings = Row (Text, Text) -> Result [(Text, Text)]
forall a. Row a -> Result [a]
HD.rowList (Row (Text, Text) -> Result [(Text, Text)])
-> Row (Text, Text) -> Result [(Text, Text)]
forall a b. (a -> b) -> a -> b
$ (,) (Text -> Text -> (Text, Text))
-> Row Text -> Row (Text -> (Text, Text))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value Text -> Row Text
forall a. Value a -> Row a
column Value Text
HD.text Row (Text -> (Text, Text)) -> Row Text -> Row (Text, Text)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Value Text -> Row Text
forall a. Value a -> Row a
column Value Text
HD.text

column :: HD.Value a -> HD.Row a
column :: Value a -> Row a
column = NullableOrNot Value a -> Row a
forall a. NullableOrNot Value a -> Row a
HD.column (NullableOrNot Value a -> Row a)
-> (Value a -> NullableOrNot Value a) -> Value a -> Row a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value a -> NullableOrNot Value a
forall (decoder :: * -> *) a. decoder a -> NullableOrNot decoder a
HD.nonNullable