{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Crypto.Keys (
    defaultCipher,
    initialSecrets,
    clientInitialSecret,
    serverInitialSecret,
    aeadKey,
    initialVector,
    nextSecret,
    headerProtectionKey,
) where

import Network.TLS hiding (Version)
import Network.TLS.Extra.Cipher
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Crypto.Types
import Network.QUIC.Imports
import Network.QUIC.Types

----------------------------------------------------------------

defaultCipher :: Cipher
defaultCipher :: Cipher
defaultCipher = Cipher
cipher_TLS13_AES128GCM_SHA256

----------------------------------------------------------------

initialSalt :: Version -> Salt
initialSalt :: Version -> ByteString
initialSalt Version
Draft29 =
    ByteString
"\xaf\xbf\xec\x28\x99\x93\xd2\x4c\x9e\x97\x86\xf1\x9c\x61\x11\xe0\x43\x90\xa8\x99"
initialSalt Version
Version1 =
    ByteString
"\x38\x76\x2c\xf7\xf5\x59\x34\xb3\x4d\x17\x9a\xe6\xa4\xc8\x0c\xad\xcc\xbb\x7f\x0a"
initialSalt Version
Version2 =
    ByteString
"\x0d\xed\xe3\xde\xf7\x00\xa6\xdb\x81\x93\x81\xbe\x6e\x26\x9d\xcb\xf9\xbd\x2e\xd9"
initialSalt (Version Word32
v) = QUICException -> ByteString
forall e a. Exception e => e -> a
E.impureThrow (QUICException -> ByteString) -> QUICException -> ByteString
forall a b. (a -> b) -> a -> b
$ Word32 -> QUICException
VersionIsUnknown Word32
v

initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
v CID
c = (Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret Version
v CID
c, Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret Version
v CID
c)

clientInitialSecret :: Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret :: Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret Version
v CID
c = ByteString -> ClientTrafficSecret InitialSecret
forall a. ByteString -> ClientTrafficSecret a
ClientTrafficSecret (ByteString -> ClientTrafficSecret InitialSecret)
-> ByteString -> ClientTrafficSecret InitialSecret
forall a b. (a -> b) -> a -> b
$ Version -> CID -> Label -> ByteString
initialSecret Version
v CID
c (Label -> ByteString) -> Label -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"client in"

serverInitialSecret :: Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret :: Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret Version
v CID
c = ByteString -> ServerTrafficSecret InitialSecret
forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret (ByteString -> ServerTrafficSecret InitialSecret)
-> ByteString -> ServerTrafficSecret InitialSecret
forall a b. (a -> b) -> a -> b
$ Version -> CID -> Label -> ByteString
initialSecret Version
v CID
c (Label -> ByteString) -> Label -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"server in"

initialSecret :: Version -> CID -> Label -> ByteString
initialSecret :: Version -> CID -> Label -> ByteString
initialSecret Version
Draft29 = ByteString -> CID -> Label -> ByteString
initialSecret' (ByteString -> CID -> Label -> ByteString)
-> ByteString -> CID -> Label -> ByteString
forall a b. (a -> b) -> a -> b
$ Version -> ByteString
initialSalt Version
Draft29
initialSecret Version
Version1 = ByteString -> CID -> Label -> ByteString
initialSecret' (ByteString -> CID -> Label -> ByteString)
-> ByteString -> CID -> Label -> ByteString
forall a b. (a -> b) -> a -> b
$ Version -> ByteString
initialSalt Version
Version1
initialSecret Version
Version2 = ByteString -> CID -> Label -> ByteString
initialSecret' (ByteString -> CID -> Label -> ByteString)
-> ByteString -> CID -> Label -> ByteString
forall a b. (a -> b) -> a -> b
$ Version -> ByteString
initialSalt Version
Version2
initialSecret Version
_ = \CID
_ Label
_ -> ByteString
"not supported"

initialSecret' :: ByteString -> CID -> Label -> ByteString
initialSecret' :: ByteString -> CID -> Label -> ByteString
initialSecret' ByteString
salt CID
cid (Label ByteString
label) = ByteString
secret
  where
    cipher :: Cipher
cipher = Cipher
defaultCipher
    hash :: Hash
hash = Cipher -> Hash
cipherHash Cipher
cipher
    iniSecret :: ByteString
iniSecret = Hash -> ByteString -> ByteString -> ByteString
hkdfExtract Hash
hash ByteString
salt (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ CID -> ByteString
fromCID CID
cid
    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
hash
    secret :: ByteString
secret = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
iniSecret ByteString
label ByteString
"" Int
hashSize

aeadKey :: Version -> Cipher -> Secret -> Key
aeadKey :: Version -> Cipher -> Secret -> Key
aeadKey Version
Draft29 = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic key"
aeadKey Version
Version1 = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic key"
aeadKey Version
Version2 = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quicv2 key"
aeadKey Version
_ = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"not supported"

headerProtectionKey :: Version -> Cipher -> Secret -> Key
headerProtectionKey :: Version -> Cipher -> Secret -> Key
headerProtectionKey Version
Draft29 = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic hp"
headerProtectionKey Version
Version1 = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic hp"
headerProtectionKey Version
Version2 = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quicv2 hp"
headerProtectionKey Version
_ = Label -> Cipher -> Secret -> Key
genKey (Label -> Cipher -> Secret -> Key)
-> Label -> Cipher -> Secret -> Key
forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"not supported"

genKey :: Label -> Cipher -> Secret -> Key
genKey :: Label -> Cipher -> Secret -> Key
genKey (Label ByteString
label) Cipher
cipher (Secret ByteString
secret) = ByteString -> Key
Key ByteString
key
  where
    hash :: Hash
hash = Cipher -> Hash
cipherHash Cipher
cipher
    bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
    keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
    key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secret ByteString
label ByteString
"" Int
keySize

initialVector :: Version -> Cipher -> Secret -> IV
initialVector :: Version -> Cipher -> Secret -> IV
initialVector Version
ver Cipher
cipher (Secret ByteString
secret) = ByteString -> IV
IV ByteString
iv
  where
    label :: ByteString
label = Version -> ByteString
ivLabel Version
ver
    hash :: Hash
hash = Cipher -> Hash
cipherHash Cipher
cipher
    bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
    ivSize :: Int
ivSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
    iv :: ByteString
iv = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secret ByteString
label ByteString
"" Int
ivSize

ivLabel :: Version -> ByteString
ivLabel :: Version -> ByteString
ivLabel Version
Draft29 = ByteString
"quic iv"
ivLabel Version
Version1 = ByteString
"quic iv"
ivLabel Version
Version2 = ByteString
"quicv2 iv"
ivLabel Version
_ = ByteString
"not supported"

nextSecret :: Version -> Cipher -> Secret -> Secret
nextSecret :: Version -> Cipher -> Secret -> Secret
nextSecret Version
ver Cipher
cipher (Secret ByteString
secN) = ByteString -> Secret
Secret ByteString
secN1
  where
    label :: ByteString
label = Version -> ByteString
kuLabel Version
ver
    hash :: Hash
hash = Cipher -> Hash
cipherHash Cipher
cipher
    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
hash
    secN1 :: ByteString
secN1 = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secN ByteString
label ByteString
"" Int
hashSize

kuLabel :: Version -> ByteString
kuLabel :: Version -> ByteString
kuLabel Version
Draft29 = ByteString
"quic ku"
kuLabel Version
Version1 = ByteString
"quic ku"
kuLabel Version
Version2 = ByteString
"quicv2 ku"
kuLabel Version
_ = ByteString
"not supported"