{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UndecidableInstances  #-}

-- | Decoding values from Postgres wire format to Haskell.

module Preql.Wire.Decode where

import Preql.Wire.Errors
import Preql.Wire.Internal

import Control.Exception (try)
import Control.Monad.Except
import Data.IORef (newIORef)
import GHC.TypeNats
import Preql.Imports

import qualified Data.Vector as V
import qualified Data.Vector.Sized as VS
import qualified Database.PostgreSQL.LibPQ as PQ

decodeVector :: KnownNat n =>
    (PgType -> IO (Either QueryError PQ.Oid)) -> RowDecoder n a -> PQ.Result -> IO (Either QueryError (Vector a))
decodeVector :: (PgType -> IO (Either QueryError Oid))
-> RowDecoder n a -> Result -> IO (Either QueryError (Vector a))
decodeVector PgType -> IO (Either QueryError Oid)
lookupType rd :: RowDecoder n a
rd@(RowDecoder Vector n PgType
pgtypes InternalDecoder a
_parsers) Result
result = do
    [TypeMismatch]
mismatches <- (Vector n (Maybe TypeMismatch) -> [TypeMismatch])
-> IO (Vector n (Maybe TypeMismatch)) -> IO [TypeMismatch]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Maybe TypeMismatch] -> [TypeMismatch]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe TypeMismatch] -> [TypeMismatch])
-> (Vector n (Maybe TypeMismatch) -> [Maybe TypeMismatch])
-> Vector n (Maybe TypeMismatch)
-> [TypeMismatch]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector n (Maybe TypeMismatch) -> [Maybe TypeMismatch]
forall (n :: Nat) a. Vector n a -> [a]
VS.toList) (IO (Vector n (Maybe TypeMismatch)) -> IO [TypeMismatch])
-> IO (Vector n (Maybe TypeMismatch)) -> IO [TypeMismatch]
forall a b. (a -> b) -> a -> b
$ Vector Vector n (Column, PgType)
-> ((Column, PgType) -> IO (Maybe TypeMismatch))
-> IO (Vector n (Maybe TypeMismatch))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (Vector n Column
-> Vector n PgType -> Vector Vector n (Column, PgType)
forall (n :: Nat) a b. Vector n a -> Vector n b -> Vector n (a, b)
VS.zip (Column -> Vector n Column
forall (n :: Nat) a. (KnownNat n, Num a) => a -> Vector n a
VS.enumFromN Column
0) Vector n PgType
pgtypes) (((Column, PgType) -> IO (Maybe TypeMismatch))
 -> IO (Vector n (Maybe TypeMismatch)))
-> ((Column, PgType) -> IO (Maybe TypeMismatch))
-> IO (Vector n (Maybe TypeMismatch))
forall a b. (a -> b) -> a -> b
$ \(column :: Column
column@(PQ.Col CInt
cint), PgType
expected) -> do
        Oid
actual <- Result -> Column -> IO Oid
PQ.ftype Result
result Column
column
        Either QueryError Oid
e_expectedOid <- PgType -> IO (Either QueryError Oid)
lookupType PgType
expected
        case Either QueryError Oid
e_expectedOid of
            Right Oid
oid | Oid
actual Oid -> Oid -> Bool
forall a. Eq a => a -> a -> Bool
== Oid
oid -> Maybe TypeMismatch -> IO (Maybe TypeMismatch)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe TypeMismatch
forall a. Maybe a
Nothing
            Either QueryError Oid
_ -> do
                Maybe ByteString
m_name <- IO (Maybe ByteString) -> IO (Maybe ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe ByteString) -> IO (Maybe ByteString))
-> IO (Maybe ByteString) -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ Result -> Column -> IO (Maybe ByteString)
PQ.fname Result
result Column
column
                let columnName :: Maybe Text
columnName = OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ByteString
m_name
                Maybe TypeMismatch -> IO (Maybe TypeMismatch)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe TypeMismatch -> IO (Maybe TypeMismatch))
-> Maybe TypeMismatch -> IO (Maybe TypeMismatch)
forall a b. (a -> b) -> a -> b
$ TypeMismatch -> Maybe TypeMismatch
forall a. a -> Maybe a
Just (TypeMismatch :: PgType -> Oid -> Int -> Maybe Text -> TypeMismatch
TypeMismatch{column :: Int
column = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
cint, Maybe Text
Oid
PgType
columnName :: Maybe Text
actual :: Oid
expected :: PgType
columnName :: Maybe Text
actual :: Oid
expected :: PgType
..})
    if Bool -> Bool
not ([TypeMismatch] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TypeMismatch]
mismatches)
        then Either QueryError (Vector a) -> IO (Either QueryError (Vector a))
forall (m :: * -> *) a. Monad m => a -> m a
return (QueryError -> Either QueryError (Vector a)
forall a b. a -> Either a b
Left ([TypeMismatch] -> QueryError
PgTypeMismatch [TypeMismatch]
mismatches))
        else do
            (PQ.Row CInt
ntuples) <- IO Row -> IO Row
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Row -> IO Row) -> IO Row -> IO Row
forall a b. (a -> b) -> a -> b
$ Result -> IO Row
PQ.ntuples Result
result
            IORef DecoderState
ref <- DecoderState -> IO (IORef DecoderState)
forall a. a -> IO (IORef a)
newIORef (Result -> Row -> Column -> DecoderState
DecoderState Result
result Row
0 Column
0)
            (Either FieldError (Vector a) -> Either QueryError (Vector a))
-> IO (Either FieldError (Vector a))
-> IO (Either QueryError (Vector a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((FieldError -> QueryError)
-> Either FieldError (Vector a) -> Either QueryError (Vector a)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first FieldError -> QueryError
DecoderError) (IO (Either FieldError (Vector a))
 -> IO (Either QueryError (Vector a)))
-> (IO (Vector a) -> IO (Either FieldError (Vector a)))
-> IO (Vector a)
-> IO (Either QueryError (Vector a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Vector a) -> IO (Either FieldError (Vector a))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (Vector a) -> IO (Either QueryError (Vector a)))
-> IO (Vector a) -> IO (Either QueryError (Vector a))
forall a b. (a -> b) -> a -> b
$
                Int -> IO a -> IO (Vector a)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
ntuples) (IORef DecoderState -> RowDecoder n a -> Result -> IO a
forall (n :: Nat) a.
IORef DecoderState -> RowDecoder n a -> Result -> IO a
decodeRow IORef DecoderState
ref RowDecoder n a
rd Result
result)