{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Database.PostgreSQL.Copy.Escape (
EscapeCopyValue(..),
escapeCopyRow,
) where
import Data.ByteString (ByteString)
import Data.ByteString.Internal (createAndTrim)
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.List (foldl')
import Foreign
import Foreign.C
import GHC.IO (unsafeDupablePerformIO)
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid
#endif
#if MIN_VERSION_base(4,9,0) && !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
import qualified Data.ByteString as B
#if !MIN_VERSION_base(4,5,0)
infixr 6 <>
(<>) :: Monoid m => m -> m -> m
(<>) = mappend
{-# INLINE (<>) #-}
#endif
type Escaper
= Ptr CUChar
-> CSize
-> Ptr CUChar
-> IO (Ptr CUChar)
newtype Emit = Emit (Ptr CUChar -> IO (Ptr CUChar))
#if MIN_VERSION_base(4,9,0)
instance Semigroup Emit where
(Emit a) <> (Emit b) =
Emit (\ptr0 -> a ptr0 >>= b)
#endif
instance Monoid Emit where
mempty =
Emit return
runEmit :: Int -> Emit -> IO ByteString
runEmit bufsize (Emit f) =
createAndTrim bufsize $ \ptr0 -> do
ptr1 <- f (castPtr ptr0)
let len = ptr1 `minusPtr` ptr0
if len < 0 then
error "Database.PostgreSQL.Copy.Escape.runEmit: len < 0"
else if len > bufsize then
error "Database.PostgreSQL.Copy.Escape.runEmit: buffer overflow"
else
return len
emitByte :: CUChar -> Emit
emitByte c = Emit $ \ptr -> do
pokeElemOff ptr 0 c
return $! (ptr `plusPtr` 1)
emitEscape :: Escaper -> ByteString -> Emit
emitEscape escaper bs = Emit $ \outptr ->
unsafeUseAsCStringLen bs $ \(inptr, inlen) ->
escaper (castPtr inptr) (fromIntegral inlen) outptr
class Escape a where
escapeEmit :: a -> Emit
escapeUpperBound :: a -> Int
escape :: Escape a => a -> IO ByteString
escape a = runEmit (escapeUpperBound a) (escapeEmit a)
foreign import ccall unsafe
c_postgresql_copy_escape_text :: Escaper
foreign import ccall unsafe
c_postgresql_copy_escape_bytea :: Escaper
data EscapeCopyValue
= EscapeCopyNull
| EscapeCopyText !ByteString
| EscapeCopyBytea !ByteString
deriving Show
instance Escape EscapeCopyValue where
escapeEmit v = case v of
EscapeCopyNull -> emitByte 92
<> emitByte 78
EscapeCopyText bs -> emitEscape c_postgresql_copy_escape_text bs
EscapeCopyBytea bs -> emitEscape c_postgresql_copy_escape_bytea bs
escapeUpperBound v = case v of
EscapeCopyNull -> 2
EscapeCopyText bs -> B.length bs * 2
EscapeCopyBytea bs -> B.length bs * 5
newtype EscapeCopyRow = EscapeCopyRow [EscapeCopyValue]
deriving Show
instance Escape EscapeCopyRow where
escapeEmit (EscapeCopyRow list) =
case list of
[] -> newline
(x:xs) -> escapeEmit x <> go xs
where
go [] = newline
go (x:xs) = tab <> escapeEmit x <> go xs
tab = emitByte 9
newline = emitByte 10
escapeUpperBound (EscapeCopyRow list) =
case list of
[] -> 1
xs -> foldl' f 0 xs
where
f a x = a + escapeUpperBound x + 1
escapeCopyRow :: [EscapeCopyValue] -> ByteString
escapeCopyRow xs = unsafeDupablePerformIO (escape (EscapeCopyRow xs))