{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_HADDOCK -ignore-exports #-}
module Network.OAuth.OAuth2.Internal where
import Control.Applicative
import Control.Arrow (second)
import Control.Monad.Catch
import Data.Aeson
import Data.Aeson.Types
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Maybe
import Data.Monoid
import Data.Text (Text, pack)
import Data.Text.Encoding
import GHC.Generics
import Lens.Micro
import Lens.Micro.Extras
import Network.HTTP.Conduit as C
import qualified Network.HTTP.Types as H
import URI.ByteString
import URI.ByteString.Aeson ()
data OAuth2 = OAuth2 {
oauthClientId :: Text
, oauthClientSecret :: Text
, oauthOAuthorizeEndpoint :: URI
, oauthAccessTokenEndpoint :: URI
, oauthCallback :: Maybe URI
} deriving (Show, Eq)
newtype AccessToken = AccessToken { atoken :: Text } deriving (Show, FromJSON, ToJSON)
newtype RefreshToken = RefreshToken { rtoken :: Text } deriving (Show, FromJSON, ToJSON)
newtype IdToken = IdToken { idtoken :: Text } deriving (Show, FromJSON, ToJSON)
newtype ExchangeToken = ExchangeToken { extoken :: Text } deriving (Show, FromJSON, ToJSON)
data OAuth2Token = OAuth2Token {
accessToken :: AccessToken
, refreshToken :: Maybe RefreshToken
, expiresIn :: Maybe Int
, tokenType :: Maybe Text
, idToken :: Maybe IdToken
} deriving (Show, Generic)
instance FromJSON OAuth2Token where
parseJSON = genericParseJSON defaultOptions { fieldLabelModifier = camelTo2 '_' }
instance ToJSON OAuth2Token where
toEncoding = genericToEncoding defaultOptions { fieldLabelModifier = camelTo2 '_' }
data OAuth2Error a =
OAuth2Error
{ error :: Either Text a
, errorDescription :: Maybe Text
, errorUri :: Maybe (URIRef Absolute) }
deriving (Show, Eq, Generic)
instance FromJSON err => FromJSON (OAuth2Error err) where
parseJSON (Object a) =
do
err <- (a .: "error") >>= (\str -> Right <$> parseJSON str <|> Left <$> parseJSON str)
desc <- a .:? "error_description"
uri <- a .:? "error_uri"
return $ OAuth2Error err desc uri
parseJSON _ = fail "Expected an object"
instance ToJSON err => ToJSON (OAuth2Error err) where
toEncoding = genericToEncoding defaultOptions { constructorTagModifier = camelTo2 '_', allNullaryToStringTag = True }
parseOAuth2Error :: FromJSON err => BSL.ByteString -> OAuth2Error err
parseOAuth2Error string =
either (mkDecodeOAuth2Error string) id (eitherDecode string)
mkDecodeOAuth2Error :: BSL.ByteString -> String -> OAuth2Error err
mkDecodeOAuth2Error response err =
OAuth2Error
(Left "Decode error")
(Just $ pack $ "Error: " <> err <> "\n Original Response:\n" <> show (decodeUtf8 $ BSL.toStrict response))
Nothing
type OAuth2Result err a = Either (OAuth2Error err) a
type PostBody = [(BS.ByteString, BS.ByteString)]
type QueryParams = [(BS.ByteString, BS.ByteString)]
authorizationUrl :: OAuth2 -> URI
authorizationUrl oa = over (queryL . queryPairsL) (++ queryParts) (oauthOAuthorizeEndpoint oa)
where queryParts = catMaybes [ Just ("client_id", encodeUtf8 $ oauthClientId oa)
, Just ("response_type", "code")
, fmap (("redirect_uri",) . serializeURIRef') (oauthCallback oa) ]
accessTokenUrl :: OAuth2
-> ExchangeToken
-> (URI, PostBody)
accessTokenUrl oa code = accessTokenUrl' oa code (Just "authorization_code")
accessTokenUrl' :: OAuth2
-> ExchangeToken
-> Maybe Text
-> (URI, PostBody)
accessTokenUrl' oa code gt = (uri, body)
where uri = oauthAccessTokenEndpoint oa
body = catMaybes [ Just ("code", encodeUtf8 $ extoken code)
, (("redirect_uri",) . serializeURIRef') <$> oauthCallback oa
, fmap (("grant_type",) . encodeUtf8) gt
]
refreshAccessTokenUrl :: OAuth2
-> RefreshToken
-> (URI, PostBody)
refreshAccessTokenUrl oa token = (uri, body)
where uri = oauthAccessTokenEndpoint oa
body = [ ("grant_type", "refresh_token")
, ("refresh_token", encodeUtf8 $ rtoken token)
]
appendAccessToken :: URIRef a
-> AccessToken
-> URIRef a
appendAccessToken uri t = over (queryL . queryPairsL) (\query -> query ++ accessTokenToParam t) uri
accessTokenToParam :: AccessToken -> [(BS.ByteString, BS.ByteString)]
accessTokenToParam t = [("access_token", encodeUtf8 $ atoken t)]
appendQueryParams :: [(BS.ByteString, BS.ByteString)] -> URIRef a -> URIRef a
appendQueryParams params =
over (queryL . queryPairsL) (params ++ )
uriToRequest :: MonadThrow m => URI -> m Request
uriToRequest uri = do
ssl <- case view (uriSchemeL . schemeBSL) uri of
"http" -> return False
"https" -> return True
s -> throwM $ InvalidUrlException (show uri) ("Invalid scheme: " ++ show s)
let
query = fmap (second Just) (view (queryL . queryPairsL) uri)
hostL = authorityL . _Just . authorityHostL . hostBSL
portL = authorityL . _Just . authorityPortL . _Just . portNumberL
defaultPort = (if ssl then 443 else 80) :: Int
req = setQueryString query $ defaultRequest {
secure = ssl,
path = view pathL uri
}
req2 = (over hostLens . maybe id const . preview hostL) uri req
req3 = (over portLens . maybe (const defaultPort) const . preview portL) uri req2
return req3
requestToUri :: Request -> URI
requestToUri req =
URI
(Scheme (if secure req
then "https"
else "http"))
(Just (Authority Nothing (Host $ host req) (Just $ Port $ port req)))
(path req)
(Query $ H.parseSimpleQuery $ queryString req)
Nothing
hostLens :: Lens' Request BS.ByteString
hostLens f req = f (C.host req) <&> \h' -> req { C.host = h' }
{-# INLINE hostLens #-}
portLens :: Lens' Request Int
portLens f req = f (C.port req) <&> \p' -> req { C.port = p' }
{-# INLINE portLens #-}