{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

module Database.Memcached.Binary.Internal where

import Network

import Foreign.Ptr
import Foreign.Storable
import Foreign.Marshal.Utils
import Foreign.Marshal.Alloc

import System.IO

import Control.Monad
import Control.Exception

import Data.Word
import Data.Pool
import Data.Storable.Endian
import qualified Data.HashMap.Strict as H
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Unsafe as S

import Database.Memcached.Binary.Types
import Database.Memcached.Binary.Types.Exception
import Database.Memcached.Binary.Internal.Definition

newtype Connection = Connection (Pool Handle)

withConnection :: ConnectInfo -> (Connection -> IO a) -> IO a
withConnection i m = withSocketsDo $ bracket (connect i) close m

connect :: ConnectInfo -> IO Connection
connect i = fmap Connection $ 
    createPool (connect' i) (\h -> quit h >> hClose h) 1 
    (connectionIdleTime i) (numConnection i)

connect' :: ConnectInfo -> IO Handle
connect' i = loop (connectAuth i)
  where
    loop [] = do
        connectTo (connectHost i) (connectPort i)

    loop [a] = do
        h <- connectTo (connectHost i) (connectPort i)
        auth a (\_ -> return h) throwIO h

    loop (a:as) = do
        h <- connectTo (connectHost i) (connectPort i)
        handle (\(_::IOError) -> loop as) $
            auth a (\_ -> return h) (\_ -> loop as) h

close :: Connection -> IO ()
close (Connection p) = destroyAllResources p

useConnection :: (Handle -> IO a) -> Connection -> IO a
useConnection f (Connection p) = withResource p f

pokeWord8 :: Ptr a -> Word8 -> IO ()
pokeWord8 = poke . castPtr

pokeWord16be :: Ptr a -> Word16 -> IO ()
pokeWord16be p w = poke (castPtr p) (BE w)

pokeWord32be :: Ptr a -> Word32 -> IO ()
pokeWord32be p w = poke (castPtr p) (BE w)

pokeWord64be :: Ptr a -> Word64 -> IO ()
pokeWord64be p w = poke (castPtr p) (BE w)

peekWord8 :: Ptr a -> IO Word8
peekWord8 = peek . castPtr

peekWord16be :: Ptr a -> IO Word16
peekWord16be p = peek (castPtr p) >>= \(BE w) -> return w

peekWord32be :: Ptr a -> IO Word32
peekWord32be p = peek (castPtr p) >>= \(BE w) -> return w

peekWord64be :: Ptr a -> IO Word64
peekWord64be p = peek (castPtr p) >>= \(BE w) -> return w

pokeByteString :: Ptr a -> S.ByteString -> IO ()
pokeByteString p v =
    S.unsafeUseAsCString v $ \cstr ->
        copyBytes (castPtr p) cstr (S.length v)

pokeLazyByteString :: Ptr a -> L.ByteString -> IO ()
pokeLazyByteString p v =
    void $ L.foldlChunks (\mi s -> mi >>= \i -> do
        pokeByteString (plusPtr p i) s
        return $ i + S.length s
    ) (return 0) v

data Header
data Request

mallocRequest :: OpCode -> Key -> Word8 -> (Ptr Request -> IO ())
              -> Int -> (Ptr Request -> IO ()) -> Word32 -> CAS -> IO (Ptr Request)
mallocRequest (OpCode o) key elen epoke vlen vpoke opaque (CAS cas) = do
    let tlen = S.length key + fromIntegral elen + vlen
    p <- mallocBytes (24 + fromIntegral tlen)
    pokeWord8             p     0x80
    pokeWord8    (plusPtr p  1) o
    pokeWord16be (plusPtr p  2) (fromIntegral $ S.length key)
    pokeWord8    (plusPtr p  4) elen
    pokeWord8    (plusPtr p  5) 0x00
    pokeWord16be (plusPtr p  6) 0x00
    pokeWord32be (plusPtr p  8) (fromIntegral tlen)
    pokeWord32be (plusPtr p 12) opaque
    pokeWord64be (plusPtr p 16) cas
    epoke (plusPtr p 24)
    pokeByteString (plusPtr p $ 24 + fromIntegral elen) key
    vpoke (plusPtr p $ 24 + fromIntegral elen + S.length key)
    return p
{-# INLINE mallocRequest #-}

sendRequest :: OpCode -> Key -> Word8 -> (Ptr Request -> IO ())
            -> Int -> (Ptr Request -> IO ()) -> Word32 -> CAS -> Handle -> IO ()
sendRequest op key elen epoke vlen vpoke opaque cas h =
    bracket (mallocRequest op key elen epoke vlen vpoke opaque cas) free $ \req -> do
        hPutBuf h req (24 + S.length key + fromIntegral elen + vlen)
        hFlush h
{-# INLINE sendRequest #-}

type Failure a = MemcachedException -> IO a

peekResponse :: (Ptr Header -> IO a) -> Failure a -> Handle -> IO a
peekResponse success failure h = allocaBytes 24 $ \p -> do
    len <- hGetBuf h p 24
    if len /= 24
        then failure DataReadFailed
        else do
            peekWord16be (plusPtr p 6) >>= \st ->
                if st == 0
                    then success p
                    else do
                        bl <- peekWord32be (plusPtr p 8)
                        failure . MemcachedException st =<< S.hGet h (fromIntegral bl)
{-# INLINE peekResponse #-}

withRequest :: OpCode -> Key -> Word8 -> (Ptr Request -> IO ())
            -> Int -> (Ptr Request -> IO ()) -> CAS
            -> (Handle -> Ptr Header -> IO a) -> Failure a -> Handle -> IO a
withRequest op key elen epoke vlen vpoke cas success failure h = do
    sendRequest  op key elen epoke vlen vpoke 0 cas h
    peekResponse (success h) failure h

getExtraLength :: Ptr Header -> IO Word8
getExtraLength p = peekWord8 (plusPtr p 4)

getKeyLength :: Ptr Header -> IO Word16
getKeyLength p = peekWord16be (plusPtr p 2)

getTotalLength :: Ptr Header -> IO Word32
getTotalLength p = peekWord32be (plusPtr p 8)

getCAS :: Ptr Header -> IO CAS
getCAS p = fmap CAS $ peekWord64be (plusPtr p 16)

getOpaque :: Ptr Header -> IO Word32
getOpaque p = peekWord32be (plusPtr p 12)

nop :: Ptr Request -> IO ()
nop _ = return ()

inspectResponse :: Handle -> Ptr Header 
                -> IO (S.ByteString, S.ByteString, L.ByteString)
inspectResponse h p = do
    el <- getExtraLength p
    kl <- getKeyLength   p
    tl <- getTotalLength p
    e <- S.hGet h $ fromIntegral el
    k <- S.hGet h $ fromIntegral kl
    v <- L.hGet h $ fromIntegral tl - fromIntegral el - fromIntegral kl
    return (e,k,v)

getSuccessCallback :: (Flags -> Value -> IO a) -> Failure a
                   -> Handle -> Ptr Header -> IO a
getSuccessCallback success failure h p = do
    elen <- getExtraLength p
    tlen <- getTotalLength p
    len  <- hGetBuf h p 4
    if len /= 4
        then failure DataReadFailed
        else do
            flags <- peekWord32be p
            value <- L.hGet h (fromIntegral tlen - fromIntegral elen)
            success flags value

get :: (Flags -> Value -> IO a) -> Failure a
    -> Key -> Handle -> IO a
get success failure key = 
    withRequest opGet key 0 nop 0 nop (CAS 0)
    (getSuccessCallback success failure) failure

getWithCAS :: (CAS -> Flags -> Value -> IO a) -> Failure a
           -> Key -> Handle -> IO a
getWithCAS success failure key =
    withRequest opGet key 0 nop 0 nop (CAS 0)
    (\h p -> getCAS p >>= \c -> getSuccessCallback (success c) failure h p) failure

setAddReplace :: IO a -> Failure a -> OpCode -> CAS
              -> Key -> Value -> Flags -> Expiry -> Handle -> IO a
setAddReplace success failure o cas key value flags expiry = withRequest o key
        8 (\p -> pokeWord32be p flags >> pokeWord32be (plusPtr p 4) expiry) 
        (fromIntegral $ L.length value) (flip pokeLazyByteString value) cas (\_ _ -> success) failure

setAddReplaceWithCAS :: (CAS -> IO a) -> Failure a -> OpCode -> CAS
                     -> Key -> Value -> Flags -> Expiry -> Handle -> IO a
setAddReplaceWithCAS success failure o cas key value flags expiry = withRequest o key
        8 (\p -> pokeWord32be p flags >> pokeWord32be (plusPtr p 4) expiry) 
        (fromIntegral $ L.length value) (flip pokeLazyByteString value) cas (\_ p -> getCAS p >>= success) failure

delete :: IO a -> Failure a -> CAS -> Key -> Handle -> IO a
delete success failure cas key =
    withRequest opDelete key 0 nop 0 nop cas (\_ _ -> success) failure

incrDecr :: (Word64 -> IO a) -> Failure a -> OpCode -> CAS
         -> Key -> Delta -> Initial -> Expiry -> Handle -> IO a
incrDecr success failure op cas key delta initial expiry =
    withRequest op key 20 extra 0 nop cas success' failure
  where
    extra p = do
        pokeWord64be          p     delta
        pokeWord64be (plusPtr p  8) initial
        pokeWord32be (plusPtr p 16) expiry

    success' h p = do
        len <- hGetBuf h p 8
        if len /= 8
            then failure DataReadFailed
            else peekWord64be p >>= success

quit :: Handle -> IO ()
quit h = do
    sendRequest  opQuit "" 0 nop 0 nop 0 (CAS 0) h
    peekResponse (\_ -> return ()) (\_ -> return ()) h

flushAll :: IO a -> Failure a -> Handle -> IO a
flushAll success =
    withRequest opFlush "" 0 nop 0 nop (CAS 0) (\_ _ -> success)

flushWithin :: IO a -> Failure a -> Expiry -> Handle -> IO a
flushWithin success failure w =
    withRequest opFlush "" 4 (flip pokeWord32be w) 0 nop (CAS 0)
    (\_ _ -> success) failure

noOp :: IO a -> Failure a -> Handle -> IO a
noOp success =
    withRequest opNoOp "" 0 nop 0 nop (CAS 0) (\_ _ -> success)

version :: (S.ByteString -> IO a) -> Failure a -> Handle -> IO a
version success =
    withRequest opVersion "" 0 nop 0 nop (CAS 0)
    (\h p -> getTotalLength p >>= S.hGet h . fromIntegral >>= success)

appendPrepend :: IO a -> Failure a -> OpCode -> CAS
              -> Key -> Value -> Handle -> IO a
appendPrepend success failure op cas key value = withRequest op key 0 nop
    (fromIntegral $ L.length value) (flip pokeLazyByteString value)
    cas (\_ _ -> success) failure

stats :: Handle -> IO (H.HashMap S.ByteString S.ByteString)
stats h = loop H.empty
  where
    loop m = do
        sendRequest opStat "" 0 nop 0 nop 0 (CAS 0) h
        peekResponse (success m) throwIO h

    success m p = getTotalLength p >>= \tl ->
        if tl == 0
            then return m
            else do
                kl <- getKeyLength p
                k  <- S.hGet h (fromIntegral kl)
                v  <- S.hGet h (fromIntegral tl - fromIntegral kl)
                loop (H.insert k v m)

verbosity :: IO a -> Failure a -> Word32 -> Handle -> IO a
verbosity success failure v = withRequest opVerbosity ""
    4 (flip pokeWord32be v) 0 nop (CAS 0) (\_ _ -> success) failure

touch :: (Flags -> Value -> IO a) -> Failure a -> OpCode
      -> Key -> Expiry -> Handle -> IO a
touch success failure op key e =
    withRequest op key 4 (flip pokeWord32be e) 0 nop (CAS 0)
    (getSuccessCallback success failure) failure

saslListMechs :: (S.ByteString -> IO a) -> Failure a
              -> Handle -> IO a
saslListMechs success failure =
    withRequest opSaslListMechs "" 0 nop 0 nop (CAS 0)
    (\h p -> getTotalLength p >>= S.hGet h . fromIntegral >>= success)
    failure

auth :: Auth -> (S.ByteString -> IO a) -> Failure a -> Handle -> IO a
auth (Plain u w) success next h = do
    sendRequest  opSaslAuth "PLAIN" 0 nop (S.length u + S.length w + 2) pokeCred 0 (CAS 0) h
    peekResponse consumeResponse next h
  where
    ul = S.length u
    pokeCred p = do
        pokeWord8 p 0
        pokeByteString (plusPtr p        1) u
        pokeWord8      (plusPtr p $ ul + 1) 0
        pokeByteString (plusPtr p $ ul + 2) w

    consumeResponse p = do
        l <- getTotalLength p
        success =<< S.hGet h (fromIntegral l)