{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
module PgNamed
(
NamedParam (..)
, Name (..)
, (=?)
, PgNamedError (..)
, WithNamedError
, extractNames
, namesToRow
, queryNamed
, executeNamed
) where
import Control.Monad.Except (MonadError (throwError))
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.Bifunctor (bimap)
import Data.ByteString (ByteString)
import Data.Char (isAlphaNum)
import Data.Int (Int64)
import Data.List (lookup)
import Data.List.NonEmpty (NonEmpty (..), toList)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import GHC.Exts (IsString)
import qualified Data.ByteString.Char8 as BS
import qualified Database.PostgreSQL.Simple as PG
import qualified Database.PostgreSQL.Simple.ToField as PG
import qualified Database.PostgreSQL.Simple.Types as PG
newtype Name = Name
{ unName :: Text
} deriving newtype (Show, Eq, Ord, IsString)
data NamedParam = NamedParam
{ namedParamName :: !Name
, namedParamParam :: !PG.Action
} deriving (Show)
data PgNamedError
= PgNamedParam Name
| PgNoNames PG.Query
| PgEmptyName PG.Query
deriving (Eq)
type WithNamedError = MonadError PgNamedError
instance Show PgNamedError where
show e = "PostgreSQL named parameter error: " ++ case e of
PgNamedParam n -> "Named parameter '" ++ show n ++ "' is not specified"
PgNoNames (PG.Query q) ->
"Query has no names but was called with named functions: " ++ BS.unpack q
PgEmptyName (PG.Query q) ->
"Query contains an empty name: " ++ BS.unpack q
lookupName :: Name -> [NamedParam] -> Maybe PG.Action
lookupName n = lookup n . map (\NamedParam{..} -> (namedParamName, namedParamParam))
extractNames
:: PG.Query
-> Either PgNamedError (PG.Query, NonEmpty Name)
extractNames qr = go (PG.fromQuery qr) >>= \case
(_, []) -> Left $ PgNoNames qr
(q, name:names) -> Right (PG.Query q, name :| names)
where
go :: ByteString -> Either PgNamedError (ByteString, [Name])
go str
| BS.null str = Right ("", [])
| otherwise = let (before, after) = BS.break (== '?') str in
case BS.uncons after of
Nothing -> Right (before, [])
Just ('?', nameStart) ->
let (name, remainingQuery) = BS.span isNameChar nameStart
in if BS.null name
then Left $ PgEmptyName qr
else fmap (bimap ((before <> "?") <>) (Name (decodeUtf8 name) :))
(go remainingQuery)
Just _ -> error "'break (== '?')' doesn't return string started with the question mark"
isNameChar :: Char -> Bool
isNameChar c = isAlphaNum c || c == '_'
namesToRow
:: forall m . WithNamedError m
=> NonEmpty Name
-> [NamedParam]
-> m (NonEmpty PG.Action)
namesToRow names params = traverse magicLookup names
where
magicLookup :: Name -> m PG.Action
magicLookup n = case lookupName n params of
Just x -> pure x
Nothing -> throwError $ PgNamedParam n
infix 1 =?
(=?) :: (PG.ToField a) => Name -> a -> NamedParam
n =? a = NamedParam n $ PG.toField a
{-# INLINE (=?) #-}
queryNamed
:: (MonadIO m, WithNamedError m, PG.FromRow res)
=> PG.Connection
-> PG.Query
-> [NamedParam]
-> m [res]
queryNamed conn qNamed params =
withNamedArgs qNamed params >>= \(q, actions) ->
liftIO $ PG.query conn q (toList actions)
executeNamed
:: (MonadIO m, WithNamedError m)
=> PG.Connection
-> PG.Query
-> [NamedParam]
-> m Int64
executeNamed conn qNamed params =
withNamedArgs qNamed params >>= \(q, actions) ->
liftIO $ PG.execute conn q (toList actions)
withNamedArgs
:: WithNamedError m
=> PG.Query
-> [NamedParam]
-> m (PG.Query, NonEmpty PG.Action)
withNamedArgs qNamed namedArgs = do
(q, names) <- case extractNames qNamed of
Left errType -> throwError errType
Right r -> pure r
args <- namesToRow names namedArgs
pure (q, args)