{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}

module Database.CQL.Protocol.Request
    ( Request           (..)
    , pack
    , encodeRequest
    , getOpCode

      -- ** Options
    , Options           (..)
    , encodeOptions

      -- ** Startup
    , Startup           (..)
    , encodeStartup

      -- ** Auth Response
    , AuthResponse      (..)
    , encodeAuthResponse

      -- ** Register
    , Register          (..)
    , EventType         (..)
    , encodeRegister
    , encodeEventType

      -- ** Query
    , Query             (..)
    , QueryParams       (..)
    , SerialConsistency (..)
    , encodeQuery
    , encodeQueryParams

      -- ** Batch
    , Batch             (..)
    , BatchQuery        (..)
    , BatchType         (..)
    , encodeBatch
    , encodeBatchType
    , encodeBatchQuery

      -- ** Prepare
    , Prepare           (..)
    , encodePrepare

      -- ** Execute
    , Execute           (..)
    , encodeExecute
    ) where

import Control.Applicative
import Control.Monad (when)
import Data.Bits
import Data.ByteString.Lazy (ByteString)
import Data.Foldable (traverse_)
import Data.Int
import Data.Text (Text)
import Data.Maybe (isJust)
import Data.Monoid
import Data.Serialize hiding (decode, encode)
import Data.Word
import Database.CQL.Protocol.Tuple
import Database.CQL.Protocol.Codec
import Database.CQL.Protocol.Types
import Database.CQL.Protocol.Header
import Prelude

import qualified Data.ByteString.Lazy as LB

------------------------------------------------------------------------------
-- Request

-- | The type corresponding to the protocol request frame.
--
-- The type parameter 'k' denotes the kind of request. It is present to allow
-- distinguishing read operations from write operations. Use 'R' for read,
-- 'W' for write and 'S' for schema related operations.
--
-- 'a' represents the argument type and 'b' the return type of this request.
data Request k a b
    = RqStartup  !Startup
    | RqOptions  !Options
    | RqRegister !Register
    | RqBatch    !Batch
    | RqAuthResp !AuthResponse
    | RqPrepare  !(Prepare k a b)
    | RqQuery    !(Query k a b)
    | RqExecute  !(Execute k a b)
    deriving Show

encodeRequest :: Tuple a => Version -> Putter (Request k a b)
encodeRequest _ (RqStartup  r) = encodeStartup r
encodeRequest _ (RqOptions  r) = encodeOptions r
encodeRequest _ (RqRegister r) = encodeRegister r
encodeRequest v (RqBatch    r) = encodeBatch v r
encodeRequest _ (RqAuthResp r) = encodeAuthResponse r
encodeRequest _ (RqPrepare  r) = encodePrepare r
encodeRequest v (RqQuery    r) = encodeQuery v r
encodeRequest v (RqExecute  r) = encodeExecute v r

-- | Serialise the given request, optionally using compression.
-- The result is either an error description in case of failure or a binary
-- protocol frame, including 'Header', 'Length' and body.
pack :: Tuple a
     => Version       -- ^ protocol version, which determines the encoding
     -> Compression   -- ^ compression to use
     -> Bool          -- ^ enable/disable tracing
     -> StreamId      -- ^ the stream Id to use
     -> Request k a b -- ^ the actual request to serialise
     -> Either String ByteString
pack v c t i r = do
    body <- runCompression c (runPutLazy $ encodeRequest v r)
    let len = Length . fromIntegral $ LB.length body
    return . runPutLazy $ do
        encodeHeader v RqHeader mkFlags i (getOpCode r) len
        putLazyByteString body
  where
    runCompression f x = maybe compressError return (shrink f $ x)
    compressError      = Left "pack: compression failure"

    mkFlags = (if t then tracing else mempty)
        <> (if algorithm c /= None then compress else mempty)

-- | Get the protocol 'OpCode' corresponding to the given 'Request'.
getOpCode :: Request k a b -> OpCode
getOpCode (RqQuery _)    = OcQuery
getOpCode (RqExecute _)  = OcExecute
getOpCode (RqPrepare _)  = OcPrepare
getOpCode (RqBatch _)    = OcBatch
getOpCode (RqRegister _) = OcRegister
getOpCode (RqOptions _)  = OcOptions
getOpCode (RqStartup _)  = OcStartup
getOpCode (RqAuthResp _) = OcAuthResponse

------------------------------------------------------------------------------
-- STARTUP

-- | A startup request which is used when initialising a connection to the
-- server. It specifies the CQL version to use and optionally the
-- compression algorithm.
data Startup = Startup !CqlVersion !CompressionAlgorithm deriving Show

encodeStartup :: Putter Startup
encodeStartup (Startup v c) =
    encodeMap $ ("CQL_VERSION", mapVersion v) : mapCompression c
  where
    mapVersion :: CqlVersion -> Text
    mapVersion Cqlv300        = "3.0.0"
    mapVersion (CqlVersion s) = s

    mapCompression :: CompressionAlgorithm -> [(Text, Text)]
    mapCompression Snappy = [("COMPRESSION", "snappy")]
    mapCompression LZ4    = [("COMPRESSION", "lz4")]
    mapCompression None   = []

------------------------------------------------------------------------------
-- AUTH_RESPONSE

-- | A request send in response to a previous authentication challenge.
newtype AuthResponse = AuthResponse LB.ByteString deriving Show

encodeAuthResponse :: Putter AuthResponse
encodeAuthResponse (AuthResponse b) = encodeBytes b

------------------------------------------------------------------------------
-- OPTIONS

-- | An options request, send prior to 'Startup' to request the server's
-- startup options.
data Options = Options deriving Show

encodeOptions :: Putter Options
encodeOptions _ = return ()

------------------------------------------------------------------------------
-- QUERY

-- | A CQL query (select, insert, etc.).
data Query k a b = Query !(QueryString k a b) !(QueryParams a) deriving Show

encodeQuery :: Tuple a => Version -> Putter (Query k a b)
encodeQuery v (Query (QueryString s) p) =
    encodeLongString s >> encodeQueryParams v p

------------------------------------------------------------------------------
-- EXECUTE

-- | Executes a prepared query.
data Execute k a b = Execute !(QueryId k a b) !(QueryParams a) deriving Show

encodeExecute :: Tuple a => Version -> Putter (Execute k a b)
encodeExecute v (Execute (QueryId q) p) =
    encodeShortBytes q >> encodeQueryParams v p

------------------------------------------------------------------------------
-- PREPARE

-- | Prepare a query for later execution (cf. 'Execute').
newtype Prepare k a b = Prepare (QueryString k a b) deriving Show

encodePrepare :: Putter (Prepare k a b)
encodePrepare (Prepare (QueryString p)) = encodeLongString p

------------------------------------------------------------------------------
-- REGISTER

-- | Register's the connection this request is made through, to receive
-- server events.
newtype Register = Register [EventType] deriving Show

encodeRegister :: Putter Register
encodeRegister (Register t) = do
    encodeShort (fromIntegral (length t))
    mapM_ encodeEventType t

-- | Event types to register.
data EventType
    = TopologyChangeEvent -- ^ events related to change in the cluster topology
    | StatusChangeEvent   -- ^ events related to change of node status
    | SchemaChangeEvent   -- ^ events related to schema change
    deriving Show

encodeEventType :: Putter EventType
encodeEventType TopologyChangeEvent = encodeString "TOPOLOGY_CHANGE"
encodeEventType StatusChangeEvent   = encodeString "STATUS_CHANGE"
encodeEventType SchemaChangeEvent   = encodeString "SCHEMA_CHANGE"

------------------------------------------------------------------------------
-- BATCH

-- | Allows executing a list of queries (prepared or not) as a batch.
data Batch = Batch
    { batchType              :: !BatchType
    , batchQuery             :: [BatchQuery]
    , batchConsistency       :: !Consistency
    , batchSerialConsistency :: Maybe SerialConsistency
    } deriving Show

data BatchType
    = BatchLogged   -- ^ default, uses a batch log for atomic application
    | BatchUnLogged -- ^ skip the batch log
    | BatchCounter  -- ^ used for batched counter updates
    deriving (Show)

encodeBatch :: Version -> Putter Batch
encodeBatch v (Batch t q c s) = do
    encodeBatchType t
    encodeShort (fromIntegral (length q))
    mapM_ (encodeBatchQuery v) q
    encodeConsistency c
    when (v == V3) $ do
        put batchFlags
        traverse_ encodeConsistency (mapCons <$> s)
  where
    batchFlags :: Word8
    batchFlags = if isJust s then 0x10 else 0x0

encodeBatchType :: Putter BatchType
encodeBatchType BatchLogged   = putWord8 0
encodeBatchType BatchUnLogged = putWord8 1
encodeBatchType BatchCounter  = putWord8 2

-- | A GADT to unify queries and prepared queries both of which can be used
-- in batch requests.
data BatchQuery where
    BatchQuery :: (Show a, Tuple a, Tuple b)
               => !(QueryString W a b)
               -> !a
               -> BatchQuery

    BatchPrepared :: (Show a, Tuple a, Tuple b)
                  => !(QueryId W a b)
                  -> !a
                  -> BatchQuery

deriving instance Show BatchQuery

encodeBatchQuery :: Version -> Putter BatchQuery
encodeBatchQuery n (BatchQuery (QueryString q) v) = do
    putWord8 0
    encodeLongString q
    store n v
encodeBatchQuery n (BatchPrepared (QueryId i) v)  = do
    putWord8 1
    encodeShortBytes i
    store n v

------------------------------------------------------------------------------
-- Query Parameters

-- | Query parameters.
data QueryParams a = QueryParams
    { consistency       :: !Consistency -- ^ consistency leven to use
    , skipMetaData      :: !Bool        -- ^ skip metadata in response
    , values            :: a            -- ^ query arguments
    , pageSize          :: Maybe Int32  -- ^ desired result set size
    , queryPagingState  :: Maybe PagingState
    , serialConsistency :: Maybe SerialConsistency
    } deriving Show

-- | Consistency level for the serial phase of conditional updates.
data SerialConsistency
    = SerialConsistency
    | LocalSerialConsistency
    deriving Show

encodeQueryParams :: forall a. Tuple a => Version -> Putter (QueryParams a)
encodeQueryParams v p = do
    encodeConsistency (consistency p)
    put queryFlags
    store v (values p)
    traverse_ encodeInt         (pageSize p)
    traverse_ encodePagingState (queryPagingState p)
    traverse_ encodeConsistency (mapCons <$> serialConsistency p)
  where
    queryFlags :: Word8
    queryFlags =
            (if hasValues                    then 0x01 else 0x0)
        .|. (if skipMetaData p               then 0x02 else 0x0)
        .|. (if isJust (pageSize p)          then 0x04 else 0x0)
        .|. (if isJust (queryPagingState p)  then 0x08 else 0x0)
        .|. (if isJust (serialConsistency p) then 0x10 else 0x0)

    hasValues = untag (count :: Tagged a Int) /= 0

mapCons :: SerialConsistency -> Consistency
mapCons SerialConsistency      = Serial
mapCons LocalSerialConsistency = LocalSerial