feat: max transcript size handling (#319)

* max transcript size handling

* adjust default

* consolidate handlers

* clippy

* Update comment

Co-authored-by: dan <themighty1@users.noreply.github.com>

* add comment on close_notify

---------

Co-authored-by: dan <themighty1@users.noreply.github.com>
This commit is contained in:
sinu.eth
2023-09-06 08:46:23 -07:00
committed by GitHub
parent 832d1baec1
commit 76842128ea
7 changed files with 139 additions and 62 deletions

View File

@@ -26,7 +26,10 @@ pub struct MpcTlsCommonConfig {
/// Opaque Rx transcript ID
#[builder(setter(into), default = "DEFAULT_OPAQUE_RX_TRANSCRIPT_ID.to_string()")]
opaque_rx_transcript_id: String,
/// Maximum size of the transcript in bytes.
/// 16 KiB by default.
#[builder(default = "1 << 14")]
max_transcript_size: usize,
/// Whether the leader commits to the handshake data.
#[builder(default = "true")]
handshake_commit: bool,
@@ -68,6 +71,11 @@ impl MpcTlsCommonConfig {
&self.opaque_rx_transcript_id
}
/// Returns the maximum size of the transcript in bytes.
pub fn max_transcript_size(&self) -> usize {
self.max_transcript_size
}
/// Whether the leader commits to the handshake data.
pub fn handshake_commit(&self) -> bool {
self.handshake_commit

View File

@@ -24,6 +24,8 @@ pub enum MpcTlsError {
UnexpectedContentType(ContentType),
#[error("invalid message length: {0}")]
InvalidMessageLength(usize),
#[error("maximum transcript length exceeded: {} > {}", .0, .1)]
MaxTranscriptLengthExceeded(usize, usize),
#[error("unexpected sequence number: {0}")]
UnexpectedSequenceNumber(u64),
#[error("not set up")]

View File

@@ -38,6 +38,8 @@ pub struct MpcTlsFollower {
decrypter: Decrypter,
handshake_commitment: Option<Hash>,
closed: bool,
}
impl MpcTlsFollower {
@@ -78,6 +80,7 @@ impl MpcTlsFollower {
encrypter,
decrypter,
handshake_commitment: None,
closed: false,
}
}
@@ -86,6 +89,11 @@ impl MpcTlsFollower {
(self.encrypter.sent_bytes(), self.decrypter.recv_bytes())
}
/// Returns the total number of bytes sent and received.
fn total_bytes_transferred(&self) -> usize {
self.encrypter.sent_bytes() + self.decrypter.recv_bytes()
}
/// Returns the server's public key
pub fn server_key(&self) -> Option<PublicKey> {
self.ke.server_key().map(|key| {
@@ -182,6 +190,72 @@ impl MpcTlsFollower {
Ok(())
}
async fn handle_encrypt_msg(&mut self, msg: EncryptMessage) -> Result<(), MpcTlsError> {
let EncryptMessage { typ, seq, len } = msg;
if self.total_bytes_transferred() + len > self.config.common().max_transcript_size() {
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
self.total_bytes_transferred() + len,
self.config.common().max_transcript_size(),
));
}
self.encrypter.encrypt_blind(typ, seq, len).await
}
async fn handle_decrypt_msg(&mut self, msg: DecryptMessage) -> Result<(), MpcTlsError> {
let DecryptMessage {
typ,
explicit_nonce,
ciphertext,
seq,
} = msg;
if self.total_bytes_transferred() + ciphertext.len()
> self.config.common().max_transcript_size()
{
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
self.total_bytes_transferred() + ciphertext.len(),
self.config.common().max_transcript_size(),
));
}
match typ {
ContentType::ApplicationData => {
self.decrypter
.decrypt_blind(typ, explicit_nonce, ciphertext, seq)
.await
}
ContentType::Alert => {
let bytes = self
.decrypter
.decrypt_public(typ, explicit_nonce, ciphertext, seq)
.await?;
let alert = AlertMessagePayload::read_bytes(&bytes)
.ok_or(MpcTlsError::PayloadDecodingError)?;
if alert.level == AlertLevel::Fatal {
return Err(MpcTlsError::ReceivedFatalAlert);
}
if alert.description == AlertDescription::CloseNotify {
self.closed = true;
}
Ok(())
}
typ => Err(MpcTlsError::UnexpectedContentType(typ)),
}
}
async fn handle_close_notify(&mut self, msg: EncryptMessage) -> Result<(), MpcTlsError> {
let EncryptMessage { typ, seq, len } = msg;
// We could use `encrypt_public` here, but it is not required.
self.encrypter.encrypt_blind(typ, seq, len).await
}
/// Runs the follower instance
#[cfg_attr(
feature = "tracing",
@@ -199,53 +273,23 @@ impl MpcTlsFollower {
};
match msg {
MpcTlsMessage::EncryptMessage(EncryptMessage { typ, seq, len }) => {
self.encrypter.encrypt_blind(typ, seq, len).await?;
MpcTlsMessage::EncryptMessage(msg) => {
self.handle_encrypt_msg(msg).await?;
}
MpcTlsMessage::DecryptMessage(DecryptMessage {
typ,
explicit_nonce,
ciphertext,
seq,
}) if typ == ContentType::ApplicationData => {
self.decrypter
.decrypt_blind(typ, explicit_nonce, ciphertext, seq)
.await?;
MpcTlsMessage::DecryptMessage(msg) => {
self.handle_decrypt_msg(msg).await?;
}
MpcTlsMessage::DecryptMessage(DecryptMessage {
typ,
explicit_nonce,
ciphertext,
seq,
}) => match typ {
ContentType::Alert => {
let bytes = self
.decrypter
.decrypt_public(typ, explicit_nonce, ciphertext, seq)
.await?;
let alert = AlertMessagePayload::read_bytes(&bytes)
.ok_or(MpcTlsError::PayloadDecodingError)?;
if alert.level == AlertLevel::Fatal {
return Err(MpcTlsError::ReceivedFatalAlert);
}
if alert.description == AlertDescription::CloseNotify {
break;
}
}
_ => {
return Err(MpcTlsError::UnexpectedContentType(typ));
}
},
MpcTlsMessage::SendCloseNotify(EncryptMessage { typ, seq, len }) => {
self.encrypter.encrypt_blind(typ, seq, len).await?;
MpcTlsMessage::SendCloseNotify(msg) => {
self.handle_close_notify(msg).await?;
}
msg => {
return Err(MpcTlsError::UnexpectedMessage(msg));
}
}
if self.closed {
break;
}
}
Ok(())

View File

@@ -164,6 +164,11 @@ impl MpcTlsLeader {
(self.conn_state.sent_bytes, self.conn_state.recv_bytes)
}
/// Returns the total number of bytes sent and received.
fn total_bytes_transferred(&self) -> usize {
self.conn_state.sent_bytes + self.conn_state.recv_bytes
}
/// Computes the combined key
#[cfg_attr(
feature = "tracing",
@@ -301,6 +306,15 @@ impl MpcTlsLeader {
m: PlainMessage,
seq: u64,
) -> Result<OpaqueMessage, MpcTlsError> {
if self.total_bytes_transferred() + m.payload.0.len()
> self.config.common().max_transcript_size()
{
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
self.total_bytes_transferred() + m.payload.0.len(),
self.config.common().max_transcript_size(),
));
}
let explicit_nonce = seq.to_be_bytes().to_vec();
let aad = make_tls12_aad(seq, m.typ, m.version, m.payload.0.len());
@@ -365,6 +379,15 @@ impl MpcTlsLeader {
m: OpaqueMessage,
seq: u64,
) -> Result<PlainMessage, MpcTlsError> {
if self.total_bytes_transferred() + m.payload.0.len()
> self.config.common().max_transcript_size()
{
return Err(MpcTlsError::MaxTranscriptLengthExceeded(
self.total_bytes_transferred() + m.payload.0.len(),
self.config.common().max_transcript_size(),
));
}
let mut payload = m.payload.0;
let explicit_nonce: Vec<u8> = payload.drain(..8).collect();

View File

@@ -1,4 +1,5 @@
use mpz_ot::{chou_orlandi, kos};
use tls_mpc::{MpcTlsCommonConfig, MpcTlsFollowerConfig};
const DEFAULT_MAX_TRANSCRIPT_SIZE: usize = 1 << 14; // 16Kb
@@ -54,6 +55,20 @@ impl NotaryConfig {
kos::ReceiverConfig::default()
}
pub(crate) fn build_tls_mpc_config(&self) -> MpcTlsFollowerConfig {
MpcTlsFollowerConfig::builder()
.common(
MpcTlsCommonConfig::builder()
.id(format!("{}/mpc_tls", &self.id))
.max_transcript_size(self.max_transcript_size)
.handshake_commit(true)
.build()
.unwrap(),
)
.build()
.unwrap()
}
pub(crate) fn ot_count(&self) -> usize {
self.max_transcript_size * 8
}

View File

@@ -25,9 +25,7 @@ use mpz_ot::{
use mpz_share_conversion as ff;
use rand::Rng;
use signature::Signer;
use tls_mpc::{
setup_components, MpcTlsCommonConfig, MpcTlsFollower, MpcTlsFollowerConfig, TlsRole,
};
use tls_mpc::{setup_components, MpcTlsFollower, TlsRole};
use tlsn_core::{
msg::{SignedSessionHeader, TlsnMessage},
signature::Signature,
@@ -199,13 +197,9 @@ where
#[cfg(feature = "tracing")]
info!("Created point addition senders and receivers");
let common_config = MpcTlsCommonConfig::builder()
.id(format!("{}/mpc_tls", &config.id()))
.handshake_commit(true)
.build()
.unwrap();
let mpc_config = config.build_tls_mpc_config();
let (ke, prf, encrypter, decrypter) = setup_components(
&common_config,
mpc_config.common(),
TlsRole::Follower,
&mut mux,
&mut vm,
@@ -217,18 +211,8 @@ where
.await
.map_err(|e| NotaryError::MpcError(Box::new(e)))?;
let channel = mux.get_channel(common_config.id()).await?;
let mut mpc_tls = MpcTlsFollower::new(
MpcTlsFollowerConfig::builder()
.common(common_config)
.build()
.unwrap(),
channel,
ke,
prf,
encrypter,
decrypter,
);
let channel = mux.get_channel(mpc_config.common().id()).await?;
let mut mpc_tls = MpcTlsFollower::new(mpc_config, channel, ke, prf, encrypter, decrypter);
#[cfg(feature = "tracing")]
info!("Finished setting up notary components");

View File

@@ -45,6 +45,7 @@ impl ProverConfig {
.common(
MpcTlsCommonConfig::builder()
.id(format!("{}/mpc_tls", &self.id))
.max_transcript_size(self.max_transcript_size)
.handshake_commit(true)
.build()
.unwrap(),