module Web.Minion.Examples.Jwt (app) where

import Control.Monad (forM_)
import Control.Monad.Reader (MonadIO (liftIO), ReaderT (runReaderT), asks)
import Crypto.JOSE (JWK, bestJWSAlg, fromOctets, newJWSHeader, runJOSE)
import Crypto.JWT (JWTError, encodeCompact, signJWT)
import Data.Aeson (FromJSON, ToJSON)
import Data.ByteString.Lazy qualified as Bytes.Lazy
import Data.Functor (($>))
import Data.Text.Encoding qualified
import Data.Text.IO qualified
import GHC.Generics (Generic)
import Network.HTTP.Types.Status qualified as Http
import System.Environment (getArgs)
import Web.Minion

import Web.Minion.Auth.Jwt

type M = ReaderT Env IO

newtype JwtUserInfo = JwtUserInfo {JwtUserInfo -> UserId
userId :: UserId}
  deriving ((forall x. JwtUserInfo -> Rep JwtUserInfo x)
-> (forall x. Rep JwtUserInfo x -> JwtUserInfo)
-> Generic JwtUserInfo
forall x. Rep JwtUserInfo x -> JwtUserInfo
forall x. JwtUserInfo -> Rep JwtUserInfo x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. JwtUserInfo -> Rep JwtUserInfo x
from :: forall x. JwtUserInfo -> Rep JwtUserInfo x
$cto :: forall x. Rep JwtUserInfo x -> JwtUserInfo
to :: forall x. Rep JwtUserInfo x -> JwtUserInfo
Generic)
  deriving anyclass (Maybe JwtUserInfo
Value -> Parser [JwtUserInfo]
Value -> Parser JwtUserInfo
(Value -> Parser JwtUserInfo)
-> (Value -> Parser [JwtUserInfo])
-> Maybe JwtUserInfo
-> FromJSON JwtUserInfo
forall a.
(Value -> Parser a)
-> (Value -> Parser [a]) -> Maybe a -> FromJSON a
$cparseJSON :: Value -> Parser JwtUserInfo
parseJSON :: Value -> Parser JwtUserInfo
$cparseJSONList :: Value -> Parser [JwtUserInfo]
parseJSONList :: Value -> Parser [JwtUserInfo]
$comittedField :: Maybe JwtUserInfo
omittedField :: Maybe JwtUserInfo
FromJSON, [JwtUserInfo] -> Value
[JwtUserInfo] -> Encoding
JwtUserInfo -> Bool
JwtUserInfo -> Value
JwtUserInfo -> Encoding
(JwtUserInfo -> Value)
-> (JwtUserInfo -> Encoding)
-> ([JwtUserInfo] -> Value)
-> ([JwtUserInfo] -> Encoding)
-> (JwtUserInfo -> Bool)
-> ToJSON JwtUserInfo
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> (a -> Bool)
-> ToJSON a
$ctoJSON :: JwtUserInfo -> Value
toJSON :: JwtUserInfo -> Value
$ctoEncoding :: JwtUserInfo -> Encoding
toEncoding :: JwtUserInfo -> Encoding
$ctoJSONList :: [JwtUserInfo] -> Value
toJSONList :: [JwtUserInfo] -> Value
$ctoEncodingList :: [JwtUserInfo] -> Encoding
toEncodingList :: [JwtUserInfo] -> Encoding
$comitField :: JwtUserInfo -> Bool
omitField :: JwtUserInfo -> Bool
ToJSON)

newtype UserId = UserId Int
  deriving newtype (Maybe UserId
Value -> Parser [UserId]
Value -> Parser UserId
(Value -> Parser UserId)
-> (Value -> Parser [UserId]) -> Maybe UserId -> FromJSON UserId
forall a.
(Value -> Parser a)
-> (Value -> Parser [a]) -> Maybe a -> FromJSON a
$cparseJSON :: Value -> Parser UserId
parseJSON :: Value -> Parser UserId
$cparseJSONList :: Value -> Parser [UserId]
parseJSONList :: Value -> Parser [UserId]
$comittedField :: Maybe UserId
omittedField :: Maybe UserId
FromJSON, Int -> UserId -> ShowS
[UserId] -> ShowS
UserId -> String
(Int -> UserId -> ShowS)
-> (UserId -> String) -> ([UserId] -> ShowS) -> Show UserId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UserId -> ShowS
showsPrec :: Int -> UserId -> ShowS
$cshow :: UserId -> String
show :: UserId -> String
$cshowList :: [UserId] -> ShowS
showList :: [UserId] -> ShowS
Show, ReadPrec [UserId]
ReadPrec UserId
Int -> ReadS UserId
ReadS [UserId]
(Int -> ReadS UserId)
-> ReadS [UserId]
-> ReadPrec UserId
-> ReadPrec [UserId]
-> Read UserId
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS UserId
readsPrec :: Int -> ReadS UserId
$creadList :: ReadS [UserId]
readList :: ReadS [UserId]
$creadPrec :: ReadPrec UserId
readPrec :: ReadPrec UserId
$creadListPrec :: ReadPrec [UserId]
readListPrec :: ReadPrec [UserId]
Read, [UserId] -> Value
[UserId] -> Encoding
UserId -> Bool
UserId -> Value
UserId -> Encoding
(UserId -> Value)
-> (UserId -> Encoding)
-> ([UserId] -> Value)
-> ([UserId] -> Encoding)
-> (UserId -> Bool)
-> ToJSON UserId
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> (a -> Bool)
-> ToJSON a
$ctoJSON :: UserId -> Value
toJSON :: UserId -> Value
$ctoEncoding :: UserId -> Encoding
toEncoding :: UserId -> Encoding
$ctoJSONList :: [UserId] -> Value
toJSONList :: [UserId] -> Value
$ctoEncodingList :: [UserId] -> Encoding
toEncodingList :: [UserId] -> Encoding
$comitField :: UserId -> Bool
omitField :: UserId -> Bool
ToJSON)

newtype Env = Env
  {Env -> HList '[JwtAuthSettings M JwtUserInfo UserId]
authCtx :: HList '[JwtAuthSettings M JwtUserInfo UserId]}

app :: IO (ApplicationM IO)
app :: IO (ApplicationM IO)
app = do
  IO ()
showJwts
  ApplicationM IO -> IO (ApplicationM IO)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ApplicationM IO -> IO (ApplicationM IO))
-> ApplicationM IO -> IO (ApplicationM IO)
forall a b. (a -> b) -> a -> b
$ \Request
req Response -> IO ResponseReceived
resp ->
    ReaderT Env IO ResponseReceived -> Env -> IO ResponseReceived
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Router' Void Void M -> ApplicationM M
forall (m :: * -> *) i.
(MonadIO m, MonadCatch m) =>
Router' i Void m -> ApplicationM m
serve Router' Void Void M
api Request
req Response -> IO ResponseReceived
resp) (HList '[JwtAuthSettings M JwtUserInfo UserId] -> Env
Env (HList '[JwtAuthSettings M JwtUserInfo UserId] -> Env)
-> HList '[JwtAuthSettings M JwtUserInfo UserId] -> Env
forall a b. (a -> b) -> a -> b
$ JwtAuthSettings M JwtUserInfo UserId
jwtSettings JwtAuthSettings M JwtUserInfo UserId
-> HList '[] -> HList '[JwtAuthSettings M JwtUserInfo UserId]
forall t (ts1 :: [*]). t -> HList ts1 -> HList (t : ts1)
:# HList '[]
HNil)

api :: Router Void M
api :: Router' Void Void M
api = Router' Void Void M -> Router' Void Void M
"api" (Router' Void Void M -> Router' Void Void M)
-> Router' Void Void M -> Router' Void Void M
forall i ts (r :: * -> *).
(Router' i ts r -> Router' i ts r)
-> Router' i ts r -> Router' i ts r
/> Router' Void Void M -> Router' Void Void M
"auth" (Router' Void Void M -> Router' Void Void M)
-> Router' Void Void M -> Router' Void Void M
forall i ts (r :: * -> *).
(Router' i ts r -> Router' i ts r)
-> Router' i ts r -> Router' i ts r
/> ValueCombinator
  Void (WithReq M (Auth '[Bearer JwtUserInfo] UserId)) Void M
forall ts.
ValueCombinator
  Void (WithReq M (Auth '[Bearer JwtUserInfo] UserId)) ts M
myAuth ValueCombinator
  Void (WithReq M (Auth '[Bearer JwtUserInfo] UserId)) Void M
-> ValueCombinator
     Void (WithReq M (Auth '[Bearer JwtUserInfo] UserId)) Void M
forall i ts' (r :: * -> *) ts.
(Router' i ts' r -> Router' i ts r)
-> Router' i ts' r -> Router' i ts r
.> Method
-> (DelayedArgs '[WithReq M (Auth '[Bearer JwtUserInfo] UserId)]
    ~> M NoBody)
-> Router'
     Void (Void :+ WithReq M (Auth '[Bearer JwtUserInfo] UserId)) M
forall o (m :: * -> *) ts i (st :: [*]).
(HandleArgs ts st m, ToResponse m o, CanRespond o,
 Introspection i 'Response o) =>
Method -> (DelayedArgs st ~> m o) -> Router' i ts m
handle Method
GET DelayedArgs '[WithReq M (Auth '[Bearer JwtUserInfo] UserId)]
~> M NoBody
UserId -> M NoBody
authEndpoint

authEndpoint :: UserId -> ReaderT Env IO NoBody
authEndpoint :: UserId -> M NoBody
authEndpoint UserId
userId = IO () -> ReaderT Env IO ()
forall a. IO a -> ReaderT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"User " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> UserId -> String
forall a. Show a => a -> String
show UserId
userId) ReaderT Env IO () -> NoBody -> M NoBody
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> NoBody
NoBody

myAuth :: ValueCombinator Void (WithReq M (Auth '[Bearer JwtUserInfo] UserId)) ts M
myAuth :: forall ts.
ValueCombinator
  Void (WithReq M (Auth '[Bearer JwtUserInfo] UserId)) ts M
myAuth = forall (auths :: [*]) a (m :: * -> *) (ctx :: [*]) ts i.
(Introspection i 'Request (Auth auths a), UnwindAuth ctx auths m a,
 MonadThrow m) =>
m (HList ctx)
-> (MakeError -> AuthResult Void -> m Void)
-> ValueCombinator i (WithReq m (Auth auths a)) ts m
auth @'[Bearer JwtUserInfo] @UserId ((Env -> HList '[JwtAuthSettings M JwtUserInfo UserId])
-> M (HList '[JwtAuthSettings M JwtUserInfo UserId])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> HList '[JwtAuthSettings M JwtUserInfo UserId]
authCtx) \MakeError
makeError -> \case
  AuthResult Void
_ -> ServerError -> M Void
forall e a. (HasCallStack, Exception e) => e -> M a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (ServerError -> M Void) -> ServerError -> M Void
forall a b. (a -> b) -> a -> b
$ MakeError
makeError Status
Http.status401 ByteString
forall a. Monoid a => a
mempty

jwtSettings :: JwtAuthSettings M JwtUserInfo UserId
jwtSettings :: JwtAuthSettings M JwtUserInfo UserId
jwtSettings = M JWK
-> (StringOrURI -> Bool)
-> (MakeError
    -> Either JWTError (JwtPayload JwtUserInfo)
    -> M (AuthResult UserId))
-> JwtAuthSettings M JwtUserInfo UserId
forall (m :: * -> *) payload a.
MonadIO m =>
m JWK
-> (StringOrURI -> Bool)
-> (MakeError
    -> Either JWTError (JwtPayload payload) -> m (AuthResult a))
-> JwtAuthSettings m payload a
defaultJwtAuthSettings (JWK -> M JWK
forall a. a -> M a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JWK
myJwk) (Bool -> StringOrURI -> Bool
forall a b. a -> b -> a
const Bool
True) do
  (Either JWTError (JwtPayload JwtUserInfo) -> M (AuthResult UserId))
-> MakeError
-> Either JWTError (JwtPayload JwtUserInfo)
-> M (AuthResult UserId)
forall a b. a -> b -> a
const (AuthResult UserId -> M (AuthResult UserId)
forall a. a -> M a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AuthResult UserId -> M (AuthResult UserId))
-> (Either JWTError (JwtPayload JwtUserInfo) -> AuthResult UserId)
-> Either JWTError (JwtPayload JwtUserInfo)
-> M (AuthResult UserId)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JWTError -> AuthResult UserId)
-> (JwtPayload JwtUserInfo -> AuthResult UserId)
-> Either JWTError (JwtPayload JwtUserInfo)
-> AuthResult UserId
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (AuthResult UserId -> JWTError -> AuthResult UserId
forall a b. a -> b -> a
const AuthResult UserId
forall a. AuthResult a
BadAuth) (\JwtPayload{$sel:payload:JwtPayload :: forall a. JwtPayload a -> a
payload = JwtUserInfo{UserId
$sel:userId:JwtUserInfo :: JwtUserInfo -> UserId
userId :: UserId
..}} -> UserId -> AuthResult UserId
forall a. a -> AuthResult a
Authenticated UserId
userId))

myJwk :: JWK
myJwk :: JWK
myJwk = forall s. Cons s s Word8 Word8 => s -> JWK
fromOctets @Bytes.Lazy.ByteString ByteString
"really secret and long enough key"

showJwts :: IO ()
showJwts :: IO ()
showJwts = do
  [UserId]
userIds <- (String -> UserId) -> [String] -> [UserId]
forall a b. (a -> b) -> [a] -> [b]
map (forall a. Read a => String -> a
read @UserId) ([String] -> [UserId]) -> IO [String] -> IO [UserId]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO [String]
getArgs
  [UserId] -> (UserId -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [UserId]
userIds \UserId
userId -> do
    Right SignedJWT
jwt <- forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
runJOSE @JWTError do
      Alg
alg <- JWK -> JOSE JWTError IO Alg
forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
bestJWSAlg JWK
myJwk
      JWK -> JWSHeader () -> JwtUserInfo -> JOSE JWTError IO SignedJWT
forall (m :: * -> *) e payload.
(MonadRandom m, MonadError e m, AsError e, ToJSON payload) =>
JWK -> JWSHeader () -> payload -> m SignedJWT
signJWT JWK
myJwk (((), Alg) -> JWSHeader ()
forall p. (p, Alg) -> JWSHeader p
newJWSHeader ((), Alg
alg)) (UserId -> JwtUserInfo
JwtUserInfo UserId
userId)
    let jwtTxt :: Text
jwtTxt = Method -> Text
Data.Text.Encoding.decodeUtf8 (Method -> Text) -> (ByteString -> Method) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Method
Bytes.Lazy.toStrict (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ SignedJWT -> ByteString
forall a. ToCompact a => a -> ByteString
encodeCompact SignedJWT
jwt
    String -> IO ()
putStr (UserId -> String
forall a. Show a => a -> String
show UserId
userId String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
": ") IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Text -> IO ()
Data.Text.IO.putStrLn Text
jwtTxt