-- | Infer the remote IP address using headers
module Network.Wai.Middleware.RealIp
    ( realIp
    , realIpHeader
    , realIpTrusted
    , defaultTrusted
    , ipInRange
    ) where

import qualified Data.ByteString.Char8 as B8 (split, unpack)
import qualified Data.IP as IP
import Data.Maybe (fromMaybe, listToMaybe, mapMaybe)
import Network.HTTP.Types (HeaderName, RequestHeaders)
import Network.Wai (Middleware, remoteHost, requestHeaders)
import Text.Read (readMaybe)

-- | Infer the remote IP address from the @X-Forwarded-For@ header,
-- trusting requests from any private IP address. See 'realIpHeader' and
-- 'realIpTrusted' for more information and options.
--
-- @since 3.1.5
realIp :: Middleware
realIp :: Middleware
realIp = HeaderName -> Middleware
realIpHeader HeaderName
"X-Forwarded-For"

-- | Infer the remote IP address using the given header, trusting
-- requests from any private IP address. See 'realIpTrusted' for more
-- information and options.
--
-- @since 3.1.5
realIpHeader :: HeaderName -> Middleware
realIpHeader :: HeaderName -> Middleware
realIpHeader HeaderName
header =
    HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header ((IP -> Bool) -> Middleware) -> (IP -> Bool) -> Middleware
forall a b. (a -> b) -> a -> b
$ \IP
ip -> (IPRange -> Bool) -> [IPRange] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPRange -> Bool
ipInRange IP
ip) [IPRange]
defaultTrusted

-- | Infer the remote IP address using the given header, but only if the
-- request came from an IP that is trusted by the provided predicate.
--
-- The last non-trusted address is used to replace the 'remoteHost' in
-- the 'Request', unless all present IP addresses are trusted, in which
-- case the first address is used. Invalid IP addresses are ignored, and
-- the remoteHost value remains unaltered if no valid IP addresses are
-- found.
--
-- Examples:
--
-- @ realIpTrusted "X-Forwarded-For" $ flip ipInRange "10.0.0.0/8" @
--
-- @ realIpTrusted "X-Real-Ip" $ \\ip -> any (ipInRange ip) defaultTrusted @
--
-- @since 3.1.5
realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware
realIpTrusted :: HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header IP -> Bool
isTrusted Application
app Request
req Response -> IO ResponseReceived
respond = Application
app Request
req' Response -> IO ResponseReceived
respond
  where
    req' :: Request
req' = Request -> Maybe Request -> Request
forall a. a -> Maybe a -> a
fromMaybe Request
req (Maybe Request -> Request) -> Maybe Request -> Request
forall a b. (a -> b) -> a -> b
$ do
             (IP
ip, PortNumber
port) <- SockAddr -> Maybe (IP, PortNumber)
IP.fromSockAddr (Request -> SockAddr
remoteHost Request
req)
             IP
ip' <- if IP -> Bool
isTrusted IP
ip
                      then RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp (Request -> RequestHeaders
requestHeaders Request
req) HeaderName
header IP -> Bool
isTrusted
                      else Maybe IP
forall a. Maybe a
Nothing
             Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request) -> Request -> Maybe Request
forall a b. (a -> b) -> a -> b
$ Request
req { remoteHost :: SockAddr
remoteHost = (IP, PortNumber) -> SockAddr
IP.toSockAddr (IP
ip', PortNumber
port) }

-- | Standard private IP ranges.
--
-- @since 3.1.5
defaultTrusted :: [IP.IPRange]
defaultTrusted :: [IPRange]
defaultTrusted = [ IPRange
"127.0.0.0/8"
                 , IPRange
"10.0.0.0/8"
                 , IPRange
"172.16.0.0/12"
                 , IPRange
"192.168.0.0/16"
                 , IPRange
"::1/128"
                 , IPRange
"fc00::/7"
                 ]

-- | Check if the given IP address is in the given range.
--
-- IPv4 addresses can be checked against IPv6 ranges, but testing an
-- IPv6 address against an IPv4 range is always 'False'.
--
-- @since 3.1.5
ipInRange :: IP.IP -> IP.IPRange -> Bool
ipInRange :: IP -> IPRange -> Bool
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv4Range AddrRange IPv4
r) = IPv4
ip IPv4 -> AddrRange IPv4 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv4
r
ipInRange (IP.IPv6 IPv6
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv6
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv4 -> IPv6
IP.ipv4ToIPv6 IPv4
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange IP
_ IPRange
_ = Bool
False


findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP
findRealIp :: RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp RequestHeaders
reqHeaders HeaderName
header IP -> Bool
isTrusted =
    case ([IP]
nonTrusted, [IP]
ips) of
      ([], [IP]
xs) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe [IP]
xs
      ([IP]
xs, [IP]
_)  -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe ([IP] -> Maybe IP) -> [IP] -> Maybe IP
forall a b. (a -> b) -> a -> b
$ [IP] -> [IP]
forall a. [a] -> [a]
reverse [IP]
xs
  where
    -- account for repeated headers
    headerVals :: [ByteString]
headerVals = [ ByteString
v | (HeaderName
k, ByteString
v) <- RequestHeaders
reqHeaders, HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
header ]
    ips :: [IP]
ips = (ByteString -> Maybe IP) -> [ByteString] -> [IP]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (String -> Maybe IP
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe IP)
-> (ByteString -> String) -> ByteString -> Maybe IP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack) ([ByteString] -> [IP]) -> [ByteString] -> [IP]
forall a b. (a -> b) -> a -> b
$ (ByteString -> [ByteString]) -> [ByteString] -> [ByteString]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Char -> ByteString -> [ByteString]
B8.split Char
',') [ByteString]
headerVals
    nonTrusted :: [IP]
nonTrusted = (IP -> Bool) -> [IP] -> [IP]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (IP -> Bool) -> IP -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Bool
isTrusted) [IP]
ips