{-# LANGUAGE CPP, DeriveDataTypeable, DeriveFunctor, DeriveGeneric,
    FlexibleContexts, OverloadedStrings, RecordWildCards #-}

module Network.Wreq.Cache
    (
      shouldCache
    , validateEntry
    , cacheStore
    ) where

import Control.Applicative
import Control.Lens ((^?), (^.), (^..), folded, non, pre, to)
import Control.Monad (guard)
import Data.Attoparsec.ByteString.Char8 as A
import Data.CaseInsensitive (mk)
import Data.Foldable (forM_)
import Data.HashSet (HashSet)
import Data.Hashable (Hashable)
import Data.IntSet (IntSet)
import Data.IORef (newIORef)
import Data.List (sort)
import Data.Maybe (listToMaybe)
import Data.Monoid (First(..), mconcat)
import Data.Time.Clock (UTCTime, addUTCTime, getCurrentTime)
import Data.Time.Format (parseTimeM)
import Data.Time.Locale.Compat (defaultTimeLocale)
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import Network.HTTP.Types (HeaderName, Method)
import Network.Wreq.Internal.Lens
import Network.Wreq.Internal.Types
import Network.Wreq.Lens
import qualified Data.ByteString.Char8 as B
import qualified Data.HashSet as HashSet
import qualified Data.IntSet as IntSet
import qualified Network.Wreq.Cache.Store as Store

#if MIN_VERSION_base(4,6,0)
import Data.IORef (atomicModifyIORef')
#else
import Data.IORef (IORef, atomicModifyIORef)

atomicModifyIORef' :: IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' = atomicModifyIORef
#endif

cacheStore :: Int -> IO (Run body -> Run body)
cacheStore capacity = do
  cache <- newIORef (Store.empty capacity)
  return $ \run req -> do
    let url = reqURL req
    before <- getCurrentTime
    mresp <- atomicModifyIORef' cache $ \s ->
      case Store.lookup url s of
        Nothing       -> (s, Nothing)
        Just (ce, s') ->
          case validateEntry before ce of
            n@Nothing -> (Store.delete url s, n)
            resp      -> (s', resp)
    case mresp of
      Just resp -> return resp
      Nothing -> do
        resp <- run req
        after <- getCurrentTime
        forM_ (shouldCache after req resp) $ \ce ->
          atomicModifyIORef' cache $ \s -> (Store.insert url ce s, ())
        return resp

cacheableStatuses :: IntSet
cacheableStatuses = IntSet.fromList [200, 203, 300, 301, 410]

cacheableMethods :: HashSet Method
cacheableMethods = HashSet.fromList ["GET", "HEAD", "OPTIONS"]

possiblyCacheable :: Request -> Response body -> Bool
possiblyCacheable req resp =
    (req ^. method) `HashSet.member` cacheableMethods &&
    (resp ^. responseStatus . statusCode) `IntSet.member` cacheableStatuses

computeExpiration :: UTCTime -> [CacheResponse Seconds] -> Maybe UTCTime
computeExpiration now crs = do
  guard $ and [NoCache [] `notElem` crs, NoStore `notElem` crs]
  age <- listToMaybe $ sort [age | MaxAge age <- crs]
  return $! fromIntegral age `addUTCTime` now

validateEntry :: UTCTime -> CacheEntry body -> Maybe (Response body)
validateEntry now CacheEntry{..} =
  case entryExpires of
    Nothing          -> Just entryResponse
    Just e | e > now -> Just entryResponse
    _                -> Nothing

shouldCache :: UTCTime -> Req -> Response body -> Maybe (CacheEntry body)
shouldCache now (Req _ req) resp = do
  guard (possiblyCacheable req resp)
  let crs = resp ^.. responseHeader "Cache-Control" . atto_ parseCacheResponse .
                     folded . to simplifyCacheResponse
      dateHeader name = responseHeader name . to parseDate . folded
      mexpires = case crs of
                   [] -> resp ^? dateHeader "Expires"
                   _  -> computeExpiration now crs
      created = resp ^. pre (dateHeader "Date") . non now
  case mexpires of
    Just expires | expires <= created                -> empty
    Nothing      | req ^. method == "GET" &&
                   not (B.null (req ^. queryString)) -> empty
    _ -> return $ CacheEntry created mexpires resp

type Seconds = Int

data CacheResponse age = Public
                       | Private [HeaderName]
                       | NoCache [HeaderName]
                       | NoStore
                       | NoTransform
                       | MustRevalidate
                       | ProxyRevalidate
                       | MaxAge age
                       | SMaxAge age
                       | Extension
                       deriving (Eq, Show, Functor, Typeable, Generic)

instance Hashable age => Hashable (CacheResponse age)

simplifyCacheResponse :: CacheResponse age -> CacheResponse age
simplifyCacheResponse (Private _) = Private []
simplifyCacheResponse (NoCache _) = NoCache []
simplifyCacheResponse cr          = cr

parseCacheResponse :: A.Parser [CacheResponse Seconds]
parseCacheResponse = commaSep1 body
  where
    body = "public" *> pure Public
       <|> "private" *> (Private <$> (eq headerNames <|> pure []))
       <|> "no-cache" *> (NoCache <$> (eq headerNames <|> pure []))
       <|> "no-store" *> pure NoStore
       <|> "no-transform" *> pure NoTransform
       <|> "must-revalidate" *> pure MustRevalidate
       <|> "proxy-revalidate" *> pure ProxyRevalidate
       <|> "max-age" *> eq (MaxAge <$> decimal)
       <|> "s-maxage" *> eq (SMaxAge <$> decimal)
    headerNames = A.char '"' *> commaSep1 hdr <* A.char '"'
    hdr = mk <$> A.takeWhile1 (inClass "a-zA-Z0-9_-")
    commaSep1 p = (p <* skipSpace) `sepBy1` (A.char ',' *> skipSpace)
    eq p = skipSpace *> A.char '=' *> skipSpace *> p

parseDate :: B.ByteString -> Maybe UTCTime
parseDate s = getFirst . mconcat . map tryout $ [
    "%a, %d %b %Y %H:%M:%S %Z"
  , "%A, %d-%b-%y %H:%M:%S %Z"
  , "%a %b %e %H:%M:%S %Y"
  ]
  where tryout fmt = First $ parseTimeM True defaultTimeLocale fmt (B.unpack s)