{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}

-- | This module provides facilities for patching incoming 'Requests' to
-- correct the value of 'rqClientAddr' if the snap server is running behind a
-- proxy.
--
-- Example usage:
--
-- @
-- m :: Snap ()
-- m = undefined  -- code goes here
--
-- applicationHandler :: Snap ()
-- applicationHandler = behindProxy X_Forwarded_For m
-- @
--
module Snap.Util.Proxy
  ( ProxyType(..)
  , behindProxy
  ) where

------------------------------------------------------------------------------
import           Control.Applicative   (Alternative ((<|>)))
import           Control.Monad         (mfilter)
import qualified Data.ByteString.Char8 as S (breakEnd, dropWhile, null, readInt, spanEnd)
import           Data.Char             (isSpace)
import           Data.Maybe            (fromMaybe)
import           Snap.Core             (MonadSnap, Request (rqClientAddr, rqClientPort), getHeader, modifyRequest)
------------------------------------------------------------------------------


------------------------------------------------------------------------------
-- | What kind of proxy is this? Affects which headers 'behindProxy' pulls the
-- original remote address from.
--
-- Currently only proxy servers that send @X-Forwarded-For@ or @Forwarded-For@
-- are supported.
data ProxyType = NoProxy          -- ^ no proxy, leave the request alone
               | X_Forwarded_For  -- ^ Use the @Forwarded-For@ or
                                  --   @X-Forwarded-For@ header
  deriving (ReadPrec [ProxyType]
ReadPrec ProxyType
Int -> ReadS ProxyType
ReadS [ProxyType]
(Int -> ReadS ProxyType)
-> ReadS [ProxyType]
-> ReadPrec ProxyType
-> ReadPrec [ProxyType]
-> Read ProxyType
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ProxyType]
$creadListPrec :: ReadPrec [ProxyType]
readPrec :: ReadPrec ProxyType
$creadPrec :: ReadPrec ProxyType
readList :: ReadS [ProxyType]
$creadList :: ReadS [ProxyType]
readsPrec :: Int -> ReadS ProxyType
$creadsPrec :: Int -> ReadS ProxyType
Read, Int -> ProxyType -> ShowS
[ProxyType] -> ShowS
ProxyType -> String
(Int -> ProxyType -> ShowS)
-> (ProxyType -> String)
-> ([ProxyType] -> ShowS)
-> Show ProxyType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ProxyType] -> ShowS
$cshowList :: [ProxyType] -> ShowS
show :: ProxyType -> String
$cshow :: ProxyType -> String
showsPrec :: Int -> ProxyType -> ShowS
$cshowsPrec :: Int -> ProxyType -> ShowS
Show, ProxyType -> ProxyType -> Bool
(ProxyType -> ProxyType -> Bool)
-> (ProxyType -> ProxyType -> Bool) -> Eq ProxyType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ProxyType -> ProxyType -> Bool
$c/= :: ProxyType -> ProxyType -> Bool
== :: ProxyType -> ProxyType -> Bool
$c== :: ProxyType -> ProxyType -> Bool
Eq, Eq ProxyType
Eq ProxyType
-> (ProxyType -> ProxyType -> Ordering)
-> (ProxyType -> ProxyType -> Bool)
-> (ProxyType -> ProxyType -> Bool)
-> (ProxyType -> ProxyType -> Bool)
-> (ProxyType -> ProxyType -> Bool)
-> (ProxyType -> ProxyType -> ProxyType)
-> (ProxyType -> ProxyType -> ProxyType)
-> Ord ProxyType
ProxyType -> ProxyType -> Bool
ProxyType -> ProxyType -> Ordering
ProxyType -> ProxyType -> ProxyType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ProxyType -> ProxyType -> ProxyType
$cmin :: ProxyType -> ProxyType -> ProxyType
max :: ProxyType -> ProxyType -> ProxyType
$cmax :: ProxyType -> ProxyType -> ProxyType
>= :: ProxyType -> ProxyType -> Bool
$c>= :: ProxyType -> ProxyType -> Bool
> :: ProxyType -> ProxyType -> Bool
$c> :: ProxyType -> ProxyType -> Bool
<= :: ProxyType -> ProxyType -> Bool
$c<= :: ProxyType -> ProxyType -> Bool
< :: ProxyType -> ProxyType -> Bool
$c< :: ProxyType -> ProxyType -> Bool
compare :: ProxyType -> ProxyType -> Ordering
$ccompare :: ProxyType -> ProxyType -> Ordering
$cp1Ord :: Eq ProxyType
Ord)


------------------------------------------------------------------------------
-- | Rewrite 'rqClientAddr' if we're behind a proxy.
--
-- Example:
--
-- @
-- ghci> :set -XOverloadedStrings
-- ghci> import qualified "Data.Map" as M
-- ghci> import qualified "Snap.Test" as T
-- ghci> let r = T.get \"\/foo\" M.empty >> T.addHeader \"X-Forwarded-For\" \"1.2.3.4\"
-- ghci> let h = 'Snap.Core.getsRequest' 'rqClientAddr' >>= 'Snap.Core.writeBS')
-- ghci> T.runHandler r h
-- HTTP\/1.1 200 OK
-- server: Snap\/test
-- date: Fri, 08 Aug 2014 14:32:29 GMT
--
-- 127.0.0.1
-- ghci> T.runHandler r ('behindProxy' 'X_Forwarded_For' h)
-- HTTP\/1.1 200 OK
-- server: Snap\/test
-- date: Fri, 08 Aug 2014 14:33:02 GMT
--
-- 1.2.3.4
-- @
behindProxy :: MonadSnap m => ProxyType -> m a -> m a
behindProxy :: ProxyType -> m a -> m a
behindProxy ProxyType
NoProxy         = m a -> m a
forall a. a -> a
id
behindProxy ProxyType
X_Forwarded_For = (((Request -> Request) -> m ()
forall (m :: * -> *). MonadSnap m => (Request -> Request) -> m ()
modifyRequest Request -> Request
xForwardedFor) m () -> m a -> m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>)
{-# INLINE behindProxy #-}


------------------------------------------------------------------------------
xForwardedFor :: Request -> Request
xForwardedFor :: Request -> Request
xForwardedFor Request
req = Request
req { rqClientAddr :: ByteString
rqClientAddr = ByteString
ip
                        , rqClientPort :: Int
rqClientPort = Int
port
                        }
  where
    extract :: ByteString -> ByteString
extract = (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, ByteString) -> ByteString)
-> (ByteString -> (ByteString, ByteString))
-> ByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> (ByteString, ByteString)
S.spanEnd Char -> Bool
isSpace (ByteString -> (ByteString, ByteString))
-> (ByteString -> ByteString)
-> ByteString
-> (ByteString, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> ByteString
S.dropWhile Char -> Bool
isSpace (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((ByteString, ByteString) -> ByteString)
-> (ByteString -> (ByteString, ByteString))
-> ByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> (ByteString, ByteString)
S.breakEnd (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
',')

    ip :: ByteString
ip      = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe (Request -> ByteString
rqClientAddr Request
req) (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> Bool) -> Maybe ByteString -> Maybe ByteString
forall (m :: * -> *) a. MonadPlus m => (a -> Bool) -> m a -> m a
mfilter (Bool -> Bool
not (Bool -> Bool) -> (ByteString -> Bool) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
S.null) (Maybe ByteString -> Maybe ByteString)
-> Maybe ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> Maybe ByteString -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
extract (Maybe ByteString -> Maybe ByteString)
-> Maybe ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
              CI ByteString -> Request -> Maybe ByteString
forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
getHeader CI ByteString
"Forwarded-For"   Request
req  Maybe ByteString -> Maybe ByteString -> Maybe ByteString
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
              CI ByteString -> Request -> Maybe ByteString
forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
getHeader CI ByteString
"X-Forwarded-For" Request
req

    port :: Int
port    = Int -> ((Int, ByteString) -> Int) -> Maybe (Int, ByteString) -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Request -> Int
rqClientPort Request
req) (Int, ByteString) -> Int
forall a b. (a, b) -> a
fst (Maybe (Int, ByteString) -> Int) -> Maybe (Int, ByteString) -> Int
forall a b. (a -> b) -> a -> b
$ (ByteString -> Maybe (Int, ByteString)
S.readInt (ByteString -> Maybe (Int, ByteString))
-> Maybe ByteString -> Maybe (Int, ByteString)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (Maybe ByteString -> Maybe (Int, ByteString))
-> Maybe ByteString -> Maybe (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> Maybe ByteString -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
extract (Maybe ByteString -> Maybe ByteString)
-> Maybe ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
              CI ByteString -> Request -> Maybe ByteString
forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
getHeader CI ByteString
"Forwarded-Port"   Request
req  Maybe ByteString -> Maybe ByteString -> Maybe ByteString
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
              CI ByteString -> Request -> Maybe ByteString
forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
getHeader CI ByteString
"X-Forwarded-Port" Request
req
{-# INLINE xForwardedFor #-}