feat(mpc-tls): improve error message for incorrect transcript config (#754)

* feat(mpc-tls): improve error message for incorrect transcript config

* rustfmt

---------

Co-authored-by: dan <themighty1@users.noreply.github.com>
This commit is contained in:
sinu.eth
2025-04-07 17:44:02 +07:00
committed by GitHub
parent a34dd57926
commit 93535ca955
4 changed files with 55 additions and 3 deletions

View File

@@ -50,7 +50,7 @@ impl ConfigBuilder {
let mut max_recv_online = self
.max_recv_online
.ok_or(ConfigBuilderError::UninitializedField("max_recv_online"))?;
let max_recv = self
let mut max_recv = self
.max_recv
.ok_or(ConfigBuilderError::UninitializedField("max_recv"))?;
@@ -61,6 +61,7 @@ impl ConfigBuilder {
}
max_recv_online += MIN_RECV;
max_recv += MIN_RECV;
let max_sent_records = self
.max_sent_records

View File

@@ -135,6 +135,7 @@ impl MpcTlsFollower {
self.config.max_recv_records,
self.config.max_sent,
self.config.max_recv_online,
self.config.max_recv,
)?;
(keys, cf_vd, sf_vd)

View File

@@ -169,6 +169,7 @@ impl MpcTlsLeader {
self.config.max_recv_records,
self.config.max_sent,
self.config.max_recv_online,
self.config.max_recv,
)?;
self.state = State::Setup {

View File

@@ -80,6 +80,18 @@ pub(crate) struct RecordLayer {
state: State,
/// Whether the record layer has started processing application data.
started: bool,
/// Number of bytes sent.
sent: usize,
/// Number of bytes received and decrypted online.
recv_online: usize,
/// Number of bytes received.
recv: usize,
/// Maximum number of bytes sent.
max_sent: usize,
/// Maximum number of bytes received to be decrypted online.
max_recv_online: usize,
/// Maximum number of bytes received.
max_recv: usize,
encrypt_buffer: Vec<EncryptOp>,
decrypt_buffer: Vec<DecryptOp>,
@@ -99,6 +111,12 @@ impl RecordLayer {
aes_ctr: AesCtr::new(role),
state: State::Init,
started: false,
sent: 0,
recv_online: 0,
recv: 0,
max_sent: 0,
max_recv_online: 0,
max_recv: 0,
encrypt_buffer: Vec::new(),
decrypt_buffer: Vec::new(),
encrypted_buffer: VecDeque::new(),
@@ -114,6 +132,8 @@ impl RecordLayer {
/// * `sent_records` - Number of sent records to allocate.
/// * `recv_records` - Number of received records to allocate.
/// * `sent_len` - Total length of sent records to allocate.
/// * `recv_len_online` - Total length of received records to be decrypted
/// online.
/// * `recv_len` - Total length of received records to allocate.
pub(crate) fn alloc(
&mut self,
@@ -121,6 +141,7 @@ impl RecordLayer {
sent_records: usize,
recv_records: usize,
sent_len: usize,
recv_len_online: usize,
recv_len: usize,
) -> Result<(), MpcTlsError> {
let State::Init = self.state.take() else {
@@ -142,12 +163,12 @@ impl RecordLayer {
.map_err(MpcTlsError::record_layer)?;
decrypt
.alloc(vm, recv_records, recv_len)
.alloc(vm, recv_records, recv_len_online)
.map_err(MpcTlsError::record_layer)?;
let recv_otp = match self.role {
Role::Leader => {
let mut recv_otp = vec![0u8; recv_len];
let mut recv_otp = vec![0u8; recv_len_online];
rand::rng().fill_bytes(&mut recv_otp);
Some(recv_otp)
@@ -157,6 +178,10 @@ impl RecordLayer {
self.aes_ctr.alloc(vm)?;
self.max_sent += sent_len;
self.max_recv_online += recv_len_online;
self.max_recv += recv_len;
self.state = State::Online {
recv_otp,
sent_records: Vec::new(),
@@ -267,9 +292,15 @@ impl RecordLayer {
) -> Result<(), MpcTlsError> {
if self.encrypt_buffer.len() >= MAX_BUFFER_SIZE {
return Err(MpcTlsError::peer("encrypt buffer is full"));
} else if self.sent + len > self.max_sent {
return Err(MpcTlsError::record_layer(format!(
"attempted to send more data than was configured, increase `max_sent` in the config: current={}, additional={}, max={}",
self.sent, len, self.max_sent
)));
}
let (seq, explicit_nonce, aad) = self.next_write(typ, version, len);
self.sent += len;
self.encrypt_buffer.push(EncryptOp::new(
seq,
typ,
@@ -295,9 +326,15 @@ impl RecordLayer {
) -> Result<(), MpcTlsError> {
if self.decrypt_buffer.len() >= MAX_BUFFER_SIZE {
return Err(MpcTlsError::peer("decrypt buffer is full"));
} else if self.recv + ciphertext.len() > self.max_recv {
return Err(MpcTlsError::record_layer(format!(
"attempted to receive more data than was configured, increase `max_recv` in the config: current={}, additional={}, max={}",
self.recv, ciphertext.len(), self.max_recv
)));
}
let (seq, aad) = self.next_read(typ, version, ciphertext.len());
self.recv += ciphertext.len();
self.decrypt_buffer.push(DecryptOp::new(
seq,
typ,
@@ -386,6 +423,18 @@ impl RecordLayer {
return Ok(());
}
if is_decrypting {
let decrypt_len: usize = decrypt_ops.iter().map(|op| op.ciphertext.len()).sum();
if self.recv_online + decrypt_len > self.max_recv_online {
return Err(MpcTlsError::record_layer(format!(
"attempted to decrypt more data in the online phase than was configured, increase `max_recv_online` in the config: current={}, additional={}, max={}",
self.recv_online, decrypt_len, self.max_recv_online
)));
} else {
self.recv_online += decrypt_len;
}
}
debug!(
"processing {} encrypt ops and {} decrypt ops",
encrypt_ops.len(),