{-# LANGUAGE TemplateHaskell, ScopedTypeVariables, GeneralizedNewtypeDeriving, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} ------------------------------------------------------ -- | -- Module : Crypto.Noise.Internal.Handshake.State -- Maintainer : John Galt -- Stability : experimental -- Portability : POSIX module Crypto.Noise.Internal.Handshake.State where import Control.Lens import Control.Monad.Coroutine import Control.Monad.Coroutine.SuspensionFunctors import Control.Monad.Catch.Pure import Control.Monad.State (MonadState(..), StateT) import Control.Monad.Trans.Class (MonadTrans(lift)) import Data.ByteArray (ScrubbedBytes, convert) import Data.ByteString (ByteString) import Data.Monoid ((<>)) import Data.Proxy import Crypto.Noise.Cipher import Crypto.Noise.DH import Crypto.Noise.Hash import Crypto.Noise.Internal.Handshake.Pattern hiding (ss) import Crypto.Noise.Internal.SymmetricState -- | Represents the side of the conversation upon which a party resides. data HandshakeRole = InitiatorRole | ResponderRole deriving (Show, Eq) -- | Represents the various options and keys for a handshake parameterized by -- the 'DH' method. data HandshakeOpts d = HandshakeOpts { _hoRole :: HandshakeRole , _hoPrologue :: Plaintext , _hoLocalEphemeral :: Maybe (KeyPair d) , _hoLocalStatic :: Maybe (KeyPair d) , _hoRemoteEphemeral :: Maybe (PublicKey d) , _hoRemoteStatic :: Maybe (PublicKey d) } $(makeLenses ''HandshakeOpts) -- | Holds all state associated with the interpreter. data HandshakeState c d h = HandshakeState { _hsSymmetricState :: SymmetricState c h , _hsOpts :: HandshakeOpts d , _hsPSKMode :: Bool , _hsMsgBuffer :: ScrubbedBytes } $(makeLenses ''HandshakeState) -- | This data structure is yielded by the coroutine when more data is needed. data HandshakeResult = HandshakeResultMessage ScrubbedBytes | HandshakeResultNeedPSK -- | All HandshakePattern interpreters run within this Monad. newtype Handshake c d h r = Handshake { runHandshake :: Coroutine (Request HandshakeResult ScrubbedBytes) (StateT (HandshakeState c d h) Catch) r } deriving ( Functor , Applicative , Monad , MonadThrow , MonadState (HandshakeState c d h) ) -- | @defaultHandshakeOpts role prologue@ returns a default set of handshake -- options. All keys are set to 'Nothing'. defaultHandshakeOpts :: HandshakeRole -> Plaintext -> HandshakeOpts d defaultHandshakeOpts r p = HandshakeOpts { _hoRole = r , _hoPrologue = p , _hoLocalEphemeral = Nothing , _hoLocalStatic = Nothing , _hoRemoteEphemeral = Nothing , _hoRemoteStatic = Nothing } -- | Sets the local ephemeral key. setLocalEphemeral :: Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d setLocalEphemeral k opts = opts { _hoLocalEphemeral = k } -- | Sets the local static key. setLocalStatic :: Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d setLocalStatic k opts = opts { _hoLocalStatic = k } -- | Sets the remote ephemeral key (rarely needed). setRemoteEphemeral :: Maybe (PublicKey d) -> HandshakeOpts d -> HandshakeOpts d setRemoteEphemeral k opts = opts { _hoRemoteEphemeral = k } -- | Sets the remote static key. setRemoteStatic :: Maybe (PublicKey d) -> HandshakeOpts d -> HandshakeOpts d setRemoteStatic k opts = opts { _hoRemoteStatic = k } -- | Given a protocol name, returns the full handshake name according to the -- rules in section 8. mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h) => ByteString -> proxy (c, d, h) -> ScrubbedBytes mkHandshakeName protoName _ = "Noise_" <> convert protoName <> "_" <> d <> "_" <> c <> "_" <> h where c = cipherName (Proxy :: Proxy c) d = dhName (Proxy :: Proxy d) h = hashName (Proxy :: Proxy h) -- | Constructs a HandshakeState from a given set of options and a protocol -- name (such as "NN" or "IK"). handshakeState :: forall c d h. (Cipher c, DH d, Hash h) => HandshakeOpts d -> HandshakePattern -> HandshakeState c d h handshakeState ho hp = HandshakeState { _hsSymmetricState = ss' , _hsOpts = ho , _hsPSKMode = hp ^. hpPSKMode , _hsMsgBuffer = mempty } where ss = symmetricState $ mkHandshakeName (hp ^. hpName) (Proxy :: Proxy (c, d, h)) ss' = mixHash (ho ^. hoPrologue) ss instance (Functor f, MonadThrow m) => MonadThrow (Coroutine f m) where throwM = lift . throwM instance (Functor f, MonadState s m) => MonadState s (Coroutine f m) where get = lift get put = lift . put state = lift . state