module Network.Wai.Middleware.RealIp
( realIp
, realIpHeader
, realIpTrusted
, defaultTrusted
, ipInRange
) where
import qualified Data.ByteString.Char8 as B8 (split, unpack)
import qualified Data.IP as IP
import Data.Maybe (fromMaybe, listToMaybe, mapMaybe)
import Network.HTTP.Types (HeaderName, RequestHeaders)
import Network.Wai (Middleware, remoteHost, requestHeaders)
import Text.Read (readMaybe)
realIp :: Middleware
realIp :: Middleware
realIp = HeaderName -> Middleware
realIpHeader HeaderName
"X-Forwarded-For"
realIpHeader :: HeaderName -> Middleware
HeaderName
header =
HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header forall a b. (a -> b) -> a -> b
$ \IP
ip -> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPRange -> Bool
ipInRange IP
ip) [IPRange]
defaultTrusted
realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware
realIpTrusted :: HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header IP -> Bool
isTrusted Application
app Request
req Response -> IO ResponseReceived
respond = Application
app Request
req' Response -> IO ResponseReceived
respond
where
req' :: Request
req' = forall a. a -> Maybe a -> a
fromMaybe Request
req forall a b. (a -> b) -> a -> b
$ do
(IP
ip, PortNumber
port) <- SockAddr -> Maybe (IP, PortNumber)
IP.fromSockAddr (Request -> SockAddr
remoteHost Request
req)
IP
ip' <- if IP -> Bool
isTrusted IP
ip
then RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp (Request -> RequestHeaders
requestHeaders Request
req) HeaderName
header IP -> Bool
isTrusted
else forall a. Maybe a
Nothing
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Request
req { remoteHost :: SockAddr
remoteHost = (IP, PortNumber) -> SockAddr
IP.toSockAddr (IP
ip', PortNumber
port) }
defaultTrusted :: [IP.IPRange]
defaultTrusted :: [IPRange]
defaultTrusted = [ IPRange
"127.0.0.0/8"
, IPRange
"10.0.0.0/8"
, IPRange
"172.16.0.0/12"
, IPRange
"192.168.0.0/16"
, IPRange
"::1/128"
, IPRange
"fc00::/7"
]
ipInRange :: IP.IP -> IP.IPRange -> Bool
ipInRange :: IP -> IPRange -> Bool
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv4Range AddrRange IPv4
r) = IPv4
ip forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv4
r
ipInRange (IP.IPv6 IPv6
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv6
ip forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv4 -> IPv6
IP.ipv4ToIPv6 IPv4
ip forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange IP
_ IPRange
_ = Bool
False
findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP
findRealIp :: RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp RequestHeaders
reqHeaders HeaderName
header IP -> Bool
isTrusted =
case ([IP]
nonTrusted, [IP]
ips) of
([], [IP]
xs) -> forall a. [a] -> Maybe a
listToMaybe [IP]
xs
([IP]
xs, [IP]
_) -> forall a. [a] -> Maybe a
listToMaybe forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [IP]
xs
where
headerVals :: [ByteString]
headerVals = [ ByteString
v | (HeaderName
k, ByteString
v) <- RequestHeaders
reqHeaders, HeaderName
k forall a. Eq a => a -> a -> Bool
== HeaderName
header ]
ips :: [IP]
ips = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall a. Read a => String -> Maybe a
readMaybe forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Char -> ByteString -> [ByteString]
B8.split Char
',') [ByteString]
headerVals
nonTrusted :: [IP]
nonTrusted = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Bool
isTrusted) [IP]
ips