module Network.HTTP.Client.Request
( parseUrl
, setUriRelative
, getUri
, setUri
, browserDecompress
, alwaysDecompress
, addProxy
, applyBasicAuth
, applyBasicProxyAuth
, urlEncodedBody
, needsGunzip
, requestBuilder
, useDefaultTimeout
, setQueryString
, streamFile
, observedStreamFile
, username
, password
) where
import Data.Int (Int64)
import Data.Maybe (fromMaybe, isJust)
import Data.Monoid (mempty, mappend)
import Data.String (IsString(..))
import Data.Char (toLower)
import Control.Applicative ((<$>))
import Control.Monad (when, unless)
import Numeric (showHex)
import Data.Default.Class (Default (def))
import Blaze.ByteString.Builder (Builder, fromByteString, fromLazyByteString, toByteStringIO, flush)
import Blaze.ByteString.Builder.Char8 (fromChar, fromShow)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Lazy.Internal (defaultChunkSize)
import qualified Network.HTTP.Types as W
import Network.URI (URI (..), URIAuth (..), parseURI, relativeTo, escapeURIString, isAllowedInURI, isReserved)
import Control.Monad.IO.Class (liftIO)
import Control.Exception (Exception, toException, throw, throwIO, IOException)
import qualified Control.Exception as E
import qualified Data.CaseInsensitive as CI
import qualified Data.ByteString.Base64 as B64
import Network.HTTP.Client.Types
import Network.HTTP.Client.Util
import Network.HTTP.Client.Connection
import Network.HTTP.Client.Util (readDec, (<>))
import Data.Time.Clock
import Control.Monad.Catch (MonadThrow, throwM)
import Data.IORef
import System.IO (withBinaryFile, hTell, hFileSize, Handle, IOMode (ReadMode))
parseUrl :: MonadThrow m => String -> m Request
parseUrl s =
case parseURI (encode s) of
Just uri -> setUri def uri
Nothing -> throwM $ InvalidUrlException s "Invalid URL"
where
encode = escapeURIString isAllowedInURI
setUriRelative :: MonadThrow m => Request -> URI -> m Request
setUriRelative req uri =
#if MIN_VERSION_network(2,4,0)
setUri req $ uri `relativeTo` getUri req
#else
case uri `relativeTo` getUri req of
Just uri' -> setUri req uri'
Nothing -> throwM $ InvalidUrlException (show uri) "Invalid URL"
#endif
getUri :: Request -> URI
getUri req = URI
{ uriScheme = if secure req
then "https:"
else "http:"
, uriAuthority = Just URIAuth
{ uriUserInfo = ""
, uriRegName = S8.unpack $ host req
, uriPort = ':' : show (port req)
}
, uriPath = S8.unpack $ path req
, uriQuery =
case S8.uncons $ queryString req of
Just (c, _) | c /= '?' -> '?' : (S8.unpack $ queryString req)
_ -> S8.unpack $ queryString req
, uriFragment = ""
}
applyAnyUriBasedAuth :: URI -> Request -> Request
applyAnyUriBasedAuth uri req =
if hasAuth
then applyBasicAuth (S8.pack theuser) (S8.pack thepass) req
else req
where
hasAuth = (notEmpty theuser) && (notEmpty thepass)
notEmpty = not . null
theuser = username authInfo
thepass = password authInfo
authInfo = maybe "" uriUserInfo $ uriAuthority uri
username :: String -> String
username = encode . takeWhile (/=':') . authPrefix
password :: String -> String
password = encode . takeWhile (/='@') . drop 1 . dropWhile (/=':')
encode :: String -> String
encode = escapeURIString (not . isReserved)
authPrefix :: String -> String
authPrefix u = if '@' `elem` u then takeWhile (/= '@') u else ""
setUri :: MonadThrow m => Request -> URI -> m Request
setUri req uri = do
sec <- parseScheme uri
auth <- maybe (failUri "URL must be absolute") return $ uriAuthority uri
port' <- parsePort sec auth
return $ applyAnyUriBasedAuth uri req
{ host = S8.pack $ uriRegName auth
, port = port'
, secure = sec
, path = S8.pack $
if null $ uriPath uri
then "/"
else uriPath uri
, queryString = S8.pack $ uriQuery uri
}
where
failUri :: MonadThrow m => String -> m a
failUri = throwM . InvalidUrlException (show uri)
parseScheme URI{uriScheme = scheme} =
case map toLower scheme of
"http:" -> return False
"https:" -> return True
_ -> failUri "Invalid scheme"
parsePort sec URIAuth{uriPort = portStr} =
case portStr of
':':rest -> maybe
(failUri "Invalid port")
return
(readDec rest)
_ -> case sec of
False -> return 80
True -> return 443
instance Show Request where
show x = unlines
[ "Request {"
, " host = " ++ show (host x)
, " port = " ++ show (port x)
, " secure = " ++ show (secure x)
, " requestHeaders = " ++ show (requestHeaders x)
, " path = " ++ show (path x)
, " queryString = " ++ show (queryString x)
, " method = " ++ show (method x)
, " proxy = " ++ show (proxy x)
, " rawBody = " ++ show (rawBody x)
, " redirectCount = " ++ show (redirectCount x)
, " responseTimeout = " ++ show (responseTimeout x)
, " requestVersion = " ++ show (requestVersion x)
, "}"
]
useDefaultTimeout :: Maybe Int
useDefaultTimeout = Just (3425)
instance Default Request where
def = Request
{ host = "localhost"
, port = 80
, secure = False
, requestHeaders = []
, path = "/"
, queryString = S8.empty
, requestBody = RequestBodyLBS L.empty
, method = "GET"
, proxy = Nothing
, hostAddress = Nothing
, rawBody = False
, decompress = browserDecompress
, redirectCount = 10
, checkStatus = \s@(W.Status sci _) hs cookie_jar ->
if 200 <= sci && sci < 300
then Nothing
else Just $ toException $ StatusCodeException s hs cookie_jar
, responseTimeout = useDefaultTimeout
, getConnectionWrapper = \mtimeout exc f ->
case mtimeout of
Nothing -> fmap ((,) Nothing) f
Just timeout' -> do
before <- getCurrentTime
mres <- timeout timeout' f
case mres of
Nothing -> throwIO exc
Just res -> do
now <- getCurrentTime
let timeSpentMicro = diffUTCTime now before * 1000000
remainingTime = round $ fromIntegral timeout' timeSpentMicro
if remainingTime <= 0
then throwIO exc
else return (Just remainingTime, res)
, cookieJar = Just def
, requestVersion = W.http11
, onRequestBodyException = \se ->
case E.fromException se of
Just (_ :: IOException) -> return ()
Nothing -> throwIO se
}
instance IsString Request where
fromString s =
case parseUrl s of
Left e -> throw e
Right r -> r
alwaysDecompress :: S.ByteString -> Bool
alwaysDecompress = const True
browserDecompress :: S.ByteString -> Bool
browserDecompress = (/= "application/x-tar")
applyBasicAuth :: S.ByteString -> S.ByteString -> Request -> Request
applyBasicAuth user passwd req =
req { requestHeaders = authHeader : requestHeaders req }
where
authHeader = (CI.mk "Authorization", basic)
basic = S8.append "Basic " (B64.encode $ S8.concat [ user, ":", passwd ])
addProxy :: S.ByteString -> Int -> Request -> Request
addProxy hst prt req =
req { proxy = Just $ Proxy hst prt }
applyBasicProxyAuth :: S.ByteString -> S.ByteString -> Request -> Request
applyBasicProxyAuth user passwd req =
req { requestHeaders = authHeader : requestHeaders req }
where
authHeader = (CI.mk "Proxy-Authorization", basic)
basic = S8.append "Basic " (B64.encode $ S8.concat [ user , ":", passwd ])
urlEncodedBody :: [(S.ByteString, S.ByteString)] -> Request -> Request
urlEncodedBody headers req = req
{ requestBody = RequestBodyLBS body
, method = "POST"
, requestHeaders =
(ct, "application/x-www-form-urlencoded")
: filter (\(x, _) -> x /= ct) (requestHeaders req)
}
where
ct = "Content-Type"
body = L.fromChunks . return $ W.renderSimpleQuery False headers
needsGunzip :: Request
-> [W.Header]
-> Bool
needsGunzip req hs' =
not (rawBody req)
&& ("content-encoding", "gzip") `elem` hs'
&& decompress req (fromMaybe "" $ lookup "content-type" hs')
requestBuilder :: Request -> Connection -> IO (Maybe (IO ()))
requestBuilder req Connection {..}
| expectContinue = flushHeaders >> return (Just (checkBadSend sendLater))
| otherwise = sendNow >> return Nothing
where
expectContinue = Just "100-continue" == lookup "Expect" (requestHeaders req)
checkBadSend f = f `E.catch` onRequestBodyException req
writeBuilder = toByteStringIO connectionWrite
writeHeadersWith = writeBuilder . (builder `mappend`)
flushHeaders = writeHeadersWith flush
(contentLength, sendNow, sendLater) =
case requestBody req of
RequestBodyLBS lbs ->
let body = fromLazyByteString lbs
now = checkBadSend $ writeHeadersWith body
later = writeBuilder body
in (Just (L.length lbs), now, later)
RequestBodyBS bs ->
let body = fromByteString bs
now = checkBadSend $ writeHeadersWith body
later = writeBuilder body
in (Just (fromIntegral $ S.length bs), now, later)
RequestBodyBuilder len body ->
let now = checkBadSend $ writeHeadersWith body
later = writeBuilder body
in (Just len, now, later)
RequestBodyStream len stream ->
let body = writeStream False stream
now = flushHeaders >> checkBadSend body
in (Just len, now, body)
RequestBodyStreamChunked stream ->
let body = writeStream True stream
now = flushHeaders >> checkBadSend body
in (Nothing, now, body)
writeStream isChunked withStream =
withStream loop
where
loop stream = do
bs <- stream
if S.null bs
then when isChunked $ connectionWrite "0\r\n\r\n"
else do
connectionWrite $
if isChunked
then S.concat
[ S8.pack $ showHex (S.length bs) "\r\n"
, bs
, "\r\n"
]
else bs
loop stream
hh
| port req == 80 && not (secure req) = host req
| port req == 443 && secure req = host req
| otherwise = host req <> S8.pack (':' : show (port req))
requestProtocol
| secure req = fromByteString "https://"
| otherwise = fromByteString "http://"
requestHostname
| isJust (proxy req) && not (secure req)
= requestProtocol <> fromByteString hh
| otherwise = mempty
contentLengthHeader (Just contentLength') =
if method req `elem` ["GET", "HEAD"] && contentLength' == 0
then id
else (:) ("Content-Length", S8.pack $ show contentLength')
contentLengthHeader Nothing = (:) ("Transfer-Encoding", "chunked")
acceptEncodingHeader =
case lookup "Accept-Encoding" $ requestHeaders req of
Nothing -> (("Accept-Encoding", "gzip"):)
Just "" -> filter (\(k, _) -> k /= "Accept-Encoding")
Just _ -> id
hostHeader x =
case lookup "Host" x of
Nothing -> ("Host", hh) : x
Just{} -> x
headerPairs :: W.RequestHeaders
headerPairs = hostHeader
$ acceptEncodingHeader
$ contentLengthHeader contentLength
$ requestHeaders req
builder :: Builder
builder =
fromByteString (method req)
<> fromByteString " "
<> requestHostname
<> (case S8.uncons $ path req of
Just ('/', _) -> fromByteString $ path req
_ -> fromChar '/' <> fromByteString (path req))
<> (case S8.uncons $ queryString req of
Nothing -> mempty
Just ('?', _) -> fromByteString $ queryString req
_ -> fromChar '?' <> fromByteString (queryString req))
<> (case requestVersion req of
W.HttpVersion 1 1 -> fromByteString " HTTP/1.1\r\n"
W.HttpVersion 1 0 -> fromByteString " HTTP/1.0\r\n"
version ->
fromChar ' ' <>
fromShow version <>
fromByteString "\r\n")
<> foldr
(\a b -> headerPairToBuilder a <> b)
(fromByteString "\r\n")
headerPairs
headerPairToBuilder (k, v) =
fromByteString (CI.original k)
<> fromByteString ": "
<> fromByteString v
<> fromByteString "\r\n"
setQueryString :: [(S.ByteString, Maybe S.ByteString)] -> Request -> Request
setQueryString qs req = req { queryString = W.renderQuery True qs }
streamFile :: FilePath -> IO RequestBody
streamFile = observedStreamFile (\_ -> return ())
observedStreamFile :: (StreamFileStatus -> IO ()) -> FilePath -> IO RequestBody
observedStreamFile obs path = do
size <- fromIntegral <$> withBinaryFile path ReadMode hFileSize
let filePopper :: Handle -> Popper
filePopper h = do
bs <- S.hGetSome h defaultChunkSize
currentPosition <- fromIntegral <$> hTell h
obs $ StreamFileStatus
{ fileSize = size
, readSoFar = currentPosition
, thisChunkSize = S.length bs
}
return bs
givesFilePopper :: GivesPopper ()
givesFilePopper k = withBinaryFile path ReadMode $ \h -> do
k (filePopper h)
return $ RequestBodyStream size givesFilePopper