{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE OverloadedStrings #-}
module AWS.Lambda.RuntimeAPI (
defaultMain,
jsonMain,
autoMockMain,
autoMockJsonMain,
Mock, GenMock (..),
makeMockRequest,
Request,
Response (..),
GenRequest (..),
RequestId,
getTimeRemaining,
jsonResponse,
jsonMediaType,
) where
import Prelude ()
import Prelude.Compat
import Control.Concurrent.Async (waitCatch, withAsync)
import Control.DeepSeq (NFData (..), force)
import Control.Exception (SomeException, displayException, evaluate)
import Control.Monad (forever, void)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.HTTP.Media (MediaType)
import System.Environment (lookupEnv)
import System.IO (hFlush, stdout)
import qualified Data.Aeson as Aeson
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Map.Strict as Map
import qualified Data.Text as Text
import qualified Data.Time.Clock.System.Compat as Time
import qualified Network.HTTP.Client as H
import qualified Network.HTTP.Media as HTTP
import qualified Network.HTTP.Types as HTTP
newtype RequestId = RequestId ByteString
deriving Show
newtype Deadline = Deadline Int64
deriving Show
type Request = GenRequest ByteString
data GenRequest a = Request
{ requestPayload :: !a
, requestId :: !RequestId
, requestXRayTraceId :: !ByteString
, requestClientContext :: !ByteString
, requestCognitoIdentity :: !ByteString
, requestARN :: !ByteString
, requestDeadline :: !Deadline
}
deriving (Show, Functor, Foldable, Traversable)
data Response
= SuccessResponse !MediaType !BS.ByteString
| FailureResponse !Text !Text
deriving Show
instance NFData Response where
rnf (SuccessResponse mt bs) =
rnf (HTTP.mainType mt) `seq`
rnf (HTTP.subType mt) `seq`
rnf (HTTP.parameters mt) `seq`
rnf bs
rnf (FailureResponse ty msg) = rnf msg `seq` rnf ty
jsonResponse :: Aeson.ToJSON a => a -> Response
jsonResponse a = SuccessResponse jsonMediaType $ LBS.toStrict $ Aeson.encode a
getTimeRemaining :: GenRequest a -> IO Int64
getTimeRemaining Request { requestDeadline = Deadline deadline } = do
millis <- getCurrentMillis
return (max 0 (deadline - millis))
getCurrentMillis :: IO Int64
getCurrentMillis = do
Time.MkSystemTime sec nano <- Time.getSystemTime
return $ sec * 1000 + fromIntegral (nano `div` 1000000)
jsonMediaType :: MediaType
jsonMediaType = "application/json"
logError :: ByteString -> IO ()
logError str = do
BS8.putStrLn $ "[ERROR] " <> str
hFlush stdout
logException :: ByteString -> SomeException -> IO ()
logException str exc = do
BS8.putStrLn $ "[ERROR] Exception " <> str <> " -- " <> BS8.pack (displayException exc)
hFlush stdout
logInfo :: ByteString -> IO ()
logInfo _ = return ()
logHttpRequest :: H.Request -> IO ()
logHttpRequest _ = return ()
logHttpResponse :: HttpRes -> IO ()
logHttpResponse _ = return ()
httpLbs :: Ctx -> H.Request -> IO HttpRes
httpLbs ctx req = do
logHttpRequest req
eres <- tryDeep $ toHttpRes <$> H.httpLbs req (ctxManager ctx)
case eres of
Right res -> do
logHttpResponse res
return res
Left exc -> do
logException "httpLbs" exc
return $ HttpRes
{ hrStatus = HTTP.status500
, hrHeaders = mempty
, hrBody = mempty
}
data HttpRes = HttpRes
{ hrStatus :: HTTP.Status
, hrHeaders :: HTTP.ResponseHeaders
, hrBody :: LBS.ByteString
}
instance NFData HttpRes where
rnf (HttpRes (HTTP.Status x y) z w) =
rnf x `seq`
rnf y `seq`
rnf z `seq`
rnf w
toHttpRes :: H.Response LBS.ByteString -> HttpRes
toHttpRes res = HttpRes
{ hrStatus = H.responseStatus res
, hrHeaders = H.responseHeaders res
, hrBody = H.responseBody res
}
tryDeep :: NFData a => IO a -> IO (Either SomeException a)
tryDeep m = withAsync (m >>= evaluate . force) waitCatch
headerRequestId :: HTTP.HeaderName
headerRequestId = "lambda-runtime-aws-request-id"
headerTraceId :: HTTP.HeaderName
headerTraceId = "lambda-runtime-trace-id";
headerClientContext :: HTTP.HeaderName
headerClientContext = "lambda-runtime-client-context";
headerCognitoIdentity :: HTTP.HeaderName
headerCognitoIdentity = "lambda-runtime-cognito-identity";
headerARN :: HTTP.HeaderName
headerARN = "lambda-runtime-invoked-function-arn";
headerDeadline :: HTTP.HeaderName
headerDeadline = "lambda-runtime-deadline-ms";
data Ctx = Ctx
{ ctxManager :: H.Manager
, ctxNextRequest :: H.Request
, ctxSuccessRequest :: H.Request
, ctxErrorRequest :: H.Request
}
initCtx :: IO Ctx
initCtx = do
endpoint <- do
e <- lookupEnv "AWS_LAMBDA_RUNTIME_API"
case e of
Nothing -> fail "AWS_LAMBDA_RUNTIME_API not defined"
Just e' -> return e'
nextReq <- H.parseRequest $ "http://" ++ endpoint ++ "/2018-06-01/runtime/invocation/next"
successReq <- H.parseRequest $ "http://" ++ endpoint ++ "/2018-06-01/runtime/invocation/:request/response"
errorReq <- H.parseRequest $ "http://" ++ endpoint ++ "/2018-06-01/runtime/invocation/:request/error"
mgr <- H.newManager H.defaultManagerSettings
{ H.managerResponseTimeout = H.responseTimeoutNone
}
return Ctx
{ ctxManager = mgr
, ctxNextRequest = nextReq
, ctxSuccessRequest = successReq
{ H.method = "POST"
}
, ctxErrorRequest = errorReq
{ H.method = "POST"
}
}
getNext :: Ctx -> IO (Either HTTP.Status Request)
getNext ctx = do
res <- httpLbs ctx $ ctxNextRequest ctx
let st = hrStatus res
if not $ HTTP.statusIsSuccessful st
then return (Left st)
else do
let headers = Map.fromList $ hrHeaders res
optionalHeader h = fromMaybe mempty $ Map.lookup h headers
case Map.lookup headerRequestId headers of
Nothing -> return $ Left st
Just reqId -> return $ Right Request
{ requestPayload = LBS.toStrict (hrBody res)
, requestId = RequestId reqId
, requestXRayTraceId = optionalHeader headerTraceId
, requestClientContext = optionalHeader headerClientContext
, requestCognitoIdentity = optionalHeader headerCognitoIdentity
, requestARN = optionalHeader headerARN
, requestDeadline = parseDeadline $ optionalHeader headerDeadline
}
where
parseDeadline bs = Deadline $ maybe 0 (fromInteger . fst) (BS8.readInteger bs)
postSuccess :: Ctx -> Request -> Response -> IO ()
postSuccess ctx req (SuccessResponse mt bs) =
void $ httpLbs ctx httpReq
where
httpReq :: H.Request
httpReq = (ctxSuccessRequest ctx)
{ H.path = replaceRequestId (requestId req) $ H.path (ctxSuccessRequest ctx)
, H.requestHeaders =
[ ("Content-Type", HTTP.renderHeader mt)
]
, H.requestBody = H.RequestBodyBS bs
}
postSuccess ctx req (FailureResponse ty msg) =
postFailure ctx req ty msg
postFailure :: Ctx -> Request -> Text -> Text -> IO ()
postFailure ctx req ty msg =
void $ httpLbs ctx httpReq
where
httpReq :: H.Request
httpReq = (ctxSuccessRequest ctx)
{ H.path = replaceRequestId (requestId req) $ H.path (ctxSuccessRequest ctx)
, H.requestHeaders =
[ ("Content-Type", HTTP.renderHeader jsonMediaType)
]
, H.requestBody = H.RequestBodyLBS $ Aeson.encode $ Aeson.object
[ "errorMessage" Aeson..= msg
, "errorType" Aeson..= ty
, "stackTrace" Aeson..= Aeson.Array mempty
]
}
replaceRequestId :: RequestId -> ByteString -> ByteString
replaceRequestId (RequestId rid) path = case BS.breakSubstring ":request" path of
(pfx, sfx) -> pfx <> rid <> BS.drop 8 sfx
defaultMain :: (Request -> IO Response) -> IO ()
defaultMain handler = do
ctx <- initCtx
forever $ do
next <- getNext ctx
case next of
Left _st -> logError "HTTP request was not successful. HTTP response code: %d. Retrying.."
Right req -> do
logInfo "Invoking user handler"
eres <- tryDeep (handler req)
case eres of
Left exc -> do
logException "handler" exc
postFailure ctx req "SomeException" (Text.pack $ displayException exc)
Right res -> postSuccess ctx req res
jsonMain :: (Aeson.FromJSON a, Aeson.ToJSON b) => (GenRequest a -> IO b) -> IO ()
jsonMain handler = defaultMain $ \req ->
case traverse Aeson.eitherDecodeStrict req of
Right req' -> jsonResponse <$> handler req'
Left err -> return $ FailureResponse "decode" (Text.pack err)
type Mock = GenMock ByteString
data GenMock a = Mock
{ mockRequest :: IO (GenRequest a)
, mockResponse :: Response -> IO ()
}
makeMockRequest
:: a
-> Int64
-> IO (GenRequest a)
makeMockRequest payload duration = do
currentMillis <- getCurrentMillis
return $ Request
{ requestPayload = payload
, requestId = RequestId "mock-deadc0de"
, requestXRayTraceId = mempty
, requestClientContext = mempty
, requestCognitoIdentity = mempty
, requestARN = mempty
, requestDeadline = Deadline $ currentMillis + max 1000 duration
}
autoMockMain :: Mock -> (Request -> IO Response) -> IO ()
autoMockMain mock handler = do
e <- lookupEnv "AWS_LAMBDA_RUNTIME_API"
case e of
Just _ -> defaultMain handler
Nothing -> mockRequest mock >>= handler >>= mockResponse mock
autoMockJsonMain :: (Aeson.FromJSON a, Aeson.ToJSON b) => GenMock a -> (GenRequest a -> IO b) -> IO ()
autoMockJsonMain mock handler = do
e <- lookupEnv "AWS_LAMBDA_RUNTIME_API"
case e of
Just _ -> jsonMain handler
Nothing -> mockRequest mock >>= handler >>= mockResponse mock . jsonResponse