module Database.CQL.Protocol.Codec
( encodeByte
, decodeByte
, encodeSignedByte
, decodeSignedByte
, encodeShort
, decodeShort
, encodeSignedShort
, decodeSignedShort
, encodeInt
, decodeInt
, encodeString
, decodeString
, encodeLongString
, decodeLongString
, encodeBytes
, decodeBytes
, encodeShortBytes
, decodeShortBytes
, encodeUUID
, decodeUUID
, encodeList
, decodeList
, encodeMap
, decodeMap
, encodeMultiMap
, decodeMultiMap
, encodeSockAddr
, decodeSockAddr
, encodeConsistency
, decodeConsistency
, encodeOpCode
, decodeOpCode
, encodeColumnType
, decodeColumnType
, encodePagingState
, decodePagingState
, decodeKeyspace
, decodeTable
, decodeQueryId
, putValue
, getValue
) where
import Control.Applicative
import Control.Monad
import Data.Bits
import Data.ByteString (ByteString)
import Data.Decimal
import Data.Int
import Data.IP
import Data.List (unfoldr)
import Data.Text (Text)
import Data.UUID (UUID)
import Data.Word
import Data.Serialize hiding (decode, encode)
import Database.CQL.Protocol.Types
import Network.Socket (SockAddr (..), PortNumber (..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as LT
import qualified Data.Text.Lazy.Encoding as LT
import qualified Data.UUID as UUID
encodeByte :: Putter Word8
encodeByte = put
decodeByte :: Get Word8
decodeByte = get
encodeSignedByte :: Putter Int8
encodeSignedByte = put
decodeSignedByte :: Get Int8
decodeSignedByte = get
encodeShort :: Putter Word16
encodeShort = put
decodeShort :: Get Word16
decodeShort = get
encodeSignedShort :: Putter Int16
encodeSignedShort = put
decodeSignedShort :: Get Int16
decodeSignedShort = get
encodeInt :: Putter Int32
encodeInt = put
decodeInt :: Get Int32
decodeInt = get
encodeString :: Putter Text
encodeString = encodeShortBytes . T.encodeUtf8
decodeString :: Get Text
decodeString = T.decodeUtf8 <$> decodeShortBytes
encodeLongString :: Putter LT.Text
encodeLongString = encodeBytes . LT.encodeUtf8
decodeLongString :: Get LT.Text
decodeLongString = do
n <- get :: Get Int32
LT.decodeUtf8 <$> getLazyByteString (fromIntegral n)
encodeBytes :: Putter LB.ByteString
encodeBytes bs = do
put (fromIntegral (LB.length bs) :: Int32)
putLazyByteString bs
decodeBytes :: Get (Maybe LB.ByteString)
decodeBytes = do
n <- get :: Get Int32
if n < 0
then return Nothing
else Just <$> getLazyByteString (fromIntegral n)
encodeShortBytes :: Putter ByteString
encodeShortBytes bs = do
put (fromIntegral (B.length bs) :: Word16)
putByteString bs
decodeShortBytes :: Get ByteString
decodeShortBytes = do
n <- get :: Get Word16
getByteString (fromIntegral n)
encodeUUID :: Putter UUID
encodeUUID = putLazyByteString . UUID.toByteString
decodeUUID :: Get UUID
decodeUUID = do
uuid <- UUID.fromByteString <$> getLazyByteString 16
maybe (fail "decode-uuid: invalid") return uuid
encodeList :: Putter [Text]
encodeList sl = do
put (fromIntegral (length sl) :: Word16)
mapM_ encodeString sl
decodeList :: Get [Text]
decodeList = do
n <- get :: Get Word16
replicateM (fromIntegral n) decodeString
encodeMap :: Putter [(Text, Text)]
encodeMap m = do
put (fromIntegral (length m) :: Word16)
forM_ m $ \(k, v) -> encodeString k >> encodeString v
decodeMap :: Get [(Text, Text)]
decodeMap = do
n <- get :: Get Word16
replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeString)
encodeMultiMap :: Putter [(Text, [Text])]
encodeMultiMap mm = do
put (fromIntegral (length mm) :: Word16)
forM_ mm $ \(k, v) -> encodeString k >> encodeList v
decodeMultiMap :: Get [(Text, [Text])]
decodeMultiMap = do
n <- get :: Get Word16
replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeList)
encodeSockAddr :: Putter SockAddr
encodeSockAddr (SockAddrInet p a) = do
putWord8 4
putWord32le a
putWord32be (fromIntegral p)
encodeSockAddr (SockAddrInet6 p _ (a, b, c, d) _) = do
putWord8 16
putWord32host a
putWord32host b
putWord32host c
putWord32host d
putWord32be (fromIntegral p)
encodeSockAddr (SockAddrUnix _) = fail "encode-socket: unix address not allowed"
decodeSockAddr :: Get SockAddr
decodeSockAddr = do
n <- getWord8
case n of
4 -> do
i <- getIPv4
p <- getPort
return $ SockAddrInet p i
16 -> do
i <- getIPv6
p <- getPort
return $ SockAddrInet6 p 0 i 0
_ -> fail $ "decode-socket: unknown: " ++ show n
where
getPort :: Get PortNumber
getPort = fromIntegral <$> getWord32be
getIPv4 :: Get Word32
getIPv4 = getWord32le
getIPv6 :: Get (Word32, Word32, Word32, Word32)
getIPv6 = (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
encodeConsistency :: Putter Consistency
encodeConsistency Any = encodeShort 0x00
encodeConsistency One = encodeShort 0x01
encodeConsistency Two = encodeShort 0x02
encodeConsistency Three = encodeShort 0x03
encodeConsistency Quorum = encodeShort 0x04
encodeConsistency All = encodeShort 0x05
encodeConsistency LocalQuorum = encodeShort 0x06
encodeConsistency EachQuorum = encodeShort 0x07
encodeConsistency Serial = encodeShort 0x08
encodeConsistency LocalSerial = encodeShort 0x09
encodeConsistency LocalOne = encodeShort 0x0A
decodeConsistency :: Get Consistency
decodeConsistency = decodeShort >>= mapCode
where
mapCode 0x00 = return Any
mapCode 0x01 = return One
mapCode 0x02 = return Two
mapCode 0x03 = return Three
mapCode 0x04 = return Quorum
mapCode 0x05 = return All
mapCode 0x06 = return LocalQuorum
mapCode 0x07 = return EachQuorum
mapCode 0x08 = return Serial
mapCode 0x09 = return LocalSerial
mapCode 0x10 = return LocalOne
mapCode code = fail $ "decode-consistency: unknown: " ++ show code
encodeOpCode :: Putter OpCode
encodeOpCode OcError = encodeByte 0x00
encodeOpCode OcStartup = encodeByte 0x01
encodeOpCode OcReady = encodeByte 0x02
encodeOpCode OcAuthenticate = encodeByte 0x03
encodeOpCode OcOptions = encodeByte 0x05
encodeOpCode OcSupported = encodeByte 0x06
encodeOpCode OcQuery = encodeByte 0x07
encodeOpCode OcResult = encodeByte 0x08
encodeOpCode OcPrepare = encodeByte 0x09
encodeOpCode OcExecute = encodeByte 0x0A
encodeOpCode OcRegister = encodeByte 0x0B
encodeOpCode OcEvent = encodeByte 0x0C
encodeOpCode OcBatch = encodeByte 0x0D
encodeOpCode OcAuthChallenge = encodeByte 0x0E
encodeOpCode OcAuthResponse = encodeByte 0x0F
encodeOpCode OcAuthSuccess = encodeByte 0x10
decodeOpCode :: Get OpCode
decodeOpCode = decodeByte >>= mapCode
where
mapCode 0x00 = return OcError
mapCode 0x01 = return OcStartup
mapCode 0x02 = return OcReady
mapCode 0x03 = return OcAuthenticate
mapCode 0x05 = return OcOptions
mapCode 0x06 = return OcSupported
mapCode 0x07 = return OcQuery
mapCode 0x08 = return OcResult
mapCode 0x09 = return OcPrepare
mapCode 0x0A = return OcExecute
mapCode 0x0B = return OcRegister
mapCode 0x0C = return OcEvent
mapCode 0x0D = return OcBatch
mapCode 0x0E = return OcAuthChallenge
mapCode 0x0F = return OcAuthResponse
mapCode 0x10 = return OcAuthSuccess
mapCode word = fail $ "decode-opcode: unknown: " ++ show word
encodeColumnType :: Putter ColumnType
encodeColumnType (CustomColumn x) = encodeShort 0x0000 >> encodeString x
encodeColumnType AsciiColumn = encodeShort 0x0001
encodeColumnType BigIntColumn = encodeShort 0x0002
encodeColumnType BlobColumn = encodeShort 0x0003
encodeColumnType BooleanColumn = encodeShort 0x0004
encodeColumnType CounterColumn = encodeShort 0x0005
encodeColumnType DecimalColumn = encodeShort 0x0006
encodeColumnType DoubleColumn = encodeShort 0x0007
encodeColumnType FloatColumn = encodeShort 0x0008
encodeColumnType IntColumn = encodeShort 0x0009
encodeColumnType TextColumn = encodeShort 0x000A
encodeColumnType TimestampColumn = encodeShort 0x000B
encodeColumnType UuidColumn = encodeShort 0x000C
encodeColumnType VarCharColumn = encodeShort 0x000D
encodeColumnType VarIntColumn = encodeShort 0x000E
encodeColumnType TimeUuidColumn = encodeShort 0x000F
encodeColumnType InetColumn = encodeShort 0x0010
encodeColumnType (MaybeColumn x) = encodeColumnType x
encodeColumnType (ListColumn x) = encodeShort 0x0020 >> encodeColumnType x
encodeColumnType (MapColumn x y) = encodeShort 0x0021 >> encodeColumnType x >> encodeColumnType y
encodeColumnType (SetColumn x) = encodeShort 0x0022 >> encodeColumnType x
encodeColumnType (TupleColumn xs) = encodeShort 0x0031 >> mapM_ encodeColumnType xs
encodeColumnType (UdtColumn k n xs) = do
encodeShort 0x0030
encodeString (unKeyspace k)
encodeString n
encodeShort (fromIntegral (length xs))
forM_ xs $ \(x, t) -> encodeString x >> encodeColumnType t
decodeColumnType :: Get ColumnType
decodeColumnType = decodeShort >>= toType
where
toType 0x0000 = CustomColumn <$> decodeString
toType 0x0001 = return AsciiColumn
toType 0x0002 = return BigIntColumn
toType 0x0003 = return BlobColumn
toType 0x0004 = return BooleanColumn
toType 0x0005 = return CounterColumn
toType 0x0006 = return DecimalColumn
toType 0x0007 = return DoubleColumn
toType 0x0008 = return FloatColumn
toType 0x0009 = return IntColumn
toType 0x000A = return TextColumn
toType 0x000B = return TimestampColumn
toType 0x000C = return UuidColumn
toType 0x000D = return VarCharColumn
toType 0x000E = return VarIntColumn
toType 0x000F = return TimeUuidColumn
toType 0x0010 = return InetColumn
toType 0x0020 = ListColumn <$> (decodeShort >>= toType)
toType 0x0021 = MapColumn <$> (decodeShort >>= toType) <*> (decodeShort >>= toType)
toType 0x0022 = SetColumn <$> (decodeShort >>= toType)
toType 0x0030 = UdtColumn <$> (Keyspace <$> decodeString) <*> decodeString <*> do
n <- fromIntegral <$> decodeShort
replicateM n ((,) <$> decodeString <*> (decodeShort >>= toType))
toType 0x0031 = TupleColumn <$> do
n <- fromIntegral <$> decodeShort
replicateM n (decodeShort >>= toType)
toType other = fail $ "decode-type: unknown: " ++ show other
encodePagingState :: Putter PagingState
encodePagingState (PagingState s) = encodeBytes s
decodePagingState :: Get (Maybe PagingState)
decodePagingState = liftM PagingState <$> decodeBytes
putValue :: Version -> Putter Value
putValue V3 (CqlList x) = toBytes 4 $ do
encodeInt (fromIntegral (length x))
mapM_ (toBytes 4 . putNative) x
putValue V2 (CqlList x) = toBytes 4 $ do
encodeShort (fromIntegral (length x))
mapM_ (toBytes 2 . putNative) x
putValue V3 (CqlSet x) = toBytes 4 $ do
encodeInt (fromIntegral (length x))
mapM_ (toBytes 4 . putNative) x
putValue V2 (CqlSet x) = toBytes 4 $ do
encodeShort (fromIntegral (length x))
mapM_ (toBytes 2 . putNative) x
putValue V3 (CqlMap x) = toBytes 4 $ do
encodeInt (fromIntegral (length x))
forM_ x $ \(k, v) -> toBytes 4 (putNative k) >> toBytes 4 (putNative v)
putValue V2 (CqlMap x) = toBytes 4 $ do
encodeShort (fromIntegral (length x))
forM_ x $ \(k, v) -> toBytes 2 (putNative k) >> toBytes 2 (putNative v)
putValue V3 (CqlTuple x) = mapM_ (toBytes 4 . putValue V3) x
putValue V3 (CqlUdt x) = mapM_ (toBytes 4 . putValue V3 . snd) x
putValue _ (CqlMaybe Nothing) = put (1 :: Int32)
putValue v (CqlMaybe (Just x)) = putValue v x
putValue _ value = toBytes 4 $ putNative value
putNative :: Putter Value
putNative (CqlCustom x) = putLazyByteString x
putNative (CqlBoolean x) = putWord8 $ if x then 1 else 0
putNative (CqlInt x) = put x
putNative (CqlBigInt x) = put x
putNative (CqlFloat x) = putFloat32be x
putNative (CqlDouble x) = putFloat64be x
putNative (CqlText x) = putByteString (T.encodeUtf8 x)
putNative (CqlUuid x) = encodeUUID x
putNative (CqlTimeUuid x) = encodeUUID x
putNative (CqlTimestamp x) = put x
putNative (CqlAscii x) = putByteString (T.encodeUtf8 x)
putNative (CqlBlob x) = putLazyByteString x
putNative (CqlCounter x) = put x
putNative (CqlInet x) = case x of
IPv4 i -> putWord32le (toHostAddress i)
IPv6 i -> do
let (a, b, c, d) = toHostAddress6 i
putWord32host a
putWord32host b
putWord32host c
putWord32host d
putNative (CqlVarInt x) = integer2bytes x
putNative (CqlDecimal x) = do
put (fromIntegral (decimalPlaces x) :: Int32)
integer2bytes (decimalMantissa x)
putNative v@(CqlList _) = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlSet _) = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlMap _) = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlMaybe _) = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlTuple _) = fail $ "putNative: tuple type: " ++ show v
putNative v@(CqlUdt _) = fail $ "putNative: UDT: " ++ show v
getValue :: Version -> ColumnType -> Get Value
getValue V3 (ListColumn t) = CqlList <$> (getList $ do
len <- decodeInt
replicateM (fromIntegral len) (withBytes 4 (getNative t)))
getValue V2 (ListColumn t) = CqlList <$> (getList $ do
len <- decodeShort
replicateM (fromIntegral len) (withBytes 2 (getNative t)))
getValue V3 (SetColumn t) = CqlSet <$> (getList $ do
len <- decodeInt
replicateM (fromIntegral len) (withBytes 4 (getNative t)))
getValue V2 (SetColumn t) = CqlSet <$> (getList $ do
len <- decodeShort
replicateM (fromIntegral len) (withBytes 2 (getNative t)))
getValue V3 (MapColumn t u) = CqlMap <$> (getList $ do
len <- decodeInt
replicateM (fromIntegral len)
((,) <$> withBytes 4 (getNative t) <*> withBytes 4 (getNative u)))
getValue V2 (MapColumn t u) = CqlMap <$> (getList $ do
len <- decodeShort
replicateM (fromIntegral len)
((,) <$> withBytes 2 (getNative t) <*> withBytes 2 (getNative u)))
getValue V3 (TupleColumn t) = CqlTuple <$> mapM (getValue V3) t
getValue V3 (UdtColumn _ _ x) = CqlUdt <$> do
let (n, t) = unzip x
zip n <$> mapM (getValue V3) t
getValue v (MaybeColumn t) = do
n <- lookAhead (get :: Get Int32)
if n < 0
then uncheckedSkip 4 >> return (CqlMaybe Nothing)
else CqlMaybe . Just <$> getValue v t
getValue _ colType = withBytes 4 $ getNative colType
getNative :: ColumnType -> Get Value
getNative (CustomColumn _) = CqlCustom <$> remainingBytesLazy
getNative BooleanColumn = CqlBoolean . (/= 0) <$> getWord8
getNative IntColumn = CqlInt <$> get
getNative BigIntColumn = CqlBigInt <$> get
getNative FloatColumn = CqlFloat <$> getFloat32be
getNative DoubleColumn = CqlDouble <$> getFloat64be
getNative TextColumn = CqlText . T.decodeUtf8 <$> remainingBytes
getNative VarCharColumn = CqlText . T.decodeUtf8 <$> remainingBytes
getNative AsciiColumn = CqlAscii . T.decodeUtf8 <$> remainingBytes
getNative BlobColumn = CqlBlob <$> remainingBytesLazy
getNative UuidColumn = CqlUuid <$> decodeUUID
getNative TimeUuidColumn = CqlTimeUuid <$> decodeUUID
getNative TimestampColumn = CqlTimestamp <$> get
getNative CounterColumn = CqlCounter <$> get
getNative InetColumn = CqlInet <$> do
len <- remaining
case len of
4 -> IPv4 . fromHostAddress <$> getWord32le
16 -> do
a <- (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
return $ IPv6 (fromHostAddress6 a)
n -> fail $ "getNative: invalid Inet length: " ++ show n
getNative VarIntColumn = CqlVarInt <$> bytes2integer
getNative DecimalColumn = do
x <- get :: Get Int32
y <- bytes2integer
return (CqlDecimal (Decimal (fromIntegral x) y))
getNative c@(ListColumn _) = fail $ "getNative: collection type: " ++ show c
getNative c@(SetColumn _) = fail $ "getNative: collection type: " ++ show c
getNative c@(MapColumn _ _) = fail $ "getNative: collection type: " ++ show c
getNative c@(MaybeColumn _) = fail $ "getNative: collection type: " ++ show c
getNative c@(TupleColumn _) = fail $ "getNative: tuple type: " ++ show c
getNative c@(UdtColumn _ _ _) = fail $ "getNative: udt: " ++ show c
getList :: Get [a] -> Get [a]
getList m = do
n <- lookAhead (get :: Get Int32)
if n < 0 then uncheckedSkip 4 >> return []
else withBytes 4 m
withBytes :: Int -> Get a -> Get a
withBytes s p = do
n <- case s of
2 -> fromIntegral <$> (get :: Get Word16)
4 -> fromIntegral <$> (get :: Get Int32)
_ -> fail $ "withBytes: invalid size: " ++ show s
when (n < 0) $
fail "withBytes: null"
b <- getBytes n
case runGet p b of
Left e -> fail $ "withBytes: " ++ e
Right x -> return x
remainingBytes :: Get ByteString
remainingBytes = remaining >>= getByteString . fromIntegral
remainingBytesLazy :: Get LB.ByteString
remainingBytesLazy = remaining >>= getLazyByteString . fromIntegral
toBytes :: Int -> Put -> Put
toBytes s p = do
let bytes = runPut p
case s of
2 -> put (fromIntegral (B.length bytes) :: Word16)
_ -> put (fromIntegral (B.length bytes) :: Int32)
putByteString bytes
integer2bytes :: Putter Integer
integer2bytes n = do
put sign
put (unroll (abs n))
where
sign = fromIntegral (signum n) :: Word8
unroll :: Integer -> [Word8]
unroll = unfoldr step
where
step 0 = Nothing
step i = Just (fromIntegral i, i `shiftR` 8)
bytes2integer :: Get Integer
bytes2integer = do
sign <- get
bytes <- get
let v = roll bytes
return $! if sign == (1 :: Word8) then v else v
where
roll :: [Word8] -> Integer
roll = foldr unstep 0
where
unstep b a = a `shiftL` 8 .|. fromIntegral b
decodeKeyspace :: Get Keyspace
decodeKeyspace = Keyspace <$> decodeString
decodeTable :: Get Table
decodeTable = Table <$> decodeString
decodeQueryId :: Get (QueryId k a b)
decodeQueryId = QueryId <$> decodeShortBytes