diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index b4774a3e0..2800e4421 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -357,6 +357,9 @@ impl MpcTlsFollower { ) .map_err(MpcTlsError::record_layer)?; } + Message::StartTraffic => { + record_layer.start_traffic(); + } Message::Flush { is_decrypting } => { record_layer .flush(&mut self.ctx, vm.clone(), is_decrypting) diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 057e2a598..a6b934c41 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -43,7 +43,7 @@ use tls_core::{ }, suites::SupportedCipherSuite, }; -use tracing::{debug, instrument, trace}; +use tracing::{debug, instrument, trace, warn}; /// Controller for MPC-TLS leader. pub type LeaderCtrl = actor::MpcTlsLeaderCtrl; @@ -692,13 +692,20 @@ impl Backend for MpcTlsLeader { #[instrument(level = "debug", skip_all, err)] async fn push_incoming(&mut self, msg: OpaqueMessage) -> Result<(), BackendError> { - let State::Active { - ctx, record_layer, .. - } = &mut self.state - else { - return Err( - MpcTlsError::state("must be in active state to push incoming message").into(), - ); + let (ctx, record_layer) = match &mut self.state { + State::Handshake { + ctx, record_layer, .. + } => (ctx, record_layer), + State::Active { + ctx, record_layer, .. + } => (ctx, record_layer), + _ => { + return Err(MpcTlsError::state(format!( + "can not push incoming message in state: {}", + self.state + )) + .into()) + } }; let OpaqueMessage { @@ -746,12 +753,14 @@ impl Backend for MpcTlsLeader { #[instrument(level = "debug", skip_all, err)] async fn next_incoming(&mut self) -> Result, BackendError> { let record_layer = match &mut self.state { + State::Handshake { record_layer, .. } => record_layer, State::Active { record_layer, .. } => record_layer, State::Closed { record_layer, .. } => record_layer, _ => { - return Err(MpcTlsError::state( - "must be in active or closed state to pull next incoming message", - ) + return Err(MpcTlsError::state(format!( + "can not pull next incoming message in state: {}", + self.state + )) .into()) } }; @@ -779,13 +788,20 @@ impl Backend for MpcTlsLeader { #[instrument(level = "debug", skip_all, err)] async fn push_outgoing(&mut self, msg: PlainMessage) -> Result<(), BackendError> { - let State::Active { - ctx, record_layer, .. - } = &mut self.state - else { - return Err( - MpcTlsError::state("must be in active state to push outgoing message").into(), - ); + let (ctx, record_layer) = match &mut self.state { + State::Handshake { + ctx, record_layer, .. + } => (ctx, record_layer), + State::Active { + ctx, record_layer, .. + } => (ctx, record_layer), + _ => { + return Err(MpcTlsError::state(format!( + "can not push outgoing message in state: {}", + self.state + )) + .into()) + } }; debug!( @@ -828,12 +844,14 @@ impl Backend for MpcTlsLeader { #[instrument(level = "debug", skip_all, err)] async fn next_outgoing(&mut self) -> Result, BackendError> { let record_layer = match &mut self.state { + State::Handshake { record_layer, .. } => record_layer, State::Active { record_layer, .. } => record_layer, State::Closed { record_layer, .. } => record_layer, _ => { - return Err(MpcTlsError::state( - "must be in active or closed state to pull next outgoing message", - ) + return Err(MpcTlsError::state(format!( + "can not pull next outgoing message in state: {}", + self.state + )) .into()) } }; @@ -860,9 +878,36 @@ impl Backend for MpcTlsLeader { Ok(record) } + async fn start_traffic(&mut self) -> Result<(), BackendError> { + match &mut self.state { + State::Active { + ctx, record_layer, .. + } => { + record_layer.start_traffic(); + ctx.io_mut() + .send(Message::StartTraffic) + .await + .map_err(MpcTlsError::from)?; + } + _ => { + return Err(MpcTlsError::state(format!( + "can not start traffic in state: {}", + self.state + )) + .into()) + } + } + + Ok(()) + } + #[instrument(level = "debug", skip_all, err)] async fn flush(&mut self) -> Result<(), BackendError> { let (ctx, vm, record_layer) = match &mut self.state { + State::Handshake { .. } => { + warn!("record layer is not ready, skipping flush"); + return Ok(()); + } State::Active { ctx, vm, @@ -876,20 +921,21 @@ impl Backend for MpcTlsLeader { .. } => (ctx, vm, record_layer), _ => { - return Err(MpcTlsError::state( - "must be in active or closed state to flush record layer", - ) + return Err(MpcTlsError::state(format!( + "can not flush record layer in state: {}", + self.state + )) .into()) } }; - debug!("flushing record layer"); - if !record_layer.wants_flush() { debug!("record layer is empty, skipping flush"); return Ok(()); } + debug!("flushing record layer"); + ctx.io_mut() .send(Message::Flush { is_decrypting: self.is_decrypting, @@ -1002,3 +1048,16 @@ impl std::fmt::Debug for State { } } } + +impl std::fmt::Display for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Init { .. } => write!(f, "Init"), + Self::Setup { .. } => write!(f, "Setup"), + Self::Handshake { .. } => write!(f, "Handshake"), + Self::Active { .. } => write!(f, "Active"), + Self::Closed { .. } => write!(f, "Closed"), + Self::Error => write!(f, "Error"), + } + } +} diff --git a/crates/mpc-tls/src/leader/actor.rs b/crates/mpc-tls/src/leader/actor.rs index 038ff80f4..ee3c48f1e 100644 --- a/crates/mpc-tls/src/leader/actor.rs +++ b/crates/mpc-tls/src/leader/actor.rs @@ -224,6 +224,12 @@ impl Dispatch for MpcTlsLeaderMsg { }) .await; } + MpcTlsLeaderMsg::BackendMsgStartTraffic(msg) => { + msg.dispatch(actor, ctx, |value| { + ret(Self::Return::BackendMsgStartTraffic(value)) + }) + .await; + } MpcTlsLeaderMsg::BackendMsgFlush(msg) => { msg.dispatch(actor, ctx, |value| { ret(Self::Return::BackendMsgFlush(value)) @@ -410,6 +416,13 @@ impl Backend for MpcTlsLeaderCtrl { .map_err(|err| BackendError::InternalError(err.to_string()))? } + async fn start_traffic(&mut self) -> Result<(), BackendError> { + self.address + .send(BackendMsgStartTraffic) + .await + .map_err(|err| BackendError::InternalError(err.to_string()))? + } + async fn flush(&mut self) -> Result<(), BackendError> { self.address .send(BackendMsgFlush) @@ -859,6 +872,27 @@ impl Handler for MpcTlsLeader { } } +impl Dispatch for BackendMsgStartTraffic { + fn dispatch( + self, + actor: &mut MpcTlsLeader, + ctx: &mut LudiCtx, + ret: R, + ) -> impl Future + Send { + actor.process(self, ctx, ret) + } +} + +impl Handler for MpcTlsLeader { + async fn handle( + &mut self, + _msg: BackendMsgStartTraffic, + _ctx: &mut LudiCtx, + ) -> ::Return { + self.start_traffic().await + } +} + impl Dispatch for BackendMsgFlush { fn dispatch( self, @@ -1005,6 +1039,7 @@ pub enum MpcTlsLeaderMsg { BackendMsgPushIncoming(BackendMsgPushIncoming), BackendMsgNextOutgoing(BackendMsgNextOutgoing), BackendMsgPushOutgoing(BackendMsgPushOutgoing), + BackendMsgStartTraffic(BackendMsgStartTraffic), BackendMsgFlush(BackendMsgFlush), BackendMsgGetNotify(BackendMsgGetNotify), BackendMsgIsEmpty(BackendMsgIsEmpty), @@ -1039,6 +1074,7 @@ pub enum MpcTlsLeaderMsgReturn { BackendMsgPushIncoming(::Return), BackendMsgNextOutgoing(::Return), BackendMsgPushOutgoing(::Return), + BackendMsgStartTraffic(::Return), BackendMsgFlush(::Return), BackendMsgGetNotify(::Return), BackendMsgIsEmpty(::Return), @@ -1573,6 +1609,31 @@ impl Wrap for MpcTlsLeaderMsg { } } +#[allow(missing_docs)] +#[derive(Debug)] +pub struct BackendMsgStartTraffic; + +impl Message for BackendMsgStartTraffic { + type Return = Result<(), BackendError>; +} + +impl From for MpcTlsLeaderMsg { + fn from(value: BackendMsgStartTraffic) -> Self { + MpcTlsLeaderMsg::BackendMsgStartTraffic(value) + } +} + +impl Wrap for MpcTlsLeaderMsg { + fn unwrap_return( + ret: Self::Return, + ) -> Result<::Return, Error> { + match ret { + Self::Return::BackendMsgStartTraffic(value) => Ok(value), + _ => Err(Error::Wrapper), + } + } +} + #[allow(missing_docs)] #[derive(Debug)] pub struct BackendMsgFlush; diff --git a/crates/mpc-tls/src/msg.rs b/crates/mpc-tls/src/msg.rs index a1bc5a786..c5c86e416 100644 --- a/crates/mpc-tls/src/msg.rs +++ b/crates/mpc-tls/src/msg.rs @@ -15,6 +15,7 @@ pub(crate) enum Message { ServerFinishedVd(ServerFinishedVd), Encrypt(Encrypt), Decrypt(Decrypt), + StartTraffic, Flush { is_decrypting: bool }, CloseConnection, } diff --git a/crates/mpc-tls/src/record_layer.rs b/crates/mpc-tls/src/record_layer.rs index 749d1ef2a..45657754b 100644 --- a/crates/mpc-tls/src/record_layer.rs +++ b/crates/mpc-tls/src/record_layer.rs @@ -23,6 +23,7 @@ use tls_core::{ }; use tlsn_common::transcript::{Record, TlsTranscript}; use tokio::sync::Mutex; +use tracing::{debug, instrument}; use crate::{ record_layer::{aes_ctr::AesCtr, decrypt::DecryptOp, encrypt::EncryptOp}, @@ -77,6 +78,8 @@ pub(crate) struct RecordLayer { decrypt: Arc>, aes_ctr: AesCtr, state: State, + /// Whether the record layer has started processing application data. + started: bool, encrypt_buffer: Vec, decrypt_buffer: Vec, @@ -95,6 +98,7 @@ impl RecordLayer { decrypt: Arc::new(Mutex::new(decrypt)), aes_ctr: AesCtr::new(role), state: State::Init, + started: false, encrypt_buffer: Vec::new(), decrypt_buffer: Vec::new(), encrypted_buffer: VecDeque::new(), @@ -248,6 +252,11 @@ impl RecordLayer { !self.encrypt_buffer.is_empty() || !self.decrypt_buffer.is_empty() } + pub(crate) fn start_traffic(&mut self) { + self.started = true; + debug!("started processing application data"); + } + pub(crate) fn push_encrypt( &mut self, typ: ContentType, @@ -305,14 +314,27 @@ impl RecordLayer { /// Returns the next encrypted record. pub(crate) fn next_encrypted(&mut self) -> Option { - self.encrypted_buffer.pop_front() + let typ = self.encrypted_buffer.front().map(|r| r.typ)?; + // If we haven't started processing application data we return None. + if !self.started && typ == ContentType::ApplicationData { + None + } else { + self.encrypted_buffer.pop_front() + } } /// Returns the next decrypted record. pub(crate) fn next_decrypted(&mut self) -> Option { - self.decrypted_buffer.pop_front() + let typ = self.decrypted_buffer.front().map(|r| r.typ)?; + // If we haven't started processing application data we return None. + if !self.started && typ == ContentType::ApplicationData { + None + } else { + self.decrypted_buffer.pop_front() + } } + #[instrument(level = "debug", skip(self, ctx, vm), err)] pub(crate) async fn flush( &mut self, ctx: &mut Context, @@ -345,19 +367,30 @@ impl RecordLayer { .try_lock() .map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?; - let encrypt_ops = take(&mut self.encrypt_buffer); - - let decrypt_end = if is_decrypting { - self.decrypt_buffer.len() + let encrypt_ops: Vec<_> = self.encrypt_buffer.drain(..).collect(); + let decrypt_ops: Vec<_> = if is_decrypting { + self.decrypt_buffer.drain(..).collect() } else { - // Position of the first application data in the decrypt buffer. - self.decrypt_buffer + // Process non-application data even if we're not decrypting. + let decrypt_pos = self + .decrypt_buffer .iter() .position(|op| op.typ == ContentType::ApplicationData) - .unwrap_or(self.decrypt_buffer.len()) + .unwrap_or(self.decrypt_buffer.len()); + + self.decrypt_buffer.drain(..decrypt_pos).collect() }; - let decrypt_ops: Vec<_> = self.decrypt_buffer.drain(..decrypt_end).collect(); + if encrypt_ops.is_empty() && decrypt_ops.is_empty() { + debug!("no operations to process, skipping"); + return Ok(()); + } + + debug!( + "processing {} encrypt ops and {} decrypt ops", + encrypt_ops.len(), + decrypt_ops.len() + ); let (pending_encrypt, compute_tags) = encrypt::encrypt(&mut (*vm), &mut encrypter, &encrypt_ops)?; diff --git a/crates/tls/backend/src/lib.rs b/crates/tls/backend/src/lib.rs index 234a04fde..29bd7d0fb 100644 --- a/crates/tls/backend/src/lib.rs +++ b/crates/tls/backend/src/lib.rs @@ -118,6 +118,8 @@ pub trait Backend: Send { async fn push_outgoing(&mut self, msg: PlainMessage) -> Result<(), BackendError>; /// Returns next outgoing message. async fn next_outgoing(&mut self) -> Result, BackendError>; + /// Starts processing application data traffic. + async fn start_traffic(&mut self) -> Result<(), BackendError>; /// Flushes the record layer. async fn flush(&mut self) -> Result<(), BackendError>; /// Returns a notification future which resolves when the backend is ready diff --git a/crates/tls/client/src/backend/standard.rs b/crates/tls/client/src/backend/standard.rs index 6fda1c40e..5c693b157 100644 --- a/crates/tls/client/src/backend/standard.rs +++ b/crates/tls/client/src/backend/standard.rs @@ -427,6 +427,10 @@ impl Backend for RustCryptoBackend { Ok(self.incoming_plain.pop_front()) } + async fn start_traffic(&mut self) -> Result<(), BackendError> { + Ok(()) + } + async fn flush(&mut self) -> Result<(), BackendError> { for incoming in take(&mut self.incoming_encrypted) { let seq = self.read_seq; diff --git a/crates/tls/client/src/conn.rs b/crates/tls/client/src/conn.rs index e6da4d0cf..54f5ed51f 100644 --- a/crates/tls/client/src/conn.rs +++ b/crates/tls/client/src/conn.rs @@ -964,6 +964,7 @@ impl CommonState { pub(crate) async fn start_traffic(&mut self) -> Result<(), Error> { self.may_receive_application_data = true; + self.backend.start_traffic().await?; self.start_outgoing_traffic().await }