{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NamedFieldPuns #-}
module Web.LTI13 (
Role(..)
, ContextClaim(..)
, UncheckedLtiTokenClaims(..)
, LtiTokenClaims(..)
, validateLtiToken
, LTI13Exception(..)
, PlatformInfo(..)
, Issuer
, ClientId
, SessionStore(..)
, AuthFlowConfig(..)
, RequestParams
, initiate
, handleAuthResponse
) where
import qualified Web.OIDC.Client.Settings as O
import qualified Web.OIDC.Client.Discovery.Provider as P
import Web.OIDC.Client.Tokens (nonce, aud, otherClaims, iss, IdTokenClaims)
import Web.OIDC.Client.IdTokenFlow (getValidIdTokenClaims)
import Web.OIDC.Client.Types (Nonce, SessionStore(..))
import Jose.Jwa (JwsAlg(RS256))
import qualified Jose.Jwk as Jwk
import Control.Monad (when, (>=>))
import qualified Control.Monad.Fail as Fail
import Control.Exception.Safe (MonadCatch, catch, throwM, Typeable, Exception, MonadThrow, throw)
import Control.Monad.IO.Class (liftIO, MonadIO)
import Data.Aeson (eitherDecode, FromJSON (parseJSON), ToJSON(toJSON, toEncoding), Object,
object, pairs, withObject, withText, (.:), (.:?), (.=))
import qualified Data.Aeson as A
import Data.Aeson.Types (Parser)
import Data.Text (Text)
import qualified Network.HTTP.Types.URI as URI
import Network.HTTP.Client (responseBody, Manager, HttpException, parseRequest, httpLbs)
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8, decodeUtf8)
parseFixed :: (FromJSON a, Eq a, Show a) => Object -> Text -> a -> Parser a
parseFixed obj field fixedVal =
obj .: field >>= \v ->
if v == fixedVal then
return v
else
fail $ "field " ++ (show field) ++ " was not the required value " ++ (show fixedVal)
data Role = Administrator
| ContentDeveloper
| Instructor
| Learner
| Mentor
| Other (Text)
deriving (Show)
roleFromString :: Text -> Role
roleFromString "http://purl.imsglobal.org/vocab/lis/v2/membership#Administrator"
= Administrator
roleFromString "http://purl.imsglobal.org/vocab/lis/v2/membership#ContentDeveloper"
= ContentDeveloper
roleFromString "http://purl.imsglobal.org/vocab/lis/v2/membership#Instructor"
= Instructor
roleFromString "http://purl.imsglobal.org/vocab/lis/v2/membership#Learner"
= Learner
roleFromString "http://purl.imsglobal.org/vocab/lis/v2/membership#Mentor"
= Mentor
roleFromString s = Other s
roleToString :: Role -> Text
roleToString Administrator = "http://purl.imsglobal.org/vocab/lis/v2/membership#Administrator"
roleToString ContentDeveloper = "http://purl.imsglobal.org/vocab/lis/v2/membership#ContentDeveloper"
roleToString Instructor = "http://purl.imsglobal.org/vocab/lis/v2/membership#Instructor"
roleToString Learner = "http://purl.imsglobal.org/vocab/lis/v2/membership#Learner"
roleToString Mentor = "http://purl.imsglobal.org/vocab/lis/v2/membership#Mentor"
roleToString (Other s) = s
instance FromJSON Role where
parseJSON = withText "Role" $ return . roleFromString
instance ToJSON Role where
toJSON = A.String . roleToString
data ContextClaim = ContextClaim
{ contextId :: Text
, contextLabel :: Maybe Text
, contextTitle :: Maybe Text
}
deriving (Show)
instance FromJSON ContextClaim where
parseJSON = withObject "ContextClaim" $ \v ->
ContextClaim
<$> (v .: "id" >>= limitLength 255)
<*> v .:? "label"
<*> v .:? "title"
instance ToJSON ContextClaim where
toJSON (ContextClaim {contextId, contextLabel, contextTitle}) =
object [
"id" .= contextId
, "label" .= contextLabel
, "title" .= contextTitle
]
toEncoding (ContextClaim {contextId, contextLabel, contextTitle}) =
pairs (
"id" .= contextId <>
"label" .= contextLabel <>
"title" .= contextTitle
)
data UncheckedLtiTokenClaims = UncheckedLtiTokenClaims
{ messageType :: Text
, ltiVersion :: Text
, deploymentId :: Text
, targetLinkUri :: Text
, roles :: [Role]
, email :: Maybe Text
, context :: Maybe ContextClaim
} deriving (Show)
newtype LtiTokenClaims = LtiTokenClaims UncheckedLtiTokenClaims
deriving (Show)
limitLength :: (Fail.MonadFail m) => Int -> Text -> m Text
limitLength len string
| (T.length string) <= len
= return string
limitLength _ _ = fail "String is too long"
claimMessageType :: Text
claimMessageType = "https://purl.imsglobal.org/spec/lti/claim/message_type"
claimVersion :: Text
claimVersion = "https://purl.imsglobal.org/spec/lti/claim/version"
claimDeploymentId :: Text
claimDeploymentId = "https://purl.imsglobal.org/spec/lti/claim/deployment_id"
claimTargetLinkUri :: Text
claimTargetLinkUri = "https://purl.imsglobal.org/spec/lti/claim/target_link_uri"
claimRoles :: Text
claimRoles = "https://purl.imsglobal.org/spec/lti/claim/roles"
claimContext :: Text
claimContext = "https://purl.imsglobal.org/spec/lti/claim/context"
instance FromJSON UncheckedLtiTokenClaims where
parseJSON = withObject "LtiTokenClaims" $ \v ->
UncheckedLtiTokenClaims
<$> (parseFixed v claimMessageType "LtiResourceLinkRequest")
<*> (parseFixed v claimVersion "1.3.0")
<*> (v .: claimDeploymentId >>= limitLength 255)
<*> v .: claimTargetLinkUri
<*> v .: claimRoles
<*> v .:? "email"
<*> v .:? claimContext
instance ToJSON UncheckedLtiTokenClaims where
toJSON (UncheckedLtiTokenClaims {
messageType, ltiVersion, deploymentId
, targetLinkUri, roles, email, context}) =
object [
claimMessageType .= messageType
, claimVersion .= ltiVersion
, claimDeploymentId .= deploymentId
, claimTargetLinkUri .= targetLinkUri
, claimRoles .= roles
, "email" .= email
, claimContext .= context
]
toEncoding (UncheckedLtiTokenClaims {
messageType, ltiVersion, deploymentId
, targetLinkUri, roles, email, context}) =
pairs (
claimMessageType .= messageType
<> claimVersion .= ltiVersion
<> claimDeploymentId .= deploymentId
<> claimTargetLinkUri .= targetLinkUri
<> claimRoles .= roles
<> "email" .= email
<> claimContext .= context
)
validateLtiToken
:: PlatformInfo
-> IdTokenClaims UncheckedLtiTokenClaims
-> Either Text (IdTokenClaims LtiTokenClaims)
validateLtiToken pinfo claims =
valid .
(issuerMatches
>=> audContainsClientId
>=> hasNonce) $ claims
where
issuerMatches c
| iss c == platformIssuer pinfo
= Right claims
| otherwise
= Left "issuer does not match platform issuer"
audContainsClientId c
| (length $ aud c) == 1 && (platformClientId pinfo) `elem` (aud c)
= Right claims
| otherwise
= Left "aud is invalid"
hasNonce c =
case nonce c of
Just _ -> Right claims
Nothing -> Left "nonce missing"
valid :: Either Text (IdTokenClaims UncheckedLtiTokenClaims) -> Either Text (IdTokenClaims LtiTokenClaims)
valid (Left e) = Left e
valid (Right tok) =
Right tok { otherClaims = (LtiTokenClaims $ otherClaims tok) }
data LTI13Exception
= InvalidHandshake Text
| DiscoveryException Text
| GotHttpException HttpException
| InvalidLtiToken Text
deriving (Show, Typeable)
instance Exception LTI13Exception
type ClientId = Text
data PlatformInfo = PlatformInfo
{
platformIssuer :: Issuer
, platformClientId :: ClientId
, platformOidcAuthEndpoint :: Text
, jwksUrl :: String
}
type Issuer = Text
data AuthFlowConfig m = AuthFlowConfig
{ getPlatformInfo :: (Issuer, Maybe ClientId) -> m PlatformInfo
, haveSeenNonce :: Nonce -> m Bool
, myRedirectUri :: Text
, sessionStore :: SessionStore m
}
rethrow :: (MonadCatch m) => HttpException -> m a
rethrow = throwM . GotHttpException
getJwkSet
:: Manager
-> String
-> IO [Jwk.Jwk]
getJwkSet manager fromUrl = do
json <- getJwkSetJson fromUrl `catch` rethrow
case jwks json of
Right keys -> return keys
Left err -> throwM $ DiscoveryException ("Failed to decode JwkSet: " <> T.pack err)
where
getJwkSetJson url = do
req <- parseRequest url
res <- httpLbs req manager
return $ responseBody res
jwks j = Jwk.keys <$> eitherDecode j
lookupOrThrow :: (MonadThrow m) => Text -> Map.Map Text Text -> m Text
lookupOrThrow name map_ =
case Map.lookup name map_ of
Nothing -> throw $ InvalidHandshake $ "Missing `" <> name <> "`"
Just a -> return a
type RequestParams = Map.Map Text Text
initiate :: (MonadIO m) => AuthFlowConfig m -> RequestParams -> m (Issuer, ClientId, Text)
initiate cfg params = do
res <- liftIO $ mapM (flip lookupOrThrow params) ["iss", "login_hint", "target_link_uri"]
let [iss, loginHint, _] = res
let messageHint = Map.lookup "lti_message_hint" params
let gotCid = Map.lookup "client_id" params
PlatformInfo
{ platformOidcAuthEndpoint = endpoint
, platformClientId = clientId } <- (getPlatformInfo cfg) (iss, gotCid)
let ss = sessionStore cfg
nonce <- sessionStoreGenerate ss
state <- sessionStoreGenerate ss
sessionStoreSave ss state nonce
let query = URI.simpleQueryToQuery $
[ ("scope", "openid")
, ("response_type", "id_token")
, ("client_id", encodeUtf8 clientId)
, ("redirect_uri", encodeUtf8 $ myRedirectUri cfg)
, ("login_hint", encodeUtf8 loginHint)
, ("state", state)
, ("response_mode", "form_post")
, ("nonce", nonce)
, ("prompt", "none")
] ++ maybe [] (\mh -> [("lti_message_hint", encodeUtf8 mh)]) messageHint
return $ (iss, clientId, endpoint <> (decodeUtf8 . URI.renderQuery True) query)
fakeOidc :: [Jwk.Jwk] -> O.OIDC
fakeOidc jset = O.OIDC
{ O.oidcProvider = P.Provider
{ P.configuration = P.Configuration
{ P.idTokenSigningAlgValuesSupported = [ P.JwsAlgJson RS256 ]
, P.issuer = undefined
, P.authorizationEndpoint = undefined
, P.tokenEndpoint = undefined
, P.userinfoEndpoint = undefined
, P.revocationEndpoint = undefined
, P.jwksUri = undefined
, P.responseTypesSupported = undefined
, P.subjectTypesSupported = undefined
, P.scopesSupported = undefined
, P.tokenEndpointAuthMethodsSupported = undefined
, P.claimsSupported = undefined
}
, P.jwkSet = jset
}
, O.oidcAuthorizationServerUrl = undefined
, O.oidcTokenEndpoint = undefined
, O.oidcClientId = undefined
, O.oidcClientSecret = undefined
, O.oidcRedirectUri = undefined
}
handleAuthResponse :: (MonadIO m)
=> Manager
-> AuthFlowConfig m
-> RequestParams
-> PlatformInfo
-> m (Text, IdTokenClaims LtiTokenClaims)
handleAuthResponse mgr cfg params pinfo = do
params' <- liftIO $ mapM (flip lookupOrThrow params) ["state", "id_token"]
let [state, idToken] = params'
let PlatformInfo { jwksUrl } = pinfo
jwkSet <- liftIO $ getJwkSet mgr jwksUrl
let ss = sessionStore cfg
oidc = fakeOidc jwkSet
toCheck <- getValidIdTokenClaims ss oidc (encodeUtf8 state) (pure $ encodeUtf8 idToken)
nonceSeen <- case nonce toCheck of
Just n -> haveSeenNonce cfg n
Nothing -> liftIO $ throw $ InvalidLtiToken "missing nonce"
when nonceSeen (liftIO $ throw $ InvalidLtiToken "nonce seen before")
case validateLtiToken pinfo toCheck of
Left err -> liftIO $ throw $ InvalidLtiToken err
Right tok -> return (state, tok)