{-# Language OverloadedStrings, ImportQualifiedPost #-}
module Client.Authentication.Ecdh
  (
    -- * Phase type
    Phase1,
    -- * Mechanism details
    mechanismName,
    -- * Transition functions
    clientFirst,
    clientResponse,
  ) where

import Control.Monad (guard)
import Crypto.Curve25519.Pure qualified as Curve
import Data.Bits (xor)
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Base64 qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as Text
import Irc.Commands (AuthenticatePayload (AuthenticatePayload))
import OpenSSL.EVP.Digest (digestBS, getDigestByName, hmacBS, Digest)
import System.IO.Unsafe (unsafePerformIO)

newtype Phase1 = Phase1 Curve.PrivateKey

mechanismName :: Text
mechanismName :: Text
mechanismName = Text
"ECDH-X25519-CHALLENGE"

clientFirst :: Maybe Text -> Text -> Text -> Maybe (AuthenticatePayload, Phase1)
clientFirst :: Maybe Text -> Text -> Text -> Maybe (AuthenticatePayload, Phase1)
clientFirst Maybe Text
mbAuthz Text
authc Text
privateKeyText =
  case ByteString -> Maybe PrivateKey
Curve.importPrivate forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either String ByteString
B64.decode (Text -> ByteString
Text.encodeUtf8 Text
privateKeyText) of
    Right (Just PrivateKey
private) -> forall a. a -> Maybe a
Just (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
payload, PrivateKey -> Phase1
Phase1 PrivateKey
private)
    Either String (Maybe PrivateKey)
_ -> forall a. Maybe a
Nothing
  where
    payload :: ByteString
payload =
      case Maybe Text
mbAuthz of
        Maybe Text
Nothing    -> Text -> ByteString
Text.encodeUtf8 Text
authc
        Just Text
authz -> Text -> ByteString
Text.encodeUtf8 Text
authc forall a. Semigroup a => a -> a -> a
<> ByteString
"\0" forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
Text.encodeUtf8 Text
authz

clientResponse ::
  Phase1 ->
  ByteString                {- ^ server response  -} ->
  Maybe AuthenticatePayload {- ^ client response  -}
clientResponse :: Phase1 -> ByteString -> Maybe AuthenticatePayload
clientResponse (Phase1 PrivateKey
privateKey) ByteString
serverMessage = 
  do let (ByteString
serverPubBS, (ByteString
sessionSalt, ByteString
maskedChallenge)) =
           Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
32 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
32 ByteString
serverMessage
     forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> Int
B.length ByteString
maskedChallenge forall a. Eq a => a -> a -> Bool
== Int
32)
     PublicKey
serverPublic <- ByteString -> Maybe PublicKey
Curve.importPublic ByteString
serverPubBS

     let sharedSecret :: ByteString
sharedSecret = PrivateKey -> PublicKey -> ByteString
Curve.makeShared PrivateKey
privateKey PublicKey
serverPublic
     let clientPublic :: PublicKey
clientPublic = PrivateKey -> PublicKey
Curve.generatePublic PrivateKey
privateKey
     let ikm :: ByteString
ikm = Digest -> ByteString -> ByteString
digestBS Digest
sha256
             forall a b. (a -> b) -> a -> b
$ ByteString
sharedSecret forall a. Semigroup a => a -> a -> a
<> PublicKey -> ByteString
Curve.exportPublic PublicKey
clientPublic forall a. Semigroup a => a -> a -> a
<> ByteString
serverPubBS
     let prk :: ByteString
prk = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
sha256 ByteString
sessionSalt ByteString
ikm
     let betterSecret :: ByteString
betterSecret = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
sha256 ByteString
prk ByteString
"ECDH-X25519-CHALLENGE\1"
     let sessionChallenge :: ByteString
sessionChallenge = [Word8] -> ByteString
B.pack (forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith forall a. Bits a => a -> a -> a
xor ByteString
maskedChallenge ByteString
betterSecret)
     forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$! ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
sessionChallenge

sha256 :: Digest
Just Digest
sha256 = forall a. IO a -> a
unsafePerformIO (String -> IO (Maybe Digest)
getDigestByName String
"SHA256")