{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module Database.CQL.Protocol.Codec
( encodeByte
, decodeByte
, encodeSignedByte
, decodeSignedByte
, encodeShort
, decodeShort
, encodeSignedShort
, decodeSignedShort
, encodeInt
, decodeInt
, encodeString
, decodeString
, encodeLongString
, decodeLongString
, encodeBytes
, decodeBytes
, encodeShortBytes
, decodeShortBytes
, encodeUUID
, decodeUUID
, encodeList
, decodeList
, encodeMap
, decodeMap
, encodeMultiMap
, decodeMultiMap
, encodeSockAddr
, decodeSockAddr
, encodeConsistency
, decodeConsistency
, encodeOpCode
, decodeOpCode
, encodePagingState
, decodePagingState
, decodeKeyspace
, decodeTable
, decodeColumnType
, decodeQueryId
, putValue
, getValue
) where
import Control.Applicative
import Control.Monad
import Data.Bits
import Data.ByteString (ByteString)
import Data.Decimal
import Data.Int
import Data.IP
#ifdef INCOMPATIBLE_VARINT
import Data.List (unfoldr)
#else
import Data.List (foldl')
#endif
import Data.Text (Text)
import Data.UUID (UUID)
import Data.Word
import Data.Serialize hiding (decode, encode)
import Database.CQL.Protocol.Types
import Network.Socket (SockAddr (..), PortNumber (..))
import Prelude
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as LT
import qualified Data.Text.Lazy.Encoding as LT
import qualified Data.UUID as UUID
encodeByte :: Putter Word8
encodeByte = put
decodeByte :: Get Word8
decodeByte = get
encodeSignedByte :: Putter Int8
encodeSignedByte = put
decodeSignedByte :: Get Int8
decodeSignedByte = get
encodeShort :: Putter Word16
encodeShort = put
decodeShort :: Get Word16
decodeShort = get
encodeSignedShort :: Putter Int16
encodeSignedShort = put
decodeSignedShort :: Get Int16
decodeSignedShort = get
encodeInt :: Putter Int32
encodeInt = put
decodeInt :: Get Int32
decodeInt = get
encodeString :: Putter Text
encodeString = encodeShortBytes . T.encodeUtf8
decodeString :: Get Text
decodeString = T.decodeUtf8 <$> decodeShortBytes
encodeLongString :: Putter LT.Text
encodeLongString = encodeBytes . LT.encodeUtf8
decodeLongString :: Get LT.Text
decodeLongString = do
n <- get :: Get Int32
LT.decodeUtf8 <$> getLazyByteString (fromIntegral n)
encodeBytes :: Putter LB.ByteString
encodeBytes bs = do
put (fromIntegral (LB.length bs) :: Int32)
putLazyByteString bs
decodeBytes :: Get (Maybe LB.ByteString)
decodeBytes = do
n <- get :: Get Int32
if n < 0
then return Nothing
else Just <$> getLazyByteString (fromIntegral n)
encodeShortBytes :: Putter ByteString
encodeShortBytes bs = do
put (fromIntegral (B.length bs) :: Word16)
putByteString bs
decodeShortBytes :: Get ByteString
decodeShortBytes = do
n <- get :: Get Word16
getByteString (fromIntegral n)
encodeUUID :: Putter UUID
encodeUUID = putLazyByteString . UUID.toByteString
decodeUUID :: Get UUID
decodeUUID = do
uuid <- UUID.fromByteString <$> getLazyByteString 16
maybe (fail "decode-uuid: invalid") return uuid
encodeList :: Putter [Text]
encodeList sl = do
put (fromIntegral (length sl) :: Word16)
mapM_ encodeString sl
decodeList :: Get [Text]
decodeList = do
n <- get :: Get Word16
replicateM (fromIntegral n) decodeString
encodeMap :: Putter [(Text, Text)]
encodeMap m = do
put (fromIntegral (length m) :: Word16)
forM_ m $ \(k, v) -> encodeString k >> encodeString v
decodeMap :: Get [(Text, Text)]
decodeMap = do
n <- get :: Get Word16
replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeString)
encodeMultiMap :: Putter [(Text, [Text])]
encodeMultiMap mm = do
put (fromIntegral (length mm) :: Word16)
forM_ mm $ \(k, v) -> encodeString k >> encodeList v
decodeMultiMap :: Get [(Text, [Text])]
decodeMultiMap = do
n <- get :: Get Word16
replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeList)
encodeSockAddr :: Putter SockAddr
encodeSockAddr (SockAddrInet p a) = do
putWord8 4
putWord32le a
putWord32be (fromIntegral p)
encodeSockAddr (SockAddrInet6 p _ (a, b, c, d) _) = do
putWord8 16
putWord32host a
putWord32host b
putWord32host c
putWord32host d
putWord32be (fromIntegral p)
encodeSockAddr (SockAddrUnix _) = fail "encode-socket: unix address not allowed"
#if MIN_VERSION_network(2,6,1) && !MIN_VERSION_network(3,0,0)
encodeSockAddr (SockAddrCan _) = fail "encode-socket: can address not allowed"
#endif
decodeSockAddr :: Get SockAddr
decodeSockAddr = do
n <- getWord8
case n of
4 -> do
i <- getIPv4
p <- getPort
return $ SockAddrInet p i
16 -> do
i <- getIPv6
p <- getPort
return $ SockAddrInet6 p 0 i 0
_ -> fail $ "decode-socket: unknown: " ++ show n
where
getPort :: Get PortNumber
getPort = fromIntegral <$> getWord32be
getIPv4 :: Get Word32
getIPv4 = getWord32le
getIPv6 :: Get (Word32, Word32, Word32, Word32)
getIPv6 = (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
encodeConsistency :: Putter Consistency
encodeConsistency Any = encodeShort 0x00
encodeConsistency One = encodeShort 0x01
encodeConsistency Two = encodeShort 0x02
encodeConsistency Three = encodeShort 0x03
encodeConsistency Quorum = encodeShort 0x04
encodeConsistency All = encodeShort 0x05
encodeConsistency LocalQuorum = encodeShort 0x06
encodeConsistency EachQuorum = encodeShort 0x07
encodeConsistency Serial = encodeShort 0x08
encodeConsistency LocalSerial = encodeShort 0x09
encodeConsistency LocalOne = encodeShort 0x0A
decodeConsistency :: Get Consistency
decodeConsistency = decodeShort >>= mapCode
where
mapCode 0x00 = return Any
mapCode 0x01 = return One
mapCode 0x02 = return Two
mapCode 0x03 = return Three
mapCode 0x04 = return Quorum
mapCode 0x05 = return All
mapCode 0x06 = return LocalQuorum
mapCode 0x07 = return EachQuorum
mapCode 0x08 = return Serial
mapCode 0x09 = return LocalSerial
mapCode 0x0A = return LocalOne
mapCode code = fail $ "decode-consistency: unknown: " ++ show code
encodeOpCode :: Putter OpCode
encodeOpCode OcError = encodeByte 0x00
encodeOpCode OcStartup = encodeByte 0x01
encodeOpCode OcReady = encodeByte 0x02
encodeOpCode OcAuthenticate = encodeByte 0x03
encodeOpCode OcOptions = encodeByte 0x05
encodeOpCode OcSupported = encodeByte 0x06
encodeOpCode OcQuery = encodeByte 0x07
encodeOpCode OcResult = encodeByte 0x08
encodeOpCode OcPrepare = encodeByte 0x09
encodeOpCode OcExecute = encodeByte 0x0A
encodeOpCode OcRegister = encodeByte 0x0B
encodeOpCode OcEvent = encodeByte 0x0C
encodeOpCode OcBatch = encodeByte 0x0D
encodeOpCode OcAuthChallenge = encodeByte 0x0E
encodeOpCode OcAuthResponse = encodeByte 0x0F
encodeOpCode OcAuthSuccess = encodeByte 0x10
decodeOpCode :: Get OpCode
decodeOpCode = decodeByte >>= mapCode
where
mapCode 0x00 = return OcError
mapCode 0x01 = return OcStartup
mapCode 0x02 = return OcReady
mapCode 0x03 = return OcAuthenticate
mapCode 0x05 = return OcOptions
mapCode 0x06 = return OcSupported
mapCode 0x07 = return OcQuery
mapCode 0x08 = return OcResult
mapCode 0x09 = return OcPrepare
mapCode 0x0A = return OcExecute
mapCode 0x0B = return OcRegister
mapCode 0x0C = return OcEvent
mapCode 0x0D = return OcBatch
mapCode 0x0E = return OcAuthChallenge
mapCode 0x0F = return OcAuthResponse
mapCode 0x10 = return OcAuthSuccess
mapCode word = fail $ "decode-opcode: unknown: " ++ show word
decodeColumnType :: Get ColumnType
decodeColumnType = decodeShort >>= toType
where
toType 0x0000 = CustomColumn <$> decodeString
toType 0x0001 = return AsciiColumn
toType 0x0002 = return BigIntColumn
toType 0x0003 = return BlobColumn
toType 0x0004 = return BooleanColumn
toType 0x0005 = return CounterColumn
toType 0x0006 = return DecimalColumn
toType 0x0007 = return DoubleColumn
toType 0x0008 = return FloatColumn
toType 0x0009 = return IntColumn
toType 0x000A = return TextColumn
toType 0x000B = return TimestampColumn
toType 0x000C = return UuidColumn
toType 0x000D = return VarCharColumn
toType 0x000E = return VarIntColumn
toType 0x000F = return TimeUuidColumn
toType 0x0010 = return InetColumn
toType 0x0011 = return DateColumn
toType 0x0012 = return TimeColumn
toType 0x0013 = return SmallIntColumn
toType 0x0014 = return TinyIntColumn
toType 0x0020 = ListColumn <$> (decodeShort >>= toType)
toType 0x0021 = MapColumn <$> (decodeShort >>= toType) <*> (decodeShort >>= toType)
toType 0x0022 = SetColumn <$> (decodeShort >>= toType)
toType 0x0030 = do
_ <- decodeString
t <- decodeString
UdtColumn t <$> do
n <- fromIntegral <$> decodeShort
replicateM n ((,) <$> decodeString <*> (decodeShort >>= toType))
toType 0x0031 = TupleColumn <$> do
n <- fromIntegral <$> decodeShort
replicateM n (decodeShort >>= toType)
toType other = fail $ "decode-type: unknown: " ++ show other
encodePagingState :: Putter PagingState
encodePagingState (PagingState s) = encodeBytes s
decodePagingState :: Get (Maybe PagingState)
decodePagingState = fmap PagingState <$> decodeBytes
putValue :: Version -> Putter Value
putValue _ (CqlCustom x) = toBytes $ putLazyByteString x
putValue _ (CqlBoolean x) = toBytes $ putWord8 $ if x then 1 else 0
putValue _ (CqlInt x) = toBytes $ put x
putValue _ (CqlBigInt x) = toBytes $ put x
putValue _ (CqlFloat x) = toBytes $ putFloat32be x
putValue _ (CqlDouble x) = toBytes $ putFloat64be x
putValue _ (CqlText x) = toBytes $ putByteString (T.encodeUtf8 x)
putValue _ (CqlUuid x) = toBytes $ encodeUUID x
putValue _ (CqlTimeUuid x) = toBytes $ encodeUUID x
putValue _ (CqlTimestamp x) = toBytes $ put x
putValue _ (CqlAscii x) = toBytes $ putByteString (T.encodeUtf8 x)
putValue _ (CqlBlob x) = toBytes $ putLazyByteString x
putValue _ (CqlCounter x) = toBytes $ put x
putValue _ (CqlInet x) = toBytes $ case x of
IPv4 i -> putWord32le (toHostAddress i)
IPv6 i -> do
let (a, b, c, d) = toHostAddress6 i
putWord32host a
putWord32host b
putWord32host c
putWord32host d
putValue _ (CqlVarInt x) = toBytes $ integer2bytes x
putValue _ (CqlDecimal x) = toBytes $ do
put (fromIntegral (decimalPlaces x) :: Int32)
integer2bytes (decimalMantissa x)
putValue V4 (CqlDate x) = toBytes $ put x
putValue _ v@(CqlDate _) = fail $ "putValue: date: " ++ show v
putValue V4 (CqlTime x) = toBytes $ put x
putValue _ v@(CqlTime _) = fail $ "putValue: time: " ++ show v
putValue V4 (CqlSmallInt x) = toBytes $ put x
putValue _ v@(CqlSmallInt _) = fail $ "putValue: smallint: " ++ show v
putValue V4 (CqlTinyInt x) = toBytes $ put x
putValue _ v@(CqlTinyInt _) = fail $ "putValue: tinyint: " ++ show v
putValue v (CqlUdt x) = toBytes $ mapM_ (putValue v . snd) x
putValue v (CqlList x) = toBytes $ do
encodeInt (fromIntegral (length x))
mapM_ (putValue v) x
putValue v (CqlSet x) = toBytes $ do
encodeInt (fromIntegral (length x))
mapM_ (putValue v) x
putValue v (CqlMap x) = toBytes $ do
encodeInt (fromIntegral (length x))
forM_ x $ \(k, w) -> putValue v k >> putValue v w
putValue v (CqlTuple x) = toBytes $ mapM_ (putValue v) x
putValue _ (CqlMaybe Nothing) = put (-1 :: Int32)
putValue v (CqlMaybe (Just x)) = putValue v x
getValue :: Version -> ColumnType -> Get Value
getValue v (ListColumn t) = CqlList <$> getList (do
len <- decodeInt
replicateM (fromIntegral len) (getValue v t))
getValue v (SetColumn t) = CqlSet <$> getList (do
len <- decodeInt
replicateM (fromIntegral len) (getValue v t))
getValue v (MapColumn t u) = CqlMap <$> getList (do
len <- decodeInt
replicateM (fromIntegral len) ((,) <$> getValue v t <*> getValue v u))
getValue v (TupleColumn t) = withBytes $ CqlTuple <$> mapM (getValue v) t
getValue v (MaybeColumn t) = do
n <- lookAhead (get :: Get Int32)
if n < 0
then uncheckedSkip 4 >> return (CqlMaybe Nothing)
else CqlMaybe . Just <$> getValue v t
getValue _ (CustomColumn _) = withBytes $ CqlCustom <$> remainingBytesLazy
getValue _ BooleanColumn = withBytes $ CqlBoolean . (/= 0) <$> getWord8
getValue _ IntColumn = withBytes $ CqlInt <$> get
getValue _ BigIntColumn = withBytes $ CqlBigInt <$> get
getValue _ FloatColumn = withBytes $ CqlFloat <$> getFloat32be
getValue _ DoubleColumn = withBytes $ CqlDouble <$> getFloat64be
getValue _ TextColumn = withBytes $ CqlText . T.decodeUtf8 <$> remainingBytes
getValue _ VarCharColumn = withBytes $ CqlText . T.decodeUtf8 <$> remainingBytes
getValue _ AsciiColumn = withBytes $ CqlAscii . T.decodeUtf8 <$> remainingBytes
getValue _ BlobColumn = withBytes $ CqlBlob <$> remainingBytesLazy
getValue _ UuidColumn = withBytes $ CqlUuid <$> decodeUUID
getValue _ TimeUuidColumn = withBytes $ CqlTimeUuid <$> decodeUUID
getValue _ TimestampColumn = withBytes $ CqlTimestamp <$> get
getValue _ CounterColumn = withBytes $ CqlCounter <$> get
getValue _ InetColumn = withBytes $ CqlInet <$> do
len <- remaining
case len of
4 -> IPv4 . fromHostAddress <$> getWord32le
16 -> do
a <- (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
return $ IPv6 (fromHostAddress6 a)
n -> fail $ "getNative: invalid Inet length: " ++ show n
getValue V4 DateColumn = withBytes $ CqlDate <$> get
getValue _ DateColumn = fail "getNative: date type"
getValue V4 TimeColumn = withBytes $ CqlTime <$> get
getValue _ TimeColumn = fail "getNative: time type"
getValue V4 SmallIntColumn = withBytes $ CqlSmallInt <$> get
getValue _ SmallIntColumn = fail "getNative: smallint type"
getValue V4 TinyIntColumn = withBytes $ CqlTinyInt <$> get
getValue _ TinyIntColumn = fail "getNative: tinyint type"
getValue _ VarIntColumn = withBytes $ CqlVarInt <$> bytes2integer
getValue _ DecimalColumn = withBytes $ do
x <- get :: Get Int32
y <- bytes2integer
return (CqlDecimal (Decimal (fromIntegral x) y))
getValue v (UdtColumn _ x) = withBytes $ CqlUdt <$> do
let (n, t) = unzip x
zip n <$> mapM (getValue v) t
getList :: Get [a] -> Get [a]
getList m = do
n <- lookAhead (get :: Get Int32)
if n < 0 then uncheckedSkip 4 >> return []
else withBytes m
withBytes :: Get a -> Get a
withBytes p = do
n <- fromIntegral <$> (get :: Get Int32)
when (n < 0) $
fail $ "withBytes: null (length = " ++ show n ++ ")"
b <- getBytes n
case runGet p b of
Left e -> fail $ "withBytes: " ++ e
Right x -> return x
remainingBytes :: Get ByteString
remainingBytes = remaining >>= getByteString . fromIntegral
remainingBytesLazy :: Get LB.ByteString
remainingBytesLazy = remaining >>= getLazyByteString . fromIntegral
toBytes :: Put -> Put
toBytes p = do
let bytes = runPut p
put (fromIntegral (B.length bytes) :: Int32)
putByteString bytes
#ifdef INCOMPATIBLE_VARINT
integer2bytes :: Putter Integer
integer2bytes n = do
put sign
put (unroll (abs n))
where
sign = fromIntegral (signum n) :: Word8
unroll :: Integer -> [Word8]
unroll = unfoldr step
where
step 0 = Nothing
step i = Just (fromIntegral i, i `shiftR` 8)
bytes2integer :: Get Integer
bytes2integer = do
sign <- get
bytes <- get
let v = roll bytes
return $! if sign == (1 :: Word8) then v else - v
where
roll :: [Word8] -> Integer
roll = foldr unstep 0
where
unstep b a = a `shiftL` 8 .|. fromIntegral b
#else
integer2bytes :: Putter Integer
integer2bytes n
| n == 0 = putWord8 0x00
| n == -1 = putWord8 0xFF
| n < 0 = do
let bytes = explode (-1) n
unless (head bytes >= 0x80) $
putWord8 0xFF
mapM_ putWord8 bytes
| otherwise = do
let bytes = explode 0 n
unless (head bytes < 0x80) $
putWord8 0x00
mapM_ putWord8 bytes
explode :: Integer -> Integer -> [Word8]
explode x n = loop n []
where
loop !i !acc
| i == x = acc
| otherwise = loop (i `shiftR` 8) (fromIntegral i : acc)
bytes2integer :: Get Integer
bytes2integer = do
msb <- getWord8
bytes <- B.unpack <$> remainingBytes
if msb < 0x80
then return (implode (msb:bytes))
else return (- (implode (map complement (msb:bytes)) + 1))
implode :: [Word8] -> Integer
implode = foldl' fun 0
where
fun i b = i `shiftL` 8 .|. fromIntegral b
#endif
decodeKeyspace :: Get Keyspace
decodeKeyspace = Keyspace <$> decodeString
decodeTable :: Get Table
decodeTable = Table <$> decodeString
decodeQueryId :: Get (QueryId k a b)
decodeQueryId = QueryId <$> decodeShortBytes