{-# LANGUAGE RecordWildCards, OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts, DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Web.WebPush
(
generateVAPIDKeys
, readVAPIDKeys
, vapidPublicKeyBytes
, sendPushNotification
, pushEndpoint
, pushP256dh
, pushAuth
, pushSenderEmail
, pushExpireInSeconds
, pushMessage
, mkPushNotification
, VAPIDKeys
, VAPIDKeysMinDetails(..)
, PushNotification
, PushNotificationMessage(..)
, PushNotificationError(..)
, PushEndpoint
, PushP256dh
, PushAuth
) where
import Web.WebPush.Internal
import Crypto.Random (MonadRandom(getRandomBytes))
import Control.Exception (Exception)
import Control.Lens ((^.), Lens', Lens, lens)
import qualified Crypto.PubKey.ECC.Types as ECC
import qualified Crypto.PubKey.ECC.Generate as ECC
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.ECC.DH as ECDH
import qualified Data.Bits as Bits
import Data.Word (Word8)
import GHC.Generics (Generic)
import GHC.Int (Int64)
import qualified Data.List as L
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.ByteString.Lazy as LB
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.Aeson as A
import qualified Data.ByteString.Base64.URL as B64.URL
import Network.HTTP.Client (Manager, httpLbs, parseUrlThrow, HttpException(HttpExceptionRequest)
, HttpExceptionContent(StatusCodeException), RequestBody(..)
, requestBody, requestHeaders, method, responseStatus)
import Network.HTTP.Types (hContentType, hAuthorization, hContentEncoding)
import Network.HTTP.Types.Status (Status(statusCode))
import Crypto.Error (CryptoError)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Exception.Base (SomeException(..), fromException, toException, throw)
import Control.Exception.Safe (tryAny, handleAny)
import System.Random (randomRIO)
generateVAPIDKeys :: MonadRandom m => m VAPIDKeysMinDetails
generateVAPIDKeys = do
(pubKey, privKey) <- ECC.generate $ ECC.getCurveByName ECC.SEC_p256r1
let ECC.Point pubX pubY = ECDSA.public_q pubKey
return $ VAPIDKeysMinDetails { privateNumber = ECDSA.private_d privKey
, publicCoordX = pubX
, publicCoordY = pubY
}
readVAPIDKeys :: VAPIDKeysMinDetails -> VAPIDKeys
readVAPIDKeys VAPIDKeysMinDetails {..} =
let vapidPublicKeyPoint = ECC.Point publicCoordX publicCoordY
in ECDSA.KeyPair (ECC.getCurveByName ECC.SEC_p256r1) vapidPublicKeyPoint privateNumber
vapidPublicKeyBytes :: VAPIDKeys -> [Word8]
vapidPublicKeyBytes keys =
let ECC.Point vapidPublicKeyX vapidPublicKeyY = ECDSA.public_q $ ECDSA.toPublicKey keys
in 4 : ( (extract32Bytes vapidPublicKeyX) ++ (extract32Bytes vapidPublicKeyY) )
where
extract32Bytes :: Integer -> [Word8]
extract32Bytes number = snd $ L.foldl' (\(integer, bytes) _ -> (Bits.shiftR integer 8, (fromIntegral integer) : bytes))
(number, ([] :: [Word8]))
([1..32] :: [Int])
sendPushNotification :: (MonadIO m, A.ToJSON msg)
=> VAPIDKeys
-> Manager
-> PushNotification msg
-> m (Either PushNotificationError ())
sendPushNotification vapidKeys httpManager pushNotification = do
result <- liftIO $ tryAny $ do
initReq <- handleAny (throw . EndpointParseFailed) $ parseUrlThrow . T.unpack $ pushNotification ^. pushEndpoint
jwt <- webPushJWT vapidKeys initReq (pushNotification ^. pushSenderEmail)
ecdhServerPrivateKey <- ECDH.generatePrivate $ ECC.getCurveByName ECC.SEC_p256r1
randSalt <- getRandomBytes 16
padLen <- randomRIO (0, 20)
let authSecretBytes = B64.URL.decodeLenient . TE.encodeUtf8 $ pushNotification ^. pushAuth
subscriptionPublicKeyBytes = B64.URL.decodeLenient . TE.encodeUtf8 $ pushNotification ^. pushP256dh
plainMessage64Encoded = A.encode . A.toJSON $ pushNotification ^. pushMessage
encryptionInput =
EncryptionInput
{ applicationServerPrivateKey = ecdhServerPrivateKey
, userAgentPublicKeyBytes = subscriptionPublicKeyBytes
, authenticationSecret = authSecretBytes
, salt = randSalt
, plainText = plainMessage64Encoded
, paddingLength = padLen
}
eitherEncryptionOutput = webPushEncrypt encryptionInput
encryptionOutput <- either (throw . toException . MessageEncryptionFailed) pure eitherEncryptionOutput
let ecdhServerPublicKeyBytes = LB.toStrict . ecPublicKeyToBytes . ECDH.calculatePublic (ECC.getCurveByName ECC.SEC_p256r1) $ ecdhServerPrivateKey
authorizationHeader = LB.toStrict $ "WebPush " <> jwt
cryptoKeyHeader = BS.concat [ "dh=", b64UrlNoPadding ecdhServerPublicKeyBytes
, ";"
, "p256ecdsa=", b64UrlNoPadding vapidPublicKeyBytestring
]
postHeaders = [ ("TTL", C8.pack $ show $ pushNotification ^. pushExpireInSeconds)
, (hContentType, "application/octet-stream")
, (hAuthorization, authorizationHeader)
, ("Crypto-Key", cryptoKeyHeader)
, (hContentEncoding, "aesgcm")
, ("Encryption", "salt=" <> (b64UrlNoPadding randSalt))
]
request = initReq { method = "POST"
, requestHeaders = postHeaders ++
(filter (\(x, _) -> L.notElem x $ map fst postHeaders)
(requestHeaders initReq)
)
, requestBody = RequestBodyBS $ encryptedMessage encryptionOutput
}
httpLbs request $ httpManager
return $ either (Left . onError) (Right . (const ())) result
where
vapidPublicKeyBytestring = LB.toStrict . ecPublicKeyToBytes . ECDSA.public_q . ECDSA.toPublicKey $ vapidKeys
onError :: SomeException -> PushNotificationError
onError err
| Just (x :: PushNotificationError) <- fromException err = x
| Just (HttpExceptionRequest _ (StatusCodeException resp _)) <- fromException err = case statusCode (responseStatus resp) of
404 -> RecepientEndpointNotFound
410 -> RecepientEndpointNotFound
_ -> PushRequestFailed err
| otherwise = PushRequestFailed err
type PushEndpoint = T.Text
type PushP256dh = T.Text
type PushAuth = T.Text
data PushNotification msg = PushNotification { _pnEndpoint :: PushEndpoint
, _pnP256dh :: PushP256dh
, _pnAuth :: PushAuth
, _pnSenderEmail :: T.Text
, _pnExpireInSeconds :: Int64
, _pnMessage :: msg
}
pushEndpoint :: Lens' (PushNotification msg) PushEndpoint
pushEndpoint = lens _pnEndpoint (\d v -> d {_pnEndpoint = v})
pushP256dh :: Lens' (PushNotification msg) PushP256dh
pushP256dh = lens _pnP256dh (\d v -> d {_pnP256dh = v})
pushAuth :: Lens' (PushNotification msg) PushAuth
pushAuth = lens _pnAuth (\d v -> d {_pnAuth = v})
pushSenderEmail :: Lens' (PushNotification msg) T.Text
pushSenderEmail = lens _pnSenderEmail (\d v -> d {_pnSenderEmail = v})
pushExpireInSeconds :: Lens' (PushNotification msg) Int64
pushExpireInSeconds = lens _pnExpireInSeconds (\d v -> d {_pnExpireInSeconds = v})
pushMessage :: (A.ToJSON msg) => Lens (PushNotification a) (PushNotification msg) a msg
pushMessage = lens _pnMessage (\d v -> d {_pnMessage = v})
mkPushNotification :: PushEndpoint -> PushP256dh -> PushAuth -> PushNotification ()
mkPushNotification endpoint p256dh auth =
PushNotification {
_pnEndpoint = endpoint
, _pnP256dh = p256dh
, _pnAuth = auth
, _pnSenderEmail = ""
, _pnExpireInSeconds = 3600
, _pnMessage = ()
}
data PushNotificationMessage = PushNotificationMessage
{ title :: T.Text
, body :: T.Text
, icon :: T.Text
, url :: T.Text
, tag :: T.Text
} deriving (Eq, Show, Generic, A.ToJSON)
data VAPIDKeysMinDetails = VAPIDKeysMinDetails { privateNumber :: Integer
, publicCoordX :: Integer
, publicCoordY :: Integer
} deriving (Show)
data PushNotificationError = EndpointParseFailed SomeException
| MessageEncryptionFailed CryptoError
| RecepientEndpointNotFound
| PushRequestFailed SomeException
deriving (Show, Exception)