module Database.PostgreSQL.Simple.Copy
( copy
, copy_
, CopyOutResult(..)
, getCopyData
, putCopyData
, putCopyEnd
, putCopyError
) where
import Control.Applicative
import Control.Concurrent
import Control.Exception ( throwIO )
import qualified Data.Attoparsec.ByteString.Char8 as P
import Data.Typeable(Typeable)
import Data.Int(Int64)
import qualified Data.ByteString.Char8 as B
import qualified Database.PostgreSQL.LibPQ as PQ
import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.Types
import Database.PostgreSQL.Simple.Internal
copy :: ( ToRow params ) => Connection -> Query -> params -> IO ()
copy conn template qs = do
q <- formatQuery conn template qs
doCopy "Database.PostgreSQL.Simple.Copy.copy" conn template q
copy_ :: Connection -> Query -> IO ()
copy_ conn (Query q) = do
doCopy "Database.PostgreSQL.Simple.Copy.copy_" conn (Query q) q
doCopy :: B.ByteString -> Connection -> Query -> B.ByteString -> IO ()
doCopy funcName conn template q = do
result <- exec conn q
status <- PQ.resultStatus result
let err = throwIO $ QueryError
(B.unpack funcName ++ " " ++ show status)
template
case status of
PQ.EmptyQuery -> err
PQ.CommandOk -> err
PQ.TuplesOk -> err
PQ.CopyOut -> return ()
PQ.CopyIn -> return ()
PQ.BadResponse -> throwResultError funcName result status
PQ.NonfatalError -> throwResultError funcName result status
PQ.FatalError -> throwResultError funcName result status
data CopyOutResult
= CopyOutRow !B.ByteString
| CopyOutDone !Int64
deriving (Eq, Typeable, Show)
getCopyData :: Connection -> IO CopyOutResult
getCopyData conn = withConnection conn loop
where
funcName = "Database.PostgreSQL.Simple.Copy.getCopyData"
loop pqconn = do
#if defined(mingw32_HOST_OS)
row <- PQ.getCopyData pqconn False
#else
row <- PQ.getCopyData pqconn True
#endif
case row of
PQ.CopyOutRow rowdata -> return $! CopyOutRow rowdata
PQ.CopyOutDone -> CopyOutDone <$> getCopyCommandTag funcName pqconn
#if defined(mingw32_HOST_OS)
PQ.CopyOutWouldBlock -> do
fail (B.unpack funcName ++ ": the impossible happened")
#else
PQ.CopyOutWouldBlock -> do
mfd <- PQ.socket pqconn
case mfd of
Nothing -> throwIO (fdError funcName)
Just fd -> do
threadWaitRead fd
_ <- PQ.consumeInput pqconn
loop pqconn
#endif
PQ.CopyOutError -> do
mmsg <- PQ.errorMessage pqconn
throwIO SqlError {
sqlState = "",
sqlExecStatus = FatalError,
sqlErrorMsg = maybe "" id mmsg,
sqlErrorDetail = "",
sqlErrorHint = funcName
}
putCopyData :: Connection -> B.ByteString -> IO ()
putCopyData conn dat = withConnection conn $ \pqconn -> do
doCopyIn funcName (\c -> PQ.putCopyData c dat) pqconn
where
funcName = "Database.PostgreSQL.Simple.Copy.putCopyData"
putCopyEnd :: Connection -> IO Int64
putCopyEnd conn = withConnection conn $ \pqconn -> do
doCopyIn funcName (\c -> PQ.putCopyEnd c Nothing) pqconn
getCopyCommandTag funcName pqconn
where
funcName = "Database.PostgreSQL.Simple.Copy.putCopyEnd"
putCopyError :: Connection -> B.ByteString -> IO ()
putCopyError conn err = withConnection conn $ \pqconn -> do
doCopyIn funcName (\c -> PQ.putCopyEnd c (Just err)) pqconn
consumeResults pqconn
where
funcName = "Database.PostgreSQL.Simple.Copy.putCopyError"
doCopyIn :: B.ByteString -> (PQ.Connection -> IO PQ.CopyInResult)
-> PQ.Connection -> IO ()
doCopyIn funcName action = loop
where
loop pqconn = do
stat <- action pqconn
case stat of
PQ.CopyInOk -> return ()
PQ.CopyInError -> do
mmsg <- PQ.errorMessage pqconn
throwIO SqlError {
sqlState = "",
sqlExecStatus = FatalError,
sqlErrorMsg = maybe "" id mmsg,
sqlErrorDetail = "",
sqlErrorHint = funcName
}
PQ.CopyInWouldBlock -> do
mfd <- PQ.socket pqconn
case mfd of
Nothing -> throwIO (fdError funcName)
Just fd -> do
threadWaitWrite fd
loop pqconn
getCopyCommandTag :: B.ByteString -> PQ.Connection -> IO Int64
getCopyCommandTag funcName pqconn = do
result <- maybe (fail errCmdStatus) return =<< PQ.getResult pqconn
cmdStat <- maybe (fail errCmdStatus) return =<< PQ.cmdStatus result
consumeResults pqconn
let rowCount = P.string "COPY " *> (P.decimal <* P.endOfInput)
case P.parseOnly rowCount cmdStat of
Left _ -> do mmsg <- PQ.errorMessage pqconn
fail $ errCmdStatusFmt
++ maybe "" (\msg -> "\nConnection error: "++B.unpack msg) mmsg
Right n -> return $! n
where
errCmdStatus = B.unpack funcName ++ ": failed to fetch command status"
errCmdStatusFmt = B.unpack funcName ++ ": failed to parse command status"
consumeResults :: PQ.Connection -> IO ()
consumeResults pqconn = do
mres <- PQ.getResult pqconn
case mres of
Nothing -> return ()
Just _ -> consumeResults pqconn