module Network.Wai.Middleware.RealIp
( realIp
, realIpHeader
, realIpTrusted
, defaultTrusted
, ipInRange
) where
import qualified Data.ByteString.Char8 as B8 (unpack, split)
import qualified Data.IP as IP
import Data.Maybe (fromMaybe, mapMaybe, listToMaybe)
import Network.HTTP.Types (HeaderName, RequestHeaders)
import Network.Wai
import Text.Read (readMaybe)
realIp :: Middleware
realIp :: Middleware
realIp = HeaderName -> Middleware
realIpHeader HeaderName
"X-Forwarded-For"
realIpHeader :: HeaderName -> Middleware
HeaderName
header =
HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header ((IP -> Bool) -> Middleware) -> (IP -> Bool) -> Middleware
forall a b. (a -> b) -> a -> b
$ \IP
ip -> (IPRange -> Bool) -> [IPRange] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPRange -> Bool
ipInRange IP
ip) [IPRange]
defaultTrusted
realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware
realIpTrusted :: HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header IP -> Bool
isTrusted Application
app Request
req Response -> IO ResponseReceived
respond = Application
app Request
req' Response -> IO ResponseReceived
respond
where
req' :: Request
req' = Request -> Maybe Request -> Request
forall a. a -> Maybe a -> a
fromMaybe Request
req (Maybe Request -> Request) -> Maybe Request -> Request
forall a b. (a -> b) -> a -> b
$ do
(IP
ip, PortNumber
port) <- SockAddr -> Maybe (IP, PortNumber)
IP.fromSockAddr (Request -> SockAddr
remoteHost Request
req)
IP
ip' <- if IP -> Bool
isTrusted IP
ip
then RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp (Request -> RequestHeaders
requestHeaders Request
req) HeaderName
header IP -> Bool
isTrusted
else Maybe IP
forall a. Maybe a
Nothing
Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request) -> Request -> Maybe Request
forall a b. (a -> b) -> a -> b
$ Request
req { remoteHost :: SockAddr
remoteHost = (IP, PortNumber) -> SockAddr
IP.toSockAddr (IP
ip', PortNumber
port) }
defaultTrusted :: [IP.IPRange]
defaultTrusted :: [IPRange]
defaultTrusted = [ IPRange
"127.0.0.0/8"
, IPRange
"10.0.0.0/8"
, IPRange
"172.16.0.0/12"
, IPRange
"192.168.0.0/16"
, IPRange
"::1/128"
, IPRange
"fc00::/7"
]
ipInRange :: IP.IP -> IP.IPRange -> Bool
ipInRange :: IP -> IPRange -> Bool
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv4Range AddrRange IPv4
r) = IPv4
ip IPv4 -> AddrRange IPv4 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv4
r
ipInRange (IP.IPv6 IPv6
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv6
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv4 -> IPv6
IP.ipv4ToIPv6 IPv4
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange IP
_ IPRange
_ = Bool
False
findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP
findRealIp :: RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp RequestHeaders
reqHeaders HeaderName
header IP -> Bool
isTrusted =
case ([IP]
nonTrusted, [IP]
ips) of
([], [IP]
xs) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe [IP]
xs
([IP]
xs, [IP]
_) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe ([IP] -> Maybe IP) -> [IP] -> Maybe IP
forall a b. (a -> b) -> a -> b
$ [IP] -> [IP]
forall a. [a] -> [a]
reverse [IP]
xs
where
headerVals :: [ByteString]
headerVals = [ ByteString
v | (HeaderName
k, ByteString
v) <- RequestHeaders
reqHeaders, HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
header ]
ips :: [IP]
ips = (ByteString -> Maybe IP) -> [ByteString] -> [IP]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (String -> Maybe IP
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe IP)
-> (ByteString -> String) -> ByteString -> Maybe IP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack) ([ByteString] -> [IP]) -> [ByteString] -> [IP]
forall a b. (a -> b) -> a -> b
$ (ByteString -> [ByteString]) -> [ByteString] -> [ByteString]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Char -> ByteString -> [ByteString]
B8.split Char
',') [ByteString]
headerVals
nonTrusted :: [IP]
nonTrusted = (IP -> Bool) -> [IP] -> [IP]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (IP -> Bool) -> IP -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Bool
isTrusted) [IP]
ips