From 76842128ea75dad9987a44adcc344f68a672885d Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Wed, 6 Sep 2023 08:46:23 -0700 Subject: [PATCH] feat: max transcript size handling (#319) * max transcript size handling * adjust default * consolidate handlers * clippy * Update comment Co-authored-by: dan * add comment on close_notify --------- Co-authored-by: dan --- components/tls/tls-mpc/src/config.rs | 10 +- components/tls/tls-mpc/src/error.rs | 2 + components/tls/tls-mpc/src/follower.rs | 124 +++++++++++++++++-------- components/tls/tls-mpc/src/leader.rs | 23 +++++ tlsn/tlsn-notary/src/config.rs | 15 +++ tlsn/tlsn-notary/src/lib.rs | 26 +----- tlsn/tlsn-prover/src/config.rs | 1 + 7 files changed, 139 insertions(+), 62 deletions(-) diff --git a/components/tls/tls-mpc/src/config.rs b/components/tls/tls-mpc/src/config.rs index 4b1446b05..d2b16dcaf 100644 --- a/components/tls/tls-mpc/src/config.rs +++ b/components/tls/tls-mpc/src/config.rs @@ -26,7 +26,10 @@ pub struct MpcTlsCommonConfig { /// Opaque Rx transcript ID #[builder(setter(into), default = "DEFAULT_OPAQUE_RX_TRANSCRIPT_ID.to_string()")] opaque_rx_transcript_id: String, - + /// Maximum size of the transcript in bytes. + /// 16 KiB by default. + #[builder(default = "1 << 14")] + max_transcript_size: usize, /// Whether the leader commits to the handshake data. #[builder(default = "true")] handshake_commit: bool, @@ -68,6 +71,11 @@ impl MpcTlsCommonConfig { &self.opaque_rx_transcript_id } + /// Returns the maximum size of the transcript in bytes. + pub fn max_transcript_size(&self) -> usize { + self.max_transcript_size + } + /// Whether the leader commits to the handshake data. pub fn handshake_commit(&self) -> bool { self.handshake_commit diff --git a/components/tls/tls-mpc/src/error.rs b/components/tls/tls-mpc/src/error.rs index fd548906d..74aba4616 100644 --- a/components/tls/tls-mpc/src/error.rs +++ b/components/tls/tls-mpc/src/error.rs @@ -24,6 +24,8 @@ pub enum MpcTlsError { UnexpectedContentType(ContentType), #[error("invalid message length: {0}")] InvalidMessageLength(usize), + #[error("maximum transcript length exceeded: {} > {}", .0, .1)] + MaxTranscriptLengthExceeded(usize, usize), #[error("unexpected sequence number: {0}")] UnexpectedSequenceNumber(u64), #[error("not set up")] diff --git a/components/tls/tls-mpc/src/follower.rs b/components/tls/tls-mpc/src/follower.rs index f3c22d06b..33871f9f3 100644 --- a/components/tls/tls-mpc/src/follower.rs +++ b/components/tls/tls-mpc/src/follower.rs @@ -38,6 +38,8 @@ pub struct MpcTlsFollower { decrypter: Decrypter, handshake_commitment: Option, + + closed: bool, } impl MpcTlsFollower { @@ -78,6 +80,7 @@ impl MpcTlsFollower { encrypter, decrypter, handshake_commitment: None, + closed: false, } } @@ -86,6 +89,11 @@ impl MpcTlsFollower { (self.encrypter.sent_bytes(), self.decrypter.recv_bytes()) } + /// Returns the total number of bytes sent and received. + fn total_bytes_transferred(&self) -> usize { + self.encrypter.sent_bytes() + self.decrypter.recv_bytes() + } + /// Returns the server's public key pub fn server_key(&self) -> Option { self.ke.server_key().map(|key| { @@ -182,6 +190,72 @@ impl MpcTlsFollower { Ok(()) } + async fn handle_encrypt_msg(&mut self, msg: EncryptMessage) -> Result<(), MpcTlsError> { + let EncryptMessage { typ, seq, len } = msg; + + if self.total_bytes_transferred() + len > self.config.common().max_transcript_size() { + return Err(MpcTlsError::MaxTranscriptLengthExceeded( + self.total_bytes_transferred() + len, + self.config.common().max_transcript_size(), + )); + } + + self.encrypter.encrypt_blind(typ, seq, len).await + } + + async fn handle_decrypt_msg(&mut self, msg: DecryptMessage) -> Result<(), MpcTlsError> { + let DecryptMessage { + typ, + explicit_nonce, + ciphertext, + seq, + } = msg; + + if self.total_bytes_transferred() + ciphertext.len() + > self.config.common().max_transcript_size() + { + return Err(MpcTlsError::MaxTranscriptLengthExceeded( + self.total_bytes_transferred() + ciphertext.len(), + self.config.common().max_transcript_size(), + )); + } + + match typ { + ContentType::ApplicationData => { + self.decrypter + .decrypt_blind(typ, explicit_nonce, ciphertext, seq) + .await + } + ContentType::Alert => { + let bytes = self + .decrypter + .decrypt_public(typ, explicit_nonce, ciphertext, seq) + .await?; + + let alert = AlertMessagePayload::read_bytes(&bytes) + .ok_or(MpcTlsError::PayloadDecodingError)?; + + if alert.level == AlertLevel::Fatal { + return Err(MpcTlsError::ReceivedFatalAlert); + } + + if alert.description == AlertDescription::CloseNotify { + self.closed = true; + } + + Ok(()) + } + typ => Err(MpcTlsError::UnexpectedContentType(typ)), + } + } + + async fn handle_close_notify(&mut self, msg: EncryptMessage) -> Result<(), MpcTlsError> { + let EncryptMessage { typ, seq, len } = msg; + + // We could use `encrypt_public` here, but it is not required. + self.encrypter.encrypt_blind(typ, seq, len).await + } + /// Runs the follower instance #[cfg_attr( feature = "tracing", @@ -199,53 +273,23 @@ impl MpcTlsFollower { }; match msg { - MpcTlsMessage::EncryptMessage(EncryptMessage { typ, seq, len }) => { - self.encrypter.encrypt_blind(typ, seq, len).await?; + MpcTlsMessage::EncryptMessage(msg) => { + self.handle_encrypt_msg(msg).await?; } - MpcTlsMessage::DecryptMessage(DecryptMessage { - typ, - explicit_nonce, - ciphertext, - seq, - }) if typ == ContentType::ApplicationData => { - self.decrypter - .decrypt_blind(typ, explicit_nonce, ciphertext, seq) - .await?; + MpcTlsMessage::DecryptMessage(msg) => { + self.handle_decrypt_msg(msg).await?; } - MpcTlsMessage::DecryptMessage(DecryptMessage { - typ, - explicit_nonce, - ciphertext, - seq, - }) => match typ { - ContentType::Alert => { - let bytes = self - .decrypter - .decrypt_public(typ, explicit_nonce, ciphertext, seq) - .await?; - - let alert = AlertMessagePayload::read_bytes(&bytes) - .ok_or(MpcTlsError::PayloadDecodingError)?; - - if alert.level == AlertLevel::Fatal { - return Err(MpcTlsError::ReceivedFatalAlert); - } - - if alert.description == AlertDescription::CloseNotify { - break; - } - } - _ => { - return Err(MpcTlsError::UnexpectedContentType(typ)); - } - }, - MpcTlsMessage::SendCloseNotify(EncryptMessage { typ, seq, len }) => { - self.encrypter.encrypt_blind(typ, seq, len).await?; + MpcTlsMessage::SendCloseNotify(msg) => { + self.handle_close_notify(msg).await?; } msg => { return Err(MpcTlsError::UnexpectedMessage(msg)); } } + + if self.closed { + break; + } } Ok(()) diff --git a/components/tls/tls-mpc/src/leader.rs b/components/tls/tls-mpc/src/leader.rs index ddbeb9087..d0100705d 100644 --- a/components/tls/tls-mpc/src/leader.rs +++ b/components/tls/tls-mpc/src/leader.rs @@ -164,6 +164,11 @@ impl MpcTlsLeader { (self.conn_state.sent_bytes, self.conn_state.recv_bytes) } + /// Returns the total number of bytes sent and received. + fn total_bytes_transferred(&self) -> usize { + self.conn_state.sent_bytes + self.conn_state.recv_bytes + } + /// Computes the combined key #[cfg_attr( feature = "tracing", @@ -301,6 +306,15 @@ impl MpcTlsLeader { m: PlainMessage, seq: u64, ) -> Result { + if self.total_bytes_transferred() + m.payload.0.len() + > self.config.common().max_transcript_size() + { + return Err(MpcTlsError::MaxTranscriptLengthExceeded( + self.total_bytes_transferred() + m.payload.0.len(), + self.config.common().max_transcript_size(), + )); + } + let explicit_nonce = seq.to_be_bytes().to_vec(); let aad = make_tls12_aad(seq, m.typ, m.version, m.payload.0.len()); @@ -365,6 +379,15 @@ impl MpcTlsLeader { m: OpaqueMessage, seq: u64, ) -> Result { + if self.total_bytes_transferred() + m.payload.0.len() + > self.config.common().max_transcript_size() + { + return Err(MpcTlsError::MaxTranscriptLengthExceeded( + self.total_bytes_transferred() + m.payload.0.len(), + self.config.common().max_transcript_size(), + )); + } + let mut payload = m.payload.0; let explicit_nonce: Vec = payload.drain(..8).collect(); diff --git a/tlsn/tlsn-notary/src/config.rs b/tlsn/tlsn-notary/src/config.rs index 0579142a6..41c3035bc 100644 --- a/tlsn/tlsn-notary/src/config.rs +++ b/tlsn/tlsn-notary/src/config.rs @@ -1,4 +1,5 @@ use mpz_ot::{chou_orlandi, kos}; +use tls_mpc::{MpcTlsCommonConfig, MpcTlsFollowerConfig}; const DEFAULT_MAX_TRANSCRIPT_SIZE: usize = 1 << 14; // 16Kb @@ -54,6 +55,20 @@ impl NotaryConfig { kos::ReceiverConfig::default() } + pub(crate) fn build_tls_mpc_config(&self) -> MpcTlsFollowerConfig { + MpcTlsFollowerConfig::builder() + .common( + MpcTlsCommonConfig::builder() + .id(format!("{}/mpc_tls", &self.id)) + .max_transcript_size(self.max_transcript_size) + .handshake_commit(true) + .build() + .unwrap(), + ) + .build() + .unwrap() + } + pub(crate) fn ot_count(&self) -> usize { self.max_transcript_size * 8 } diff --git a/tlsn/tlsn-notary/src/lib.rs b/tlsn/tlsn-notary/src/lib.rs index ec4f7e6f4..93b6e4111 100644 --- a/tlsn/tlsn-notary/src/lib.rs +++ b/tlsn/tlsn-notary/src/lib.rs @@ -25,9 +25,7 @@ use mpz_ot::{ use mpz_share_conversion as ff; use rand::Rng; use signature::Signer; -use tls_mpc::{ - setup_components, MpcTlsCommonConfig, MpcTlsFollower, MpcTlsFollowerConfig, TlsRole, -}; +use tls_mpc::{setup_components, MpcTlsFollower, TlsRole}; use tlsn_core::{ msg::{SignedSessionHeader, TlsnMessage}, signature::Signature, @@ -199,13 +197,9 @@ where #[cfg(feature = "tracing")] info!("Created point addition senders and receivers"); - let common_config = MpcTlsCommonConfig::builder() - .id(format!("{}/mpc_tls", &config.id())) - .handshake_commit(true) - .build() - .unwrap(); + let mpc_config = config.build_tls_mpc_config(); let (ke, prf, encrypter, decrypter) = setup_components( - &common_config, + mpc_config.common(), TlsRole::Follower, &mut mux, &mut vm, @@ -217,18 +211,8 @@ where .await .map_err(|e| NotaryError::MpcError(Box::new(e)))?; - let channel = mux.get_channel(common_config.id()).await?; - let mut mpc_tls = MpcTlsFollower::new( - MpcTlsFollowerConfig::builder() - .common(common_config) - .build() - .unwrap(), - channel, - ke, - prf, - encrypter, - decrypter, - ); + let channel = mux.get_channel(mpc_config.common().id()).await?; + let mut mpc_tls = MpcTlsFollower::new(mpc_config, channel, ke, prf, encrypter, decrypter); #[cfg(feature = "tracing")] info!("Finished setting up notary components"); diff --git a/tlsn/tlsn-prover/src/config.rs b/tlsn/tlsn-prover/src/config.rs index 00766aa49..1e13ab2b6 100644 --- a/tlsn/tlsn-prover/src/config.rs +++ b/tlsn/tlsn-prover/src/config.rs @@ -45,6 +45,7 @@ impl ProverConfig { .common( MpcTlsCommonConfig::builder() .id(format!("{}/mpc_tls", &self.id)) + .max_transcript_size(self.max_transcript_size) .handshake_commit(true) .build() .unwrap(),