--
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
--   http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.
--

{-# LANGUAGE CPP #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Thrift.Protocol.Binary
    ( module Thrift.Protocol
    , BinaryProtocol(..)
    , versionMask
    , version1
    ) where

import Control.Exception ( throw )
import Control.Monad
import Data.Bits
import Data.ByteString.Lazy.Builder
import Data.Functor
import Data.Int
import Data.Monoid
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
import Data.Word

import Thrift.Protocol
import Thrift.Transport
import Thrift.Types

import qualified Data.Attoparsec.ByteString as P
import qualified Data.Attoparsec.ByteString.Lazy as LP
import qualified Data.Binary as Binary
import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as Map
import qualified Data.Text.Lazy as LT

versionMask :: Int32
versionMask = fromIntegral (0xffff0000 :: Word32)

version1 :: Int32
version1 = fromIntegral (0x80010000 :: Word32)

data BinaryProtocol a = Transport a => BinaryProtocol a

getTransport :: Transport t => BinaryProtocol t -> t
getTransport (BinaryProtocol t) = t

-- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
-- encode and decode data.  Data.Binary assumes that the binary values it is
-- encoding to and decoding from are in BIG ENDIAN format, and converts the
-- endianness as necessary to match the local machine.
instance Transport t => Protocol (BinaryProtocol t) where
    readByte p = tReadAll (getTransport p) 1
    -- flushTransport p = tFlush (getTransport p)
    writeMessage p (n, t, s) f = do
      tWrite (getTransport p) messageBegin
      f
      tFlush $ getTransport p
      where
        messageBegin = toLazyByteString $
          buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
          buildBinaryValue (TString $ encodeUtf8 n) <>
          buildBinaryValue (TI32 s)

    readMessage p = (readMessageBegin p >>=)
      where
        readMessageBegin p = runParser p $ do
          TI32 ver <- parseBinaryValue T_I32
          if ver .&. versionMask /= version1
            then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
            else do
              TString s <- parseBinaryValue T_STRING
              TI32 sz <- parseBinaryValue T_I32
              return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)

    writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
    readVal p = runParser p . parseBinaryValue

instance Transport t => StatelessProtocol (BinaryProtocol t) where
    serializeVal _ = toLazyByteString . buildBinaryValue
    deserializeVal _ ty bs =
      case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
        Left s -> error s
        Right val -> val

-- | Writing Functions
buildBinaryValue :: ThriftVal -> Builder
buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP
buildBinaryValue (TMap ky vt entries) =
  buildType ky <>
  buildType vt <>
  int32BE (fromIntegral (length entries)) <>
  buildBinaryMap entries
buildBinaryValue (TList ty entries) =
  buildType ty <>
  int32BE (fromIntegral (length entries)) <>
  buildBinaryList entries
buildBinaryValue (TSet ty entries) =
  buildType ty <>
  int32BE (fromIntegral (length entries)) <>
  buildBinaryList entries
buildBinaryValue (TBool b) =
  word8 $ toEnum $ if b then 1 else 0
buildBinaryValue (TByte b) = int8 b
buildBinaryValue (TI16 i) = int16BE i
buildBinaryValue (TI32 i) = int32BE i
buildBinaryValue (TI64 i) = int64BE i
buildBinaryValue (TDouble d) = doubleBE d
buildBinaryValue (TString s) = int32BE len <> lazyByteString s
  where
    len :: Int32 = fromIntegral (LBS.length s)
buildBinaryValue (TBinary s) = buildBinaryValue (TString s)

buildBinaryStruct :: Map.HashMap Int16 (LT.Text, ThriftVal) -> Builder
buildBinaryStruct = Map.foldrWithKey combine mempty
  where
    combine fid (_,val) s =
      buildTypeOf val <> int16BE fid <> buildBinaryValue val <> s

buildBinaryMap :: [(ThriftVal, ThriftVal)] -> Builder
buildBinaryMap = foldl combine mempty
  where
    combine s (key, val) = s <> buildBinaryValue key <> buildBinaryValue val

buildBinaryList :: [ThriftVal] -> Builder
buildBinaryList = foldr (mappend . buildBinaryValue) mempty

-- | Reading Functions
parseBinaryValue :: ThriftType -> P.Parser ThriftVal
parseBinaryValue (T_STRUCT tmap) = TStruct <$> parseBinaryStruct tmap
parseBinaryValue (T_MAP _ _) = do
  kt <- parseType
  vt <- parseType
  n <- Binary.decode . LBS.fromStrict <$> P.take 4
  TMap kt vt <$> parseBinaryMap kt vt n
parseBinaryValue (T_LIST _) = do
  t <- parseType
  n <- Binary.decode . LBS.fromStrict <$> P.take 4
  TList t <$> parseBinaryList t n
parseBinaryValue (T_SET _) = do
  t <- parseType
  n <- Binary.decode . LBS.fromStrict <$> P.take 4
  TSet t <$> parseBinaryList t n
parseBinaryValue T_BOOL = TBool . (/=0) <$> P.anyWord8
parseBinaryValue T_BYTE = TByte . Binary.decode . LBS.fromStrict <$> P.take 1
parseBinaryValue T_I16 = TI16 . Binary.decode . LBS.fromStrict <$> P.take 2
parseBinaryValue T_I32 = TI32 . Binary.decode . LBS.fromStrict <$> P.take 4
parseBinaryValue T_I64 = TI64 . Binary.decode . LBS.fromStrict <$> P.take 8
parseBinaryValue T_DOUBLE = TDouble . bsToDouble <$> P.take 8
parseBinaryValue T_STRING = parseBinaryString TString
parseBinaryValue T_BINARY = parseBinaryString TBinary
parseBinaryValue ty = error $ "Cannot read value of type " ++ show ty

parseBinaryString ty = do
  i :: Int32  <- Binary.decode . LBS.fromStrict <$> P.take 4
  ty . LBS.fromStrict <$> P.take (fromIntegral i)

parseBinaryStruct :: TypeMap -> P.Parser (Map.HashMap Int16 (LT.Text, ThriftVal))
parseBinaryStruct tmap = Map.fromList <$> P.manyTill parseField (matchType T_STOP)
  where
    parseField = do
      t <- parseType
      n <- Binary.decode . LBS.fromStrict <$> P.take 2
      v <- case (t, Map.lookup n tmap) of
             (T_STRING, Just (_, T_BINARY)) -> parseBinaryValue T_BINARY
             _ -> parseBinaryValue t
      return (n, ("", v))

parseBinaryMap :: ThriftType -> ThriftType -> Int32 -> P.Parser [(ThriftVal, ThriftVal)]
parseBinaryMap kt vt n | n <= 0 = return []
                       | otherwise = do
  k <- parseBinaryValue kt
  v <- parseBinaryValue vt
  ((k,v) :) <$> parseBinaryMap kt vt (n-1)

parseBinaryList :: ThriftType -> Int32 -> P.Parser [ThriftVal]
parseBinaryList ty n | n <= 0 = return []
                     | otherwise = liftM2 (:) (parseBinaryValue ty)
                                   (parseBinaryList ty (n-1))



-- | Write a type as a byte
buildType :: ThriftType -> Builder
buildType t = word8 $ fromIntegral $ fromEnum t

-- | Write type of a ThriftVal as a byte
buildTypeOf :: ThriftVal -> Builder
buildTypeOf = buildType . getTypeOf

-- | Read a byte as though it were a ThriftType
parseType :: P.Parser ThriftType
parseType = toEnum . fromIntegral <$> P.anyWord8

matchType :: ThriftType -> P.Parser ThriftType
matchType t = t <$ P.word8 (fromIntegral $ fromEnum t)