fix: record layer handshake control flow (#733)

This commit is contained in:
sinu.eth
2025-03-17 11:04:41 -07:00
committed by GitHub
parent 9649d6e4cf
commit b24041b9f5
8 changed files with 200 additions and 36 deletions

View File

@@ -357,6 +357,9 @@ impl MpcTlsFollower {
)
.map_err(MpcTlsError::record_layer)?;
}
Message::StartTraffic => {
record_layer.start_traffic();
}
Message::Flush { is_decrypting } => {
record_layer
.flush(&mut self.ctx, vm.clone(), is_decrypting)

View File

@@ -43,7 +43,7 @@ use tls_core::{
},
suites::SupportedCipherSuite,
};
use tracing::{debug, instrument, trace};
use tracing::{debug, instrument, trace, warn};
/// Controller for MPC-TLS leader.
pub type LeaderCtrl = actor::MpcTlsLeaderCtrl;
@@ -692,13 +692,20 @@ impl Backend for MpcTlsLeader {
#[instrument(level = "debug", skip_all, err)]
async fn push_incoming(&mut self, msg: OpaqueMessage) -> Result<(), BackendError> {
let State::Active {
ctx, record_layer, ..
} = &mut self.state
else {
return Err(
MpcTlsError::state("must be in active state to push incoming message").into(),
);
let (ctx, record_layer) = match &mut self.state {
State::Handshake {
ctx, record_layer, ..
} => (ctx, record_layer),
State::Active {
ctx, record_layer, ..
} => (ctx, record_layer),
_ => {
return Err(MpcTlsError::state(format!(
"can not push incoming message in state: {}",
self.state
))
.into())
}
};
let OpaqueMessage {
@@ -746,12 +753,14 @@ impl Backend for MpcTlsLeader {
#[instrument(level = "debug", skip_all, err)]
async fn next_incoming(&mut self) -> Result<Option<PlainMessage>, BackendError> {
let record_layer = match &mut self.state {
State::Handshake { record_layer, .. } => record_layer,
State::Active { record_layer, .. } => record_layer,
State::Closed { record_layer, .. } => record_layer,
_ => {
return Err(MpcTlsError::state(
"must be in active or closed state to pull next incoming message",
)
return Err(MpcTlsError::state(format!(
"can not pull next incoming message in state: {}",
self.state
))
.into())
}
};
@@ -779,13 +788,20 @@ impl Backend for MpcTlsLeader {
#[instrument(level = "debug", skip_all, err)]
async fn push_outgoing(&mut self, msg: PlainMessage) -> Result<(), BackendError> {
let State::Active {
ctx, record_layer, ..
} = &mut self.state
else {
return Err(
MpcTlsError::state("must be in active state to push outgoing message").into(),
);
let (ctx, record_layer) = match &mut self.state {
State::Handshake {
ctx, record_layer, ..
} => (ctx, record_layer),
State::Active {
ctx, record_layer, ..
} => (ctx, record_layer),
_ => {
return Err(MpcTlsError::state(format!(
"can not push outgoing message in state: {}",
self.state
))
.into())
}
};
debug!(
@@ -828,12 +844,14 @@ impl Backend for MpcTlsLeader {
#[instrument(level = "debug", skip_all, err)]
async fn next_outgoing(&mut self) -> Result<Option<OpaqueMessage>, BackendError> {
let record_layer = match &mut self.state {
State::Handshake { record_layer, .. } => record_layer,
State::Active { record_layer, .. } => record_layer,
State::Closed { record_layer, .. } => record_layer,
_ => {
return Err(MpcTlsError::state(
"must be in active or closed state to pull next outgoing message",
)
return Err(MpcTlsError::state(format!(
"can not pull next outgoing message in state: {}",
self.state
))
.into())
}
};
@@ -860,9 +878,36 @@ impl Backend for MpcTlsLeader {
Ok(record)
}
async fn start_traffic(&mut self) -> Result<(), BackendError> {
match &mut self.state {
State::Active {
ctx, record_layer, ..
} => {
record_layer.start_traffic();
ctx.io_mut()
.send(Message::StartTraffic)
.await
.map_err(MpcTlsError::from)?;
}
_ => {
return Err(MpcTlsError::state(format!(
"can not start traffic in state: {}",
self.state
))
.into())
}
}
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn flush(&mut self) -> Result<(), BackendError> {
let (ctx, vm, record_layer) = match &mut self.state {
State::Handshake { .. } => {
warn!("record layer is not ready, skipping flush");
return Ok(());
}
State::Active {
ctx,
vm,
@@ -876,20 +921,21 @@ impl Backend for MpcTlsLeader {
..
} => (ctx, vm, record_layer),
_ => {
return Err(MpcTlsError::state(
"must be in active or closed state to flush record layer",
)
return Err(MpcTlsError::state(format!(
"can not flush record layer in state: {}",
self.state
))
.into())
}
};
debug!("flushing record layer");
if !record_layer.wants_flush() {
debug!("record layer is empty, skipping flush");
return Ok(());
}
debug!("flushing record layer");
ctx.io_mut()
.send(Message::Flush {
is_decrypting: self.is_decrypting,
@@ -1002,3 +1048,16 @@ impl std::fmt::Debug for State {
}
}
}
impl std::fmt::Display for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Init { .. } => write!(f, "Init"),
Self::Setup { .. } => write!(f, "Setup"),
Self::Handshake { .. } => write!(f, "Handshake"),
Self::Active { .. } => write!(f, "Active"),
Self::Closed { .. } => write!(f, "Closed"),
Self::Error => write!(f, "Error"),
}
}
}

View File

@@ -224,6 +224,12 @@ impl Dispatch<MpcTlsLeader> for MpcTlsLeaderMsg {
})
.await;
}
MpcTlsLeaderMsg::BackendMsgStartTraffic(msg) => {
msg.dispatch(actor, ctx, |value| {
ret(Self::Return::BackendMsgStartTraffic(value))
})
.await;
}
MpcTlsLeaderMsg::BackendMsgFlush(msg) => {
msg.dispatch(actor, ctx, |value| {
ret(Self::Return::BackendMsgFlush(value))
@@ -410,6 +416,13 @@ impl Backend for MpcTlsLeaderCtrl {
.map_err(|err| BackendError::InternalError(err.to_string()))?
}
async fn start_traffic(&mut self) -> Result<(), BackendError> {
self.address
.send(BackendMsgStartTraffic)
.await
.map_err(|err| BackendError::InternalError(err.to_string()))?
}
async fn flush(&mut self) -> Result<(), BackendError> {
self.address
.send(BackendMsgFlush)
@@ -859,6 +872,27 @@ impl Handler<BackendMsgNextOutgoing> for MpcTlsLeader {
}
}
impl Dispatch<MpcTlsLeader> for BackendMsgStartTraffic {
fn dispatch<R: FnOnce(Self::Return) + Send>(
self,
actor: &mut MpcTlsLeader,
ctx: &mut LudiCtx<MpcTlsLeader>,
ret: R,
) -> impl Future<Output = ()> + Send {
actor.process(self, ctx, ret)
}
}
impl Handler<BackendMsgStartTraffic> for MpcTlsLeader {
async fn handle(
&mut self,
_msg: BackendMsgStartTraffic,
_ctx: &mut LudiCtx<Self>,
) -> <BackendMsgStartTraffic as Message>::Return {
self.start_traffic().await
}
}
impl Dispatch<MpcTlsLeader> for BackendMsgFlush {
fn dispatch<R: FnOnce(Self::Return) + Send>(
self,
@@ -1005,6 +1039,7 @@ pub enum MpcTlsLeaderMsg {
BackendMsgPushIncoming(BackendMsgPushIncoming),
BackendMsgNextOutgoing(BackendMsgNextOutgoing),
BackendMsgPushOutgoing(BackendMsgPushOutgoing),
BackendMsgStartTraffic(BackendMsgStartTraffic),
BackendMsgFlush(BackendMsgFlush),
BackendMsgGetNotify(BackendMsgGetNotify),
BackendMsgIsEmpty(BackendMsgIsEmpty),
@@ -1039,6 +1074,7 @@ pub enum MpcTlsLeaderMsgReturn {
BackendMsgPushIncoming(<BackendMsgPushIncoming as Message>::Return),
BackendMsgNextOutgoing(<BackendMsgNextOutgoing as Message>::Return),
BackendMsgPushOutgoing(<BackendMsgPushOutgoing as Message>::Return),
BackendMsgStartTraffic(<BackendMsgStartTraffic as Message>::Return),
BackendMsgFlush(<BackendMsgFlush as Message>::Return),
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
@@ -1573,6 +1609,31 @@ impl Wrap<BackendMsgNextOutgoing> for MpcTlsLeaderMsg {
}
}
#[allow(missing_docs)]
#[derive(Debug)]
pub struct BackendMsgStartTraffic;
impl Message for BackendMsgStartTraffic {
type Return = Result<(), BackendError>;
}
impl From<BackendMsgStartTraffic> for MpcTlsLeaderMsg {
fn from(value: BackendMsgStartTraffic) -> Self {
MpcTlsLeaderMsg::BackendMsgStartTraffic(value)
}
}
impl Wrap<BackendMsgStartTraffic> for MpcTlsLeaderMsg {
fn unwrap_return(
ret: Self::Return,
) -> Result<<BackendMsgStartTraffic as Message>::Return, Error> {
match ret {
Self::Return::BackendMsgStartTraffic(value) => Ok(value),
_ => Err(Error::Wrapper),
}
}
}
#[allow(missing_docs)]
#[derive(Debug)]
pub struct BackendMsgFlush;

View File

@@ -15,6 +15,7 @@ pub(crate) enum Message {
ServerFinishedVd(ServerFinishedVd),
Encrypt(Encrypt),
Decrypt(Decrypt),
StartTraffic,
Flush { is_decrypting: bool },
CloseConnection,
}

View File

@@ -23,6 +23,7 @@ use tls_core::{
};
use tlsn_common::transcript::{Record, TlsTranscript};
use tokio::sync::Mutex;
use tracing::{debug, instrument};
use crate::{
record_layer::{aes_ctr::AesCtr, decrypt::DecryptOp, encrypt::EncryptOp},
@@ -77,6 +78,8 @@ pub(crate) struct RecordLayer {
decrypt: Arc<Mutex<MpcAesGcm>>,
aes_ctr: AesCtr,
state: State,
/// Whether the record layer has started processing application data.
started: bool,
encrypt_buffer: Vec<EncryptOp>,
decrypt_buffer: Vec<DecryptOp>,
@@ -95,6 +98,7 @@ impl RecordLayer {
decrypt: Arc::new(Mutex::new(decrypt)),
aes_ctr: AesCtr::new(role),
state: State::Init,
started: false,
encrypt_buffer: Vec::new(),
decrypt_buffer: Vec::new(),
encrypted_buffer: VecDeque::new(),
@@ -248,6 +252,11 @@ impl RecordLayer {
!self.encrypt_buffer.is_empty() || !self.decrypt_buffer.is_empty()
}
pub(crate) fn start_traffic(&mut self) {
self.started = true;
debug!("started processing application data");
}
pub(crate) fn push_encrypt(
&mut self,
typ: ContentType,
@@ -305,14 +314,27 @@ impl RecordLayer {
/// Returns the next encrypted record.
pub(crate) fn next_encrypted(&mut self) -> Option<EncryptedRecord> {
self.encrypted_buffer.pop_front()
let typ = self.encrypted_buffer.front().map(|r| r.typ)?;
// If we haven't started processing application data we return None.
if !self.started && typ == ContentType::ApplicationData {
None
} else {
self.encrypted_buffer.pop_front()
}
}
/// Returns the next decrypted record.
pub(crate) fn next_decrypted(&mut self) -> Option<PlainRecord> {
self.decrypted_buffer.pop_front()
let typ = self.decrypted_buffer.front().map(|r| r.typ)?;
// If we haven't started processing application data we return None.
if !self.started && typ == ContentType::ApplicationData {
None
} else {
self.decrypted_buffer.pop_front()
}
}
#[instrument(level = "debug", skip(self, ctx, vm), err)]
pub(crate) async fn flush(
&mut self,
ctx: &mut Context,
@@ -345,19 +367,30 @@ impl RecordLayer {
.try_lock()
.map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?;
let encrypt_ops = take(&mut self.encrypt_buffer);
let decrypt_end = if is_decrypting {
self.decrypt_buffer.len()
let encrypt_ops: Vec<_> = self.encrypt_buffer.drain(..).collect();
let decrypt_ops: Vec<_> = if is_decrypting {
self.decrypt_buffer.drain(..).collect()
} else {
// Position of the first application data in the decrypt buffer.
self.decrypt_buffer
// Process non-application data even if we're not decrypting.
let decrypt_pos = self
.decrypt_buffer
.iter()
.position(|op| op.typ == ContentType::ApplicationData)
.unwrap_or(self.decrypt_buffer.len())
.unwrap_or(self.decrypt_buffer.len());
self.decrypt_buffer.drain(..decrypt_pos).collect()
};
let decrypt_ops: Vec<_> = self.decrypt_buffer.drain(..decrypt_end).collect();
if encrypt_ops.is_empty() && decrypt_ops.is_empty() {
debug!("no operations to process, skipping");
return Ok(());
}
debug!(
"processing {} encrypt ops and {} decrypt ops",
encrypt_ops.len(),
decrypt_ops.len()
);
let (pending_encrypt, compute_tags) =
encrypt::encrypt(&mut (*vm), &mut encrypter, &encrypt_ops)?;

View File

@@ -118,6 +118,8 @@ pub trait Backend: Send {
async fn push_outgoing(&mut self, msg: PlainMessage) -> Result<(), BackendError>;
/// Returns next outgoing message.
async fn next_outgoing(&mut self) -> Result<Option<OpaqueMessage>, BackendError>;
/// Starts processing application data traffic.
async fn start_traffic(&mut self) -> Result<(), BackendError>;
/// Flushes the record layer.
async fn flush(&mut self) -> Result<(), BackendError>;
/// Returns a notification future which resolves when the backend is ready

View File

@@ -427,6 +427,10 @@ impl Backend for RustCryptoBackend {
Ok(self.incoming_plain.pop_front())
}
async fn start_traffic(&mut self) -> Result<(), BackendError> {
Ok(())
}
async fn flush(&mut self) -> Result<(), BackendError> {
for incoming in take(&mut self.incoming_encrypted) {
let seq = self.read_seq;

View File

@@ -964,6 +964,7 @@ impl CommonState {
pub(crate) async fn start_traffic(&mut self) -> Result<(), Error> {
self.may_receive_application_data = true;
self.backend.start_traffic().await?;
self.start_outgoing_traffic().await
}