{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}

module Database.CQL.Protocol.Codec
    ( encodeByte
    , decodeByte

    , encodeSignedByte
    , decodeSignedByte

    , encodeShort
    , decodeShort

    , encodeSignedShort
    , decodeSignedShort

    , encodeInt
    , decodeInt

    , encodeString
    , decodeString

    , encodeLongString
    , decodeLongString

    , encodeBytes
    , decodeBytes

    , encodeShortBytes
    , decodeShortBytes

    , encodeUUID
    , decodeUUID

    , encodeList
    , decodeList

    , encodeMap
    , decodeMap

    , encodeMultiMap
    , decodeMultiMap

    , encodeSockAddr
    , decodeSockAddr

    , encodeConsistency
    , decodeConsistency

    , encodeOpCode
    , decodeOpCode

    , encodePagingState
    , decodePagingState

    , decodeKeyspace
    , decodeTable
    , decodeColumnType
    , decodeQueryId

    , putValue
    , getValue
    ) where

import Control.Applicative
import Control.Monad
import Data.Bits
import Data.ByteString (ByteString)
import Data.Decimal
import Data.Int
import Data.IP
#ifdef INCOMPATIBLE_VARINT
import Data.List (unfoldr)
#else
import Data.List (foldl')
#endif
import Data.Text (Text)
import Data.UUID (UUID)
import Data.Word
import Data.Serialize hiding (decode, encode)
import Database.CQL.Protocol.Types
import Network.Socket (SockAddr (..), PortNumber)
import Prelude

import qualified Data.ByteString         as B
import qualified Data.ByteString.Lazy    as LB
import qualified Data.Text.Encoding      as T
import qualified Data.Text.Lazy          as LT
import qualified Data.Text.Lazy.Encoding as LT
import qualified Data.UUID               as UUID

------------------------------------------------------------------------------
-- Byte

encodeByte :: Putter Word8
encodeByte = put

decodeByte :: Get Word8
decodeByte = get

------------------------------------------------------------------------------
-- Signed Byte

encodeSignedByte :: Putter Int8
encodeSignedByte = put

decodeSignedByte :: Get Int8
decodeSignedByte = get

------------------------------------------------------------------------------
-- Short

encodeShort :: Putter Word16
encodeShort = put

decodeShort :: Get Word16
decodeShort = get

------------------------------------------------------------------------------
-- Signed Short

encodeSignedShort :: Putter Int16
encodeSignedShort = put

decodeSignedShort :: Get Int16
decodeSignedShort = get

------------------------------------------------------------------------------
-- Int

encodeInt :: Putter Int32
encodeInt = put

decodeInt :: Get Int32
decodeInt = get

------------------------------------------------------------------------------
-- String

encodeString :: Putter Text
encodeString = encodeShortBytes . T.encodeUtf8

decodeString :: Get Text
decodeString = T.decodeUtf8 <$> decodeShortBytes

------------------------------------------------------------------------------
-- Long String

encodeLongString :: Putter LT.Text
encodeLongString = encodeBytes . LT.encodeUtf8

decodeLongString :: Get LT.Text
decodeLongString = do
    n <- get :: Get Int32
    LT.decodeUtf8 <$> getLazyByteString (fromIntegral n)

------------------------------------------------------------------------------
-- Bytes

encodeBytes :: Putter LB.ByteString
encodeBytes bs = do
    put (fromIntegral (LB.length bs) :: Int32)
    putLazyByteString bs

decodeBytes :: Get (Maybe LB.ByteString)
decodeBytes = do
    n <- get :: Get Int32
    if n < 0
        then return Nothing
        else Just <$> getLazyByteString (fromIntegral n)

------------------------------------------------------------------------------
-- Short Bytes

encodeShortBytes :: Putter ByteString
encodeShortBytes bs = do
    put (fromIntegral (B.length bs) :: Word16)
    putByteString bs

decodeShortBytes :: Get ByteString
decodeShortBytes = do
    n <- get :: Get Word16
    getByteString (fromIntegral n)

------------------------------------------------------------------------------
-- UUID

encodeUUID :: Putter UUID
encodeUUID = putLazyByteString . UUID.toByteString

decodeUUID :: Get UUID
decodeUUID = do
    uuid <- UUID.fromByteString <$> getLazyByteString 16
    maybe (fail "decode-uuid: invalid") return uuid

------------------------------------------------------------------------------
-- String List

encodeList :: Putter [Text]
encodeList sl = do
    put (fromIntegral (length sl) :: Word16)
    mapM_ encodeString sl

decodeList :: Get [Text]
decodeList = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) decodeString

------------------------------------------------------------------------------
-- String Map

encodeMap :: Putter [(Text, Text)]
encodeMap m = do
    put (fromIntegral (length m) :: Word16)
    forM_ m $ \(k, v) -> encodeString k >> encodeString v

decodeMap :: Get [(Text, Text)]
decodeMap = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeString)

------------------------------------------------------------------------------
-- String Multi-Map

encodeMultiMap :: Putter [(Text, [Text])]
encodeMultiMap mm = do
    put (fromIntegral (length mm) :: Word16)
    forM_ mm $ \(k, v) -> encodeString k >> encodeList v

decodeMultiMap :: Get [(Text, [Text])]
decodeMultiMap = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeList)

------------------------------------------------------------------------------
-- Inet Address

encodeSockAddr :: Putter SockAddr
encodeSockAddr (SockAddrInet p a) = do
    putWord8 4
    putWord32le a
    putWord32be (fromIntegral p)
encodeSockAddr (SockAddrInet6 p _ (a, b, c, d) _) = do
    putWord8 16
    putWord32host a
    putWord32host b
    putWord32host c
    putWord32host d
    putWord32be (fromIntegral p)
encodeSockAddr (SockAddrUnix _) = error "encode-socket: unix address not supported"
#if MIN_VERSION_network(2,6,1) && !MIN_VERSION_network(3,0,0)
encodeSockAddr (SockAddrCan _) = error "encode-socket: can address not supported"
#endif

decodeSockAddr :: Get SockAddr
decodeSockAddr = do
    n <- getWord8
    case n of
        4  -> do
            i <- getIPv4
            p <- getPort
            return $ SockAddrInet p i
        16 -> do
            i <- getIPv6
            p <- getPort
            return $ SockAddrInet6 p 0 i 0
        _  -> fail $ "decode-socket: unknown: " ++ show n
  where
    getPort :: Get PortNumber
    getPort = fromIntegral <$> getWord32be

    getIPv4 :: Get Word32
    getIPv4 = getWord32le

    getIPv6 :: Get (Word32, Word32, Word32, Word32)
    getIPv6 = (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host

------------------------------------------------------------------------------
-- Consistency

encodeConsistency :: Putter Consistency
encodeConsistency Any         = encodeShort 0x00
encodeConsistency One         = encodeShort 0x01
encodeConsistency Two         = encodeShort 0x02
encodeConsistency Three       = encodeShort 0x03
encodeConsistency Quorum      = encodeShort 0x04
encodeConsistency All         = encodeShort 0x05
encodeConsistency LocalQuorum = encodeShort 0x06
encodeConsistency EachQuorum  = encodeShort 0x07
encodeConsistency Serial      = encodeShort 0x08
encodeConsistency LocalSerial = encodeShort 0x09
encodeConsistency LocalOne    = encodeShort 0x0A

decodeConsistency :: Get Consistency
decodeConsistency = decodeShort >>= mapCode
      where
        mapCode 0x00 = return Any
        mapCode 0x01 = return One
        mapCode 0x02 = return Two
        mapCode 0x03 = return Three
        mapCode 0x04 = return Quorum
        mapCode 0x05 = return All
        mapCode 0x06 = return LocalQuorum
        mapCode 0x07 = return EachQuorum
        mapCode 0x08 = return Serial
        mapCode 0x09 = return LocalSerial
        mapCode 0x0A = return LocalOne
        mapCode code = fail $ "decode-consistency: unknown: " ++ show code

------------------------------------------------------------------------------
-- OpCode

encodeOpCode :: Putter OpCode
encodeOpCode OcError         = encodeByte 0x00
encodeOpCode OcStartup       = encodeByte 0x01
encodeOpCode OcReady         = encodeByte 0x02
encodeOpCode OcAuthenticate  = encodeByte 0x03
encodeOpCode OcOptions       = encodeByte 0x05
encodeOpCode OcSupported     = encodeByte 0x06
encodeOpCode OcQuery         = encodeByte 0x07
encodeOpCode OcResult        = encodeByte 0x08
encodeOpCode OcPrepare       = encodeByte 0x09
encodeOpCode OcExecute       = encodeByte 0x0A
encodeOpCode OcRegister      = encodeByte 0x0B
encodeOpCode OcEvent         = encodeByte 0x0C
encodeOpCode OcBatch         = encodeByte 0x0D
encodeOpCode OcAuthChallenge = encodeByte 0x0E
encodeOpCode OcAuthResponse  = encodeByte 0x0F
encodeOpCode OcAuthSuccess   = encodeByte 0x10

decodeOpCode :: Get OpCode
decodeOpCode = decodeByte >>= mapCode
  where
    mapCode 0x00 = return OcError
    mapCode 0x01 = return OcStartup
    mapCode 0x02 = return OcReady
    mapCode 0x03 = return OcAuthenticate
    mapCode 0x05 = return OcOptions
    mapCode 0x06 = return OcSupported
    mapCode 0x07 = return OcQuery
    mapCode 0x08 = return OcResult
    mapCode 0x09 = return OcPrepare
    mapCode 0x0A = return OcExecute
    mapCode 0x0B = return OcRegister
    mapCode 0x0C = return OcEvent
    mapCode 0x0D = return OcBatch
    mapCode 0x0E = return OcAuthChallenge
    mapCode 0x0F = return OcAuthResponse
    mapCode 0x10 = return OcAuthSuccess
    mapCode word = fail $ "decode-opcode: unknown: " ++ show word

------------------------------------------------------------------------------
-- ColumnType

decodeColumnType :: Get ColumnType
decodeColumnType = decodeShort >>= toType
  where
    toType 0x0000 = CustomColumn <$> decodeString
    toType 0x0001 = return AsciiColumn
    toType 0x0002 = return BigIntColumn
    toType 0x0003 = return BlobColumn
    toType 0x0004 = return BooleanColumn
    toType 0x0005 = return CounterColumn
    toType 0x0006 = return DecimalColumn
    toType 0x0007 = return DoubleColumn
    toType 0x0008 = return FloatColumn
    toType 0x0009 = return IntColumn
    toType 0x000A = return TextColumn
    toType 0x000B = return TimestampColumn
    toType 0x000C = return UuidColumn
    toType 0x000D = return VarCharColumn
    toType 0x000E = return VarIntColumn
    toType 0x000F = return TimeUuidColumn
    toType 0x0010 = return InetColumn
    toType 0x0011 = return DateColumn
    toType 0x0012 = return TimeColumn
    toType 0x0013 = return SmallIntColumn
    toType 0x0014 = return TinyIntColumn
    toType 0x0020 = ListColumn  <$> (decodeShort >>= toType)
    toType 0x0021 = MapColumn   <$> (decodeShort >>= toType) <*> (decodeShort >>= toType)
    toType 0x0022 = SetColumn   <$> (decodeShort >>= toType)
    toType 0x0030 = do
        _ <- decodeString -- Keyspace (not used by this library)
        t <- decodeString -- Type name
        UdtColumn t <$> do
            n <- fromIntegral <$> decodeShort
            replicateM n ((,) <$> decodeString <*> (decodeShort >>= toType))
    toType 0x0031 = TupleColumn <$> do
        n <- fromIntegral <$> decodeShort
        replicateM n (decodeShort >>= toType)
    toType other  = fail $ "decode-type: unknown: " ++ show other

------------------------------------------------------------------------------
-- Paging State

encodePagingState :: Putter PagingState
encodePagingState (PagingState s) = encodeBytes s

decodePagingState :: Get (Maybe PagingState)
decodePagingState = fmap PagingState <$> decodeBytes

------------------------------------------------------------------------------
-- Value

putValue :: Version -> Putter Value
putValue _ (CqlCustom x)    = toBytes $ putLazyByteString x
putValue _ (CqlBoolean x)   = toBytes $ putWord8 $ if x then 1 else 0
putValue _ (CqlInt x)       = toBytes $ put x
putValue _ (CqlBigInt x)    = toBytes $ put x
putValue _ (CqlFloat x)     = toBytes $ putFloat32be x
putValue _ (CqlDouble x)    = toBytes $ putFloat64be x
putValue _ (CqlText x)      = toBytes $ putByteString (T.encodeUtf8 x)
putValue _ (CqlUuid x)      = toBytes $ encodeUUID x
putValue _ (CqlTimeUuid x)  = toBytes $ encodeUUID x
putValue _ (CqlTimestamp x) = toBytes $ put x
putValue _ (CqlAscii x)     = toBytes $ putByteString (T.encodeUtf8 x)
putValue _ (CqlBlob x)      = toBytes $ putLazyByteString x
putValue _ (CqlCounter x)   = toBytes $ put x
putValue _ (CqlInet x)      = toBytes $ case x of
    IPv4 i -> putWord32le (toHostAddress i)
    IPv6 i -> do
        let (a, b, c, d) = toHostAddress6 i
        putWord32host a
        putWord32host b
        putWord32host c
        putWord32host d
putValue _ (CqlVarInt x)   = toBytes $ integer2bytes x
putValue _ (CqlDecimal x)  = toBytes $ do
    put (fromIntegral (decimalPlaces x) :: Int32)
    integer2bytes (decimalMantissa x)
putValue V4   (CqlDate x)     = toBytes $ put x
putValue _  v@(CqlDate _)     = error $ "putValue: date: " ++ show v
putValue V4   (CqlTime x)     = toBytes $ put x
putValue _  v@(CqlTime _)     = error $ "putValue: time: " ++ show v
putValue V4   (CqlSmallInt x) = toBytes $ put x
putValue _  v@(CqlSmallInt _) = error $ "putValue: smallint: " ++ show v
putValue V4   (CqlTinyInt x)  = toBytes $ put x
putValue _  v@(CqlTinyInt _)  = error $ "putValue: tinyint: " ++ show v
putValue v    (CqlUdt   x)    = toBytes $ mapM_ (putValue v . snd) x
putValue v    (CqlList x)     = toBytes $ do
    encodeInt (fromIntegral (length x))
    mapM_ (putValue v) x
putValue v (CqlSet x) = toBytes $ do
    encodeInt (fromIntegral (length x))
    mapM_ (putValue v) x
putValue v (CqlMap x) = toBytes $ do
    encodeInt (fromIntegral (length x))
    forM_ x $ \(k, w) -> putValue v k >> putValue v w
putValue v (CqlTuple x)        = toBytes $ mapM_ (putValue v) x
putValue _ (CqlMaybe Nothing)  = put (-1 :: Int32)
putValue v (CqlMaybe (Just x)) = putValue v x

getValue :: Version -> ColumnType -> Get Value
getValue v (ListColumn t) = CqlList <$> getList (do
    len <- decodeInt
    replicateM (fromIntegral len) (getValue v t))
getValue v (SetColumn t) = CqlSet <$> getList (do
    len <- decodeInt
    replicateM (fromIntegral len) (getValue v t))
getValue v (MapColumn t u) = CqlMap <$> getList (do
    len <- decodeInt
    replicateM (fromIntegral len) ((,) <$> getValue v t <*> getValue v u))
getValue v (TupleColumn t) = withBytes $ CqlTuple <$> mapM (getValue v) t
getValue v (MaybeColumn t) = do
    n <- lookAhead (get :: Get Int32)
    if n < 0
        then uncheckedSkip 4 >> return (CqlMaybe Nothing)
        else CqlMaybe . Just <$> getValue v t
getValue _ (CustomColumn _) = withBytes $ CqlCustom <$> remainingBytesLazy
getValue _ BooleanColumn    = withBytes $ CqlBoolean . (/= 0) <$> getWord8
getValue _ IntColumn        = withBytes $ CqlInt <$> get
getValue _ BigIntColumn     = withBytes $ CqlBigInt <$> get
getValue _ FloatColumn      = withBytes $ CqlFloat  <$> getFloat32be
getValue _ DoubleColumn     = withBytes $ CqlDouble <$> getFloat64be
getValue _ TextColumn       = withBytes $ CqlText . T.decodeUtf8 <$> remainingBytes
getValue _ VarCharColumn    = withBytes $ CqlText . T.decodeUtf8 <$> remainingBytes
getValue _ AsciiColumn      = withBytes $ CqlAscii . T.decodeUtf8 <$> remainingBytes
getValue _ BlobColumn       = withBytes $ CqlBlob <$> remainingBytesLazy
getValue _ UuidColumn       = withBytes $ CqlUuid <$> decodeUUID
getValue _ TimeUuidColumn   = withBytes $ CqlTimeUuid <$> decodeUUID
getValue _ TimestampColumn  = withBytes $ CqlTimestamp <$> get
getValue _ CounterColumn    = withBytes $ CqlCounter <$> get
getValue _ InetColumn       = withBytes $ CqlInet <$> do
    len <- remaining
    case len of
        4  -> IPv4 . fromHostAddress <$> getWord32le
        16 -> do
            a <- (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
            return $ IPv6 (fromHostAddress6 a)
        n  -> fail $ "getNative: invalid Inet length: " ++ show n
getValue V4 DateColumn     = withBytes $ CqlDate <$> get
getValue _  DateColumn     = fail "getNative: date type"
getValue V4 TimeColumn     = withBytes $ CqlTime <$> get
getValue _  TimeColumn     = fail "getNative: time type"
getValue V4 SmallIntColumn = withBytes $ CqlSmallInt <$> get
getValue _  SmallIntColumn = fail "getNative: smallint type"
getValue V4 TinyIntColumn  = withBytes $ CqlTinyInt <$> get
getValue _  TinyIntColumn  = fail "getNative: tinyint type"
getValue _  VarIntColumn   = withBytes $ CqlVarInt <$> bytes2integer
getValue _  DecimalColumn  = withBytes $ do
    x <- get :: Get Int32
    y <- bytes2integer
    return (CqlDecimal (Decimal (fromIntegral x) y))
getValue v (UdtColumn _ x) = withBytes $ CqlUdt <$> do
    let (n, t) = unzip x
    zip n <$> mapM (getValue v) t

getList :: Get [a] -> Get [a]
getList m = do
    n <- lookAhead (get :: Get Int32)
    if n < 0 then uncheckedSkip 4 >> return []
             else withBytes m

withBytes :: Get a -> Get a
withBytes p = do
    n <- fromIntegral <$> (get :: Get Int32)
    when (n < 0) $
        fail $ "withBytes: null (length = " ++ show n ++ ")"
    b <- getBytes n
    case runGet p b of
        Left  e -> fail $ "withBytes: " ++ e
        Right x -> return x

remainingBytes :: Get ByteString
remainingBytes = remaining >>= getByteString . fromIntegral

remainingBytesLazy :: Get LB.ByteString
remainingBytesLazy = remaining >>= getLazyByteString . fromIntegral

toBytes :: Put -> Put
toBytes p = do
    let bytes = runPut p
    put (fromIntegral (B.length bytes) :: Int32)
    putByteString bytes

#ifdef INCOMPATIBLE_VARINT

-- 'integer2bytes' and 'bytes2integer' implementations are taken
-- from cereal's instance declaration of 'Serialize' for 'Integer'
-- except that no distinction between small and large integers is made.
-- Cf. to LICENSE for copyright details.
integer2bytes :: Putter Integer
integer2bytes n = do
    put sign
    put (unroll (abs n))
  where
    sign = fromIntegral (signum n) :: Word8

    unroll :: Integer -> [Word8]
    unroll = unfoldr step
      where
        step 0 = Nothing
        step i = Just (fromIntegral i, i `shiftR` 8)

bytes2integer :: Get Integer
bytes2integer = do
    sign  <- get
    bytes <- get
    let v = roll bytes
    return $! if sign == (1 :: Word8) then v else - v
  where
    roll :: [Word8] -> Integer
    roll = foldr unstep 0
      where
        unstep b a = a `shiftL` 8 .|. fromIntegral b

#else

integer2bytes :: Putter Integer
integer2bytes n
    | n == 0  = putWord8 0x00
    | n == -1 = putWord8 0xFF
    | n <  0  = do
        let bytes = explode (-1) n
        unless (head bytes >= 0x80) $
            putWord8 0xFF
        mapM_ putWord8 bytes
    | otherwise = do
        let bytes = explode 0 n
        unless (head bytes < 0x80) $
            putWord8 0x00
        mapM_ putWord8 bytes

explode :: Integer -> Integer -> [Word8]
explode x n = loop n []
  where
    loop !i !acc
        | i == x    = acc
        | otherwise = loop (i `shiftR` 8) (fromIntegral i : acc)

bytes2integer :: Get Integer
bytes2integer = do
    msb   <- getWord8
    bytes <- B.unpack <$> remainingBytes
    if msb < 0x80
        then return (implode (msb:bytes))
        else return (- (implode (map complement (msb:bytes)) + 1))

implode :: [Word8] -> Integer
implode = foldl' fun 0
  where
    fun i b = i `shiftL` 8 .|. fromIntegral b

#endif
------------------------------------------------------------------------------
-- Various

decodeKeyspace :: Get Keyspace
decodeKeyspace = Keyspace <$> decodeString

decodeTable :: Get Table
decodeTable = Table <$> decodeString

decodeQueryId :: Get (QueryId k a b)
decodeQueryId = QueryId <$> decodeShortBytes