{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}
module Yesod.Auth.LTI13 (
PlatformInfo(..)
, Issuer
, ClientId
, Nonce
, authLTI13
, authLTI13WithWidget
, YesodAuthLTI13(..)
, getLtiIss
, getLtiSub
, getLtiToken
, LtiTokenClaims(..)
, UncheckedLtiTokenClaims(..)
, ContextClaim(..)
, Role(..)
) where
import Yesod.Core.Widget
import Yesod.Auth (Auth, Route(PluginR), setCredsRedirect, Creds(..), authHttpManager, AuthHandler, AuthPlugin(..), YesodAuth)
import Web.LTI13
import qualified Data.Aeson as A
import Data.Text (Text)
import qualified Data.Map.Strict as Map
import Crypto.Random (getRandomBytes)
import qualified Crypto.PubKey.RSA as RSA
import Yesod.Core.Types (TypedContent)
import Yesod.Core (toTypedContent, permissionDenied, setSession, lookupSession, redirect,
deleteSession, lookupSessionBS, setSessionBS, runRequestBody,
getRequest, MonadHandler, SessionMap, notFound, getUrlRender)
import qualified Data.ByteString.Base64.URL as B64
import Web.OIDC.Client.Tokens (IdTokenClaims(..))
import Yesod.Core (YesodRequest(reqGetParams))
import Control.Exception.Safe (Exception, throwIO)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Web.OIDC.Client (Nonce)
import Yesod.Core.Handler (getRouteToParent)
import qualified Data.Text.Encoding as E
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString as BS
import Jose.Jwk (JwkSet(..), Jwk(..), generateRsaKeyPair, KeyUse(Sig))
import Data.Time (getCurrentTime)
import Jose.Jwt (KeyId(UTCKeyId))
import Jose.Jwa (Alg(Signed), JwsAlg(RS256))
data YesodAuthLTI13Exception
= LTIException Text LTI13Exception
| BadRequest Text Text
| CorruptJwks Text Text
deriving (Show)
instance Exception YesodAuthLTI13Exception
dispatchAuthRequest
:: YesodAuthLTI13 master
=> PluginName
-> Text
-> [Text]
-> AuthHandler master TypedContent
dispatchAuthRequest name "GET" ["initiate"] =
unifyParams GET >>= dispatchInitiate name
dispatchAuthRequest name "POST" ["initiate"] =
unifyParams POST >>= dispatchInitiate name
dispatchAuthRequest name "POST" ["authenticate"] =
dispatchAuthenticate name
dispatchAuthRequest name "GET" ["jwks"] =
dispatchJwks name
dispatchAuthRequest _ _ _ = notFound
data Method = GET
| POST
unifyParams
:: MonadHandler m
=> Method
-> m RequestParams
unifyParams GET = do
req <- getRequest
return $ Map.fromList $ reqGetParams req
unifyParams POST = do
(params, _) <- runRequestBody
return $ Map.fromList params
prefixSession :: Text -> Text -> Text
prefixSession name datum =
"_lti13_" <> name <> "_" <> datum
myCid :: Text -> Text
myCid = flip prefixSession $ "clientId"
myIss :: Text -> Text
myIss = flip prefixSession $ "iss"
myState :: Text -> Text
myState = flip prefixSession $ "state"
myNonce :: Text -> Text
myNonce = flip prefixSession $ "nonce"
mkSessionStore :: MonadHandler m => Text -> SessionStore m
mkSessionStore name =
SessionStore
{ sessionStoreGenerate = gen
, sessionStoreSave = sessionSave
, sessionStoreGet = sessionGet
, sessionStoreDelete = sessionDelete
}
where
gen = liftIO $ (B64.encode <$> getRandomBytes 33)
sname = myState name
nname = myNonce name
sessionSave state nonce = do
setSessionBS sname state
setSessionBS nname nonce
return ()
sessionGet = do
state <- lookupSessionBS sname
nonce <- lookupSessionBS nname
return (state, nonce)
sessionDelete = do
deleteSession sname
deleteSession nname
type PluginName = Text
makeCfg
:: MonadHandler m
=> PluginName
-> ((Issuer, Maybe ClientId) -> m PlatformInfo)
-> (Nonce -> m Bool)
-> Text
-> AuthFlowConfig m
makeCfg name pinfo seenNonce callback =
AuthFlowConfig
{ getPlatformInfo = pinfo
, haveSeenNonce = seenNonce
, myRedirectUri = callback
, sessionStore = mkSessionStore name
}
createNewJwk :: IO Jwk
createNewJwk = do
kid <- UTCKeyId <$> getCurrentTime
let use = Sig
alg = Signed RS256
(_, priv) <- generateRsaKeyPair 256 kid use $ Just alg
return priv
dispatchJwks
:: YesodAuthLTI13 master
=> PluginName
-> AuthHandler master TypedContent
dispatchJwks name = do
jwks <- retrieveOrInsertJwks makeJwks
JwkSet privs <- maybe (liftIO $ throwIO $ CorruptJwks name "json decode failed")
pure (A.decodeStrict jwks)
let pubs = JwkSet $ map rsaPrivToPub privs
return $ toTypedContent $ A.toJSON pubs
where makeJwks = (LBS.toStrict . A.encode) <$> makeJwkSet
makeJwkSet = fmap (\jwk -> JwkSet {keys = [jwk]}) createNewJwk
rsaPrivToPub :: Jwk -> Jwk
rsaPrivToPub (RsaPrivateJwk privKey mId mUse mAlg) =
RsaPublicJwk (RSA.private_pub privKey) mId mUse mAlg
rsaPrivToPub _ = error "rsaPrivToPub called on a Jwk that's not a RsaPrivateJwk"
dispatchInitiate
:: YesodAuthLTI13 master
=> PluginName
-> RequestParams
-> AuthHandler master TypedContent
dispatchInitiate name params = do
let url = PluginR name ["authenticate"]
tm <- getRouteToParent
render <- getUrlRender
let authUrl = render $ tm url
let cfg = makeCfg name retrievePlatformInfo checkSeenNonce authUrl
(iss, cid, redir) <- initiate cfg params
setSession (myIss name) iss
setSession (myCid name) cid
redirect redir
type State = Text
checkCSRFToken :: MonadHandler m => State -> Maybe State -> m ()
checkCSRFToken state state' = do
if state' /= Just state then do
permissionDenied "Bad CSRF token"
else
return ()
makeUserId :: Issuer -> Text -> Text
makeUserId iss name = name <> "@@" <> iss
dispatchAuthenticate :: YesodAuthLTI13 m => PluginName -> AuthHandler m TypedContent
dispatchAuthenticate name = do
mgr <- authHttpManager
maybeIss <- lookupSession $ myIss name
iss <- maybe (liftIO . throwIO $ BadRequest name "missing `iss` cookie")
pure
maybeIss
cid <- lookupSession $ myCid name
deleteSession $ myIss name
deleteSession $ myCid name
state' <- lookupSession $ myState name
pinfo <- retrievePlatformInfo (iss, cid)
let cfg = makeCfg name retrievePlatformInfo checkSeenNonce undefined
(params', _) <- runRequestBody
let params = Map.fromList params'
(state, tok) <- handleAuthResponse mgr cfg params pinfo
checkCSRFToken state state'
let LtiTokenClaims ltiClaims = otherClaims tok
ltiClaimsJson = E.decodeUtf8 $ LBS.toStrict $ A.encode ltiClaims
let IdTokenClaims { sub } = tok
myCreds = Creds {
credsPlugin = name
, credsIdent = makeUserId iss sub
, credsExtra = [("ltiIss", iss), ("ltiSub", sub), ("ltiToken", ltiClaimsJson)]
}
setCredsRedirect myCreds
getLtiIss :: SessionMap -> Maybe Issuer
getLtiIss sess =
E.decodeUtf8 <$> Map.lookup "ltiIss" sess
getLtiSub :: SessionMap -> Maybe Issuer
getLtiSub sess =
E.decodeUtf8 <$> Map.lookup "ltiSub" sess
getLtiToken :: SessionMap -> Maybe UncheckedLtiTokenClaims
getLtiToken sess =
(Map.lookup "ltiToken" sess)
>>= A.decodeStrict
class (YesodAuth site)
=> YesodAuthLTI13 site where
checkSeenNonce :: Nonce -> AuthHandler site (Bool)
retrievePlatformInfo :: (Issuer, Maybe ClientId) -> AuthHandler site (PlatformInfo)
retrieveOrInsertJwks
:: (IO BS.ByteString)
-> AuthHandler site (BS.ByteString)
authLTI13 :: YesodAuthLTI13 m => AuthPlugin m
authLTI13 = authLTI13WithWidget login
where
login _ = [whamlet|<p>Go to your Learning Management System to log in via LTI 1.3|]
authLTI13WithWidget :: YesodAuthLTI13 m => ((Route Auth -> Route m) -> WidgetFor m ()) -> AuthPlugin m
authLTI13WithWidget login = do
AuthPlugin name (dispatchAuthRequest name) login
where
name = "lti13"