{-# LANGUAGE OverloadedStrings #-}

module Snap.Extras.CSRF where

------------------------------------------------------------------------------
import           Control.Monad.Trans
import qualified Data.ByteString.Char8 as B
import           Data.Text             (Text)
import qualified Data.Text.Encoding    as T
import           Heist
import           Heist.Interpreted
import           Snap
import           Snap.Snaplet.Session
import qualified Text.XmlHtml          as X
------------------------------------------------------------------------------



------------------------------------------------------------------------------
-- | A splice that makes the CSRF token available to templates.  Typically we
-- use it by binding a splice and using the CSRF token provided by the session
-- snaplet as follows:
--
-- @(\"csrfToken\", csrfTokenSplice $ with session 'csrfToken')@
--
-- Where @session@ is a lens to the session snaplet.  Then you can make it
-- available to javascript code by putting a meta tag at the top of every
-- page like this:
--
-- > <meta name="csrf-token" content="${csrfToken}">
csrfTokenSplice :: Monad m
                => m Text
                -- ^ A computation in the runtime monad that gets the
                -- CSRF protection token.
                -> Splice m
csrfTokenSplice f = do
    token <- lift f
    textSplice token


------------------------------------------------------------------------------
-- | Adds a hidden _csrf input field as the first child of the bound tag.  For
-- full site protection against CSRF, you should bind this splice to the form
-- tag, and then make sure your app checks all POST requests for the presence
-- of this CSRF token and that the token is randomly generated and secure on a
-- per session basis.
secureForm :: MonadIO m
           => m Text
           -- ^ A computation in the runtime monad that gets the CSRF
           -- protection token.
           -> Splice m
secureForm mToken = do
    n <- getParamNode
    token <- lift mToken
    let input = X.Element "input"
          [("type", "hidden"), ("name", "_csrf"), ("value", token)] []
    case n of
      X.Element nm as cs -> do
        cs' <- runNodeList cs
        let newCs = if take 1 cs' == [input] then cs' else (input : cs')
        stopRecursion
        return [X.Element nm as newCs]
      _ -> return [n] -- "impossible"


------------------------------------------------------------------------------
-- | Use this function to wrap your whole site with CSRF protection.  Due to
-- security considerations, the way Snap parses file uploads
-- means that the CSRF token cannot be checked before the file uploads have
-- been handled.  This function protects your whole site except for handlers
-- of multipart/form-data forms (forms with file uploads).  To protect those
-- handlers, you have to call handleCSRF explicitly after the file has been
-- processed.
blanketCSRF :: SnapletLens v SessionManager
            -- ^ Lens to the session snaplet
            -> Handler b v ()
            -- ^ Handler to run if the CSRF check fails
            -> Handler b v ()
            -- ^ Handler to let through when successful.
            -> Handler b v ()
blanketCSRF session onFailure onSucc = do
    h <- getHeader "content-type" `fmap` getRequest
    case maybe False (B.isInfixOf "multipart/form-data") h of
      True -> onSucc
      False -> handleCSRF session onFailure onSucc


------------------------------------------------------------------------------
-- | If a request is a POST, check the CSRF token and fail with the specified
-- handler if the check fails.  If if the token is correct or if it's not a
-- POST request, then control passes through as a no-op.
handleCSRF :: SnapletLens v SessionManager
           -- ^ Lens to the session snaplet
           -> Handler b v ()
           -- ^ Handler to run on failure
           -> Handler b v ()
           -- ^ Handler to let through when successful.
           -> Handler b v ()
handleCSRF session onFailure onSucc = do
    m <- getsRequest rqMethod
    case m /= POST of
      True ->  onSucc
      False -> do
        tok <- getParam "_csrf"
        realTok <- with session csrfToken
        if tok == Just (T.encodeUtf8 realTok)
          then onSucc
          else onFailure >> getResponse >>= finishWith


-------------------------------------------------------------------------------
-- | A version of 'handleCSRF' that works as an imperative filter.
-- It's a NOOP when successful, redirs to oblivion under failure.
handleCSRF'
    :: SnapletLens v SessionManager
    -> Handler b v ()
    -- ^ On failure
    -> Handler b v ()
handleCSRF' ses onFail = handleCSRF ses onFail (return ())