mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-04-28 03:00:14 -04:00
feat: max transcript size handling (#319)
* max transcript size handling * adjust default * consolidate handlers * clippy * Update comment Co-authored-by: dan <themighty1@users.noreply.github.com> * add comment on close_notify --------- Co-authored-by: dan <themighty1@users.noreply.github.com>
This commit is contained in:
@@ -26,7 +26,10 @@ pub struct MpcTlsCommonConfig {
|
||||
/// Opaque Rx transcript ID
|
||||
#[builder(setter(into), default = "DEFAULT_OPAQUE_RX_TRANSCRIPT_ID.to_string()")]
|
||||
opaque_rx_transcript_id: String,
|
||||
|
||||
/// Maximum size of the transcript in bytes.
|
||||
/// 16 KiB by default.
|
||||
#[builder(default = "1 << 14")]
|
||||
max_transcript_size: usize,
|
||||
/// Whether the leader commits to the handshake data.
|
||||
#[builder(default = "true")]
|
||||
handshake_commit: bool,
|
||||
@@ -68,6 +71,11 @@ impl MpcTlsCommonConfig {
|
||||
&self.opaque_rx_transcript_id
|
||||
}
|
||||
|
||||
/// Returns the maximum size of the transcript in bytes.
|
||||
pub fn max_transcript_size(&self) -> usize {
|
||||
self.max_transcript_size
|
||||
}
|
||||
|
||||
/// Whether the leader commits to the handshake data.
|
||||
pub fn handshake_commit(&self) -> bool {
|
||||
self.handshake_commit
|
||||
|
||||
@@ -24,6 +24,8 @@ pub enum MpcTlsError {
|
||||
UnexpectedContentType(ContentType),
|
||||
#[error("invalid message length: {0}")]
|
||||
InvalidMessageLength(usize),
|
||||
#[error("maximum transcript length exceeded: {} > {}", .0, .1)]
|
||||
MaxTranscriptLengthExceeded(usize, usize),
|
||||
#[error("unexpected sequence number: {0}")]
|
||||
UnexpectedSequenceNumber(u64),
|
||||
#[error("not set up")]
|
||||
|
||||
@@ -38,6 +38,8 @@ pub struct MpcTlsFollower {
|
||||
decrypter: Decrypter,
|
||||
|
||||
handshake_commitment: Option<Hash>,
|
||||
|
||||
closed: bool,
|
||||
}
|
||||
|
||||
impl MpcTlsFollower {
|
||||
@@ -78,6 +80,7 @@ impl MpcTlsFollower {
|
||||
encrypter,
|
||||
decrypter,
|
||||
handshake_commitment: None,
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,6 +89,11 @@ impl MpcTlsFollower {
|
||||
(self.encrypter.sent_bytes(), self.decrypter.recv_bytes())
|
||||
}
|
||||
|
||||
/// Returns the total number of bytes sent and received.
|
||||
fn total_bytes_transferred(&self) -> usize {
|
||||
self.encrypter.sent_bytes() + self.decrypter.recv_bytes()
|
||||
}
|
||||
|
||||
/// Returns the server's public key
|
||||
pub fn server_key(&self) -> Option<PublicKey> {
|
||||
self.ke.server_key().map(|key| {
|
||||
@@ -182,6 +190,72 @@ impl MpcTlsFollower {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_encrypt_msg(&mut self, msg: EncryptMessage) -> Result<(), MpcTlsError> {
|
||||
let EncryptMessage { typ, seq, len } = msg;
|
||||
|
||||
if self.total_bytes_transferred() + len > self.config.common().max_transcript_size() {
|
||||
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
|
||||
self.total_bytes_transferred() + len,
|
||||
self.config.common().max_transcript_size(),
|
||||
));
|
||||
}
|
||||
|
||||
self.encrypter.encrypt_blind(typ, seq, len).await
|
||||
}
|
||||
|
||||
async fn handle_decrypt_msg(&mut self, msg: DecryptMessage) -> Result<(), MpcTlsError> {
|
||||
let DecryptMessage {
|
||||
typ,
|
||||
explicit_nonce,
|
||||
ciphertext,
|
||||
seq,
|
||||
} = msg;
|
||||
|
||||
if self.total_bytes_transferred() + ciphertext.len()
|
||||
> self.config.common().max_transcript_size()
|
||||
{
|
||||
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
|
||||
self.total_bytes_transferred() + ciphertext.len(),
|
||||
self.config.common().max_transcript_size(),
|
||||
));
|
||||
}
|
||||
|
||||
match typ {
|
||||
ContentType::ApplicationData => {
|
||||
self.decrypter
|
||||
.decrypt_blind(typ, explicit_nonce, ciphertext, seq)
|
||||
.await
|
||||
}
|
||||
ContentType::Alert => {
|
||||
let bytes = self
|
||||
.decrypter
|
||||
.decrypt_public(typ, explicit_nonce, ciphertext, seq)
|
||||
.await?;
|
||||
|
||||
let alert = AlertMessagePayload::read_bytes(&bytes)
|
||||
.ok_or(MpcTlsError::PayloadDecodingError)?;
|
||||
|
||||
if alert.level == AlertLevel::Fatal {
|
||||
return Err(MpcTlsError::ReceivedFatalAlert);
|
||||
}
|
||||
|
||||
if alert.description == AlertDescription::CloseNotify {
|
||||
self.closed = true;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
typ => Err(MpcTlsError::UnexpectedContentType(typ)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_close_notify(&mut self, msg: EncryptMessage) -> Result<(), MpcTlsError> {
|
||||
let EncryptMessage { typ, seq, len } = msg;
|
||||
|
||||
// We could use `encrypt_public` here, but it is not required.
|
||||
self.encrypter.encrypt_blind(typ, seq, len).await
|
||||
}
|
||||
|
||||
/// Runs the follower instance
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
@@ -199,53 +273,23 @@ impl MpcTlsFollower {
|
||||
};
|
||||
|
||||
match msg {
|
||||
MpcTlsMessage::EncryptMessage(EncryptMessage { typ, seq, len }) => {
|
||||
self.encrypter.encrypt_blind(typ, seq, len).await?;
|
||||
MpcTlsMessage::EncryptMessage(msg) => {
|
||||
self.handle_encrypt_msg(msg).await?;
|
||||
}
|
||||
MpcTlsMessage::DecryptMessage(DecryptMessage {
|
||||
typ,
|
||||
explicit_nonce,
|
||||
ciphertext,
|
||||
seq,
|
||||
}) if typ == ContentType::ApplicationData => {
|
||||
self.decrypter
|
||||
.decrypt_blind(typ, explicit_nonce, ciphertext, seq)
|
||||
.await?;
|
||||
MpcTlsMessage::DecryptMessage(msg) => {
|
||||
self.handle_decrypt_msg(msg).await?;
|
||||
}
|
||||
MpcTlsMessage::DecryptMessage(DecryptMessage {
|
||||
typ,
|
||||
explicit_nonce,
|
||||
ciphertext,
|
||||
seq,
|
||||
}) => match typ {
|
||||
ContentType::Alert => {
|
||||
let bytes = self
|
||||
.decrypter
|
||||
.decrypt_public(typ, explicit_nonce, ciphertext, seq)
|
||||
.await?;
|
||||
|
||||
let alert = AlertMessagePayload::read_bytes(&bytes)
|
||||
.ok_or(MpcTlsError::PayloadDecodingError)?;
|
||||
|
||||
if alert.level == AlertLevel::Fatal {
|
||||
return Err(MpcTlsError::ReceivedFatalAlert);
|
||||
}
|
||||
|
||||
if alert.description == AlertDescription::CloseNotify {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(MpcTlsError::UnexpectedContentType(typ));
|
||||
}
|
||||
},
|
||||
MpcTlsMessage::SendCloseNotify(EncryptMessage { typ, seq, len }) => {
|
||||
self.encrypter.encrypt_blind(typ, seq, len).await?;
|
||||
MpcTlsMessage::SendCloseNotify(msg) => {
|
||||
self.handle_close_notify(msg).await?;
|
||||
}
|
||||
msg => {
|
||||
return Err(MpcTlsError::UnexpectedMessage(msg));
|
||||
}
|
||||
}
|
||||
|
||||
if self.closed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -164,6 +164,11 @@ impl MpcTlsLeader {
|
||||
(self.conn_state.sent_bytes, self.conn_state.recv_bytes)
|
||||
}
|
||||
|
||||
/// Returns the total number of bytes sent and received.
|
||||
fn total_bytes_transferred(&self) -> usize {
|
||||
self.conn_state.sent_bytes + self.conn_state.recv_bytes
|
||||
}
|
||||
|
||||
/// Computes the combined key
|
||||
#[cfg_attr(
|
||||
feature = "tracing",
|
||||
@@ -301,6 +306,15 @@ impl MpcTlsLeader {
|
||||
m: PlainMessage,
|
||||
seq: u64,
|
||||
) -> Result<OpaqueMessage, MpcTlsError> {
|
||||
if self.total_bytes_transferred() + m.payload.0.len()
|
||||
> self.config.common().max_transcript_size()
|
||||
{
|
||||
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
|
||||
self.total_bytes_transferred() + m.payload.0.len(),
|
||||
self.config.common().max_transcript_size(),
|
||||
));
|
||||
}
|
||||
|
||||
let explicit_nonce = seq.to_be_bytes().to_vec();
|
||||
|
||||
let aad = make_tls12_aad(seq, m.typ, m.version, m.payload.0.len());
|
||||
@@ -365,6 +379,15 @@ impl MpcTlsLeader {
|
||||
m: OpaqueMessage,
|
||||
seq: u64,
|
||||
) -> Result<PlainMessage, MpcTlsError> {
|
||||
if self.total_bytes_transferred() + m.payload.0.len()
|
||||
> self.config.common().max_transcript_size()
|
||||
{
|
||||
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
|
||||
self.total_bytes_transferred() + m.payload.0.len(),
|
||||
self.config.common().max_transcript_size(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut payload = m.payload.0;
|
||||
|
||||
let explicit_nonce: Vec<u8> = payload.drain(..8).collect();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use mpz_ot::{chou_orlandi, kos};
|
||||
use tls_mpc::{MpcTlsCommonConfig, MpcTlsFollowerConfig};
|
||||
|
||||
const DEFAULT_MAX_TRANSCRIPT_SIZE: usize = 1 << 14; // 16Kb
|
||||
|
||||
@@ -54,6 +55,20 @@ impl NotaryConfig {
|
||||
kos::ReceiverConfig::default()
|
||||
}
|
||||
|
||||
pub(crate) fn build_tls_mpc_config(&self) -> MpcTlsFollowerConfig {
|
||||
MpcTlsFollowerConfig::builder()
|
||||
.common(
|
||||
MpcTlsCommonConfig::builder()
|
||||
.id(format!("{}/mpc_tls", &self.id))
|
||||
.max_transcript_size(self.max_transcript_size)
|
||||
.handshake_commit(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn ot_count(&self) -> usize {
|
||||
self.max_transcript_size * 8
|
||||
}
|
||||
|
||||
@@ -25,9 +25,7 @@ use mpz_ot::{
|
||||
use mpz_share_conversion as ff;
|
||||
use rand::Rng;
|
||||
use signature::Signer;
|
||||
use tls_mpc::{
|
||||
setup_components, MpcTlsCommonConfig, MpcTlsFollower, MpcTlsFollowerConfig, TlsRole,
|
||||
};
|
||||
use tls_mpc::{setup_components, MpcTlsFollower, TlsRole};
|
||||
use tlsn_core::{
|
||||
msg::{SignedSessionHeader, TlsnMessage},
|
||||
signature::Signature,
|
||||
@@ -199,13 +197,9 @@ where
|
||||
#[cfg(feature = "tracing")]
|
||||
info!("Created point addition senders and receivers");
|
||||
|
||||
let common_config = MpcTlsCommonConfig::builder()
|
||||
.id(format!("{}/mpc_tls", &config.id()))
|
||||
.handshake_commit(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
let mpc_config = config.build_tls_mpc_config();
|
||||
let (ke, prf, encrypter, decrypter) = setup_components(
|
||||
&common_config,
|
||||
mpc_config.common(),
|
||||
TlsRole::Follower,
|
||||
&mut mux,
|
||||
&mut vm,
|
||||
@@ -217,18 +211,8 @@ where
|
||||
.await
|
||||
.map_err(|e| NotaryError::MpcError(Box::new(e)))?;
|
||||
|
||||
let channel = mux.get_channel(common_config.id()).await?;
|
||||
let mut mpc_tls = MpcTlsFollower::new(
|
||||
MpcTlsFollowerConfig::builder()
|
||||
.common(common_config)
|
||||
.build()
|
||||
.unwrap(),
|
||||
channel,
|
||||
ke,
|
||||
prf,
|
||||
encrypter,
|
||||
decrypter,
|
||||
);
|
||||
let channel = mux.get_channel(mpc_config.common().id()).await?;
|
||||
let mut mpc_tls = MpcTlsFollower::new(mpc_config, channel, ke, prf, encrypter, decrypter);
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
info!("Finished setting up notary components");
|
||||
|
||||
@@ -45,6 +45,7 @@ impl ProverConfig {
|
||||
.common(
|
||||
MpcTlsCommonConfig::builder()
|
||||
.id(format!("{}/mpc_tls", &self.id))
|
||||
.max_transcript_size(self.max_transcript_size)
|
||||
.handshake_commit(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
|
||||
Reference in New Issue
Block a user