mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 14:48:13 -05:00
fix: record layer handshake control flow (#733)
This commit is contained in:
@@ -357,6 +357,9 @@ impl MpcTlsFollower {
|
||||
)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
}
|
||||
Message::StartTraffic => {
|
||||
record_layer.start_traffic();
|
||||
}
|
||||
Message::Flush { is_decrypting } => {
|
||||
record_layer
|
||||
.flush(&mut self.ctx, vm.clone(), is_decrypting)
|
||||
|
||||
@@ -43,7 +43,7 @@ use tls_core::{
|
||||
},
|
||||
suites::SupportedCipherSuite,
|
||||
};
|
||||
use tracing::{debug, instrument, trace};
|
||||
use tracing::{debug, instrument, trace, warn};
|
||||
|
||||
/// Controller for MPC-TLS leader.
|
||||
pub type LeaderCtrl = actor::MpcTlsLeaderCtrl;
|
||||
@@ -692,13 +692,20 @@ impl Backend for MpcTlsLeader {
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn push_incoming(&mut self, msg: OpaqueMessage) -> Result<(), BackendError> {
|
||||
let State::Active {
|
||||
ctx, record_layer, ..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(
|
||||
MpcTlsError::state("must be in active state to push incoming message").into(),
|
||||
);
|
||||
let (ctx, record_layer) = match &mut self.state {
|
||||
State::Handshake {
|
||||
ctx, record_layer, ..
|
||||
} => (ctx, record_layer),
|
||||
State::Active {
|
||||
ctx, record_layer, ..
|
||||
} => (ctx, record_layer),
|
||||
_ => {
|
||||
return Err(MpcTlsError::state(format!(
|
||||
"can not push incoming message in state: {}",
|
||||
self.state
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
let OpaqueMessage {
|
||||
@@ -746,12 +753,14 @@ impl Backend for MpcTlsLeader {
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn next_incoming(&mut self) -> Result<Option<PlainMessage>, BackendError> {
|
||||
let record_layer = match &mut self.state {
|
||||
State::Handshake { record_layer, .. } => record_layer,
|
||||
State::Active { record_layer, .. } => record_layer,
|
||||
State::Closed { record_layer, .. } => record_layer,
|
||||
_ => {
|
||||
return Err(MpcTlsError::state(
|
||||
"must be in active or closed state to pull next incoming message",
|
||||
)
|
||||
return Err(MpcTlsError::state(format!(
|
||||
"can not pull next incoming message in state: {}",
|
||||
self.state
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
@@ -779,13 +788,20 @@ impl Backend for MpcTlsLeader {
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn push_outgoing(&mut self, msg: PlainMessage) -> Result<(), BackendError> {
|
||||
let State::Active {
|
||||
ctx, record_layer, ..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(
|
||||
MpcTlsError::state("must be in active state to push outgoing message").into(),
|
||||
);
|
||||
let (ctx, record_layer) = match &mut self.state {
|
||||
State::Handshake {
|
||||
ctx, record_layer, ..
|
||||
} => (ctx, record_layer),
|
||||
State::Active {
|
||||
ctx, record_layer, ..
|
||||
} => (ctx, record_layer),
|
||||
_ => {
|
||||
return Err(MpcTlsError::state(format!(
|
||||
"can not push outgoing message in state: {}",
|
||||
self.state
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
@@ -828,12 +844,14 @@ impl Backend for MpcTlsLeader {
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn next_outgoing(&mut self) -> Result<Option<OpaqueMessage>, BackendError> {
|
||||
let record_layer = match &mut self.state {
|
||||
State::Handshake { record_layer, .. } => record_layer,
|
||||
State::Active { record_layer, .. } => record_layer,
|
||||
State::Closed { record_layer, .. } => record_layer,
|
||||
_ => {
|
||||
return Err(MpcTlsError::state(
|
||||
"must be in active or closed state to pull next outgoing message",
|
||||
)
|
||||
return Err(MpcTlsError::state(format!(
|
||||
"can not pull next outgoing message in state: {}",
|
||||
self.state
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
@@ -860,9 +878,36 @@ impl Backend for MpcTlsLeader {
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
async fn start_traffic(&mut self) -> Result<(), BackendError> {
|
||||
match &mut self.state {
|
||||
State::Active {
|
||||
ctx, record_layer, ..
|
||||
} => {
|
||||
record_layer.start_traffic();
|
||||
ctx.io_mut()
|
||||
.send(Message::StartTraffic)
|
||||
.await
|
||||
.map_err(MpcTlsError::from)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(MpcTlsError::state(format!(
|
||||
"can not start traffic in state: {}",
|
||||
self.state
|
||||
))
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn flush(&mut self) -> Result<(), BackendError> {
|
||||
let (ctx, vm, record_layer) = match &mut self.state {
|
||||
State::Handshake { .. } => {
|
||||
warn!("record layer is not ready, skipping flush");
|
||||
return Ok(());
|
||||
}
|
||||
State::Active {
|
||||
ctx,
|
||||
vm,
|
||||
@@ -876,20 +921,21 @@ impl Backend for MpcTlsLeader {
|
||||
..
|
||||
} => (ctx, vm, record_layer),
|
||||
_ => {
|
||||
return Err(MpcTlsError::state(
|
||||
"must be in active or closed state to flush record layer",
|
||||
)
|
||||
return Err(MpcTlsError::state(format!(
|
||||
"can not flush record layer in state: {}",
|
||||
self.state
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
debug!("flushing record layer");
|
||||
|
||||
if !record_layer.wants_flush() {
|
||||
debug!("record layer is empty, skipping flush");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
debug!("flushing record layer");
|
||||
|
||||
ctx.io_mut()
|
||||
.send(Message::Flush {
|
||||
is_decrypting: self.is_decrypting,
|
||||
@@ -1002,3 +1048,16 @@ impl std::fmt::Debug for State {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for State {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Init { .. } => write!(f, "Init"),
|
||||
Self::Setup { .. } => write!(f, "Setup"),
|
||||
Self::Handshake { .. } => write!(f, "Handshake"),
|
||||
Self::Active { .. } => write!(f, "Active"),
|
||||
Self::Closed { .. } => write!(f, "Closed"),
|
||||
Self::Error => write!(f, "Error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,6 +224,12 @@ impl Dispatch<MpcTlsLeader> for MpcTlsLeaderMsg {
|
||||
})
|
||||
.await;
|
||||
}
|
||||
MpcTlsLeaderMsg::BackendMsgStartTraffic(msg) => {
|
||||
msg.dispatch(actor, ctx, |value| {
|
||||
ret(Self::Return::BackendMsgStartTraffic(value))
|
||||
})
|
||||
.await;
|
||||
}
|
||||
MpcTlsLeaderMsg::BackendMsgFlush(msg) => {
|
||||
msg.dispatch(actor, ctx, |value| {
|
||||
ret(Self::Return::BackendMsgFlush(value))
|
||||
@@ -410,6 +416,13 @@ impl Backend for MpcTlsLeaderCtrl {
|
||||
.map_err(|err| BackendError::InternalError(err.to_string()))?
|
||||
}
|
||||
|
||||
async fn start_traffic(&mut self) -> Result<(), BackendError> {
|
||||
self.address
|
||||
.send(BackendMsgStartTraffic)
|
||||
.await
|
||||
.map_err(|err| BackendError::InternalError(err.to_string()))?
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> Result<(), BackendError> {
|
||||
self.address
|
||||
.send(BackendMsgFlush)
|
||||
@@ -859,6 +872,27 @@ impl Handler<BackendMsgNextOutgoing> for MpcTlsLeader {
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch<MpcTlsLeader> for BackendMsgStartTraffic {
|
||||
fn dispatch<R: FnOnce(Self::Return) + Send>(
|
||||
self,
|
||||
actor: &mut MpcTlsLeader,
|
||||
ctx: &mut LudiCtx<MpcTlsLeader>,
|
||||
ret: R,
|
||||
) -> impl Future<Output = ()> + Send {
|
||||
actor.process(self, ctx, ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<BackendMsgStartTraffic> for MpcTlsLeader {
|
||||
async fn handle(
|
||||
&mut self,
|
||||
_msg: BackendMsgStartTraffic,
|
||||
_ctx: &mut LudiCtx<Self>,
|
||||
) -> <BackendMsgStartTraffic as Message>::Return {
|
||||
self.start_traffic().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch<MpcTlsLeader> for BackendMsgFlush {
|
||||
fn dispatch<R: FnOnce(Self::Return) + Send>(
|
||||
self,
|
||||
@@ -1005,6 +1039,7 @@ pub enum MpcTlsLeaderMsg {
|
||||
BackendMsgPushIncoming(BackendMsgPushIncoming),
|
||||
BackendMsgNextOutgoing(BackendMsgNextOutgoing),
|
||||
BackendMsgPushOutgoing(BackendMsgPushOutgoing),
|
||||
BackendMsgStartTraffic(BackendMsgStartTraffic),
|
||||
BackendMsgFlush(BackendMsgFlush),
|
||||
BackendMsgGetNotify(BackendMsgGetNotify),
|
||||
BackendMsgIsEmpty(BackendMsgIsEmpty),
|
||||
@@ -1039,6 +1074,7 @@ pub enum MpcTlsLeaderMsgReturn {
|
||||
BackendMsgPushIncoming(<BackendMsgPushIncoming as Message>::Return),
|
||||
BackendMsgNextOutgoing(<BackendMsgNextOutgoing as Message>::Return),
|
||||
BackendMsgPushOutgoing(<BackendMsgPushOutgoing as Message>::Return),
|
||||
BackendMsgStartTraffic(<BackendMsgStartTraffic as Message>::Return),
|
||||
BackendMsgFlush(<BackendMsgFlush as Message>::Return),
|
||||
BackendMsgGetNotify(<BackendMsgGetNotify as Message>::Return),
|
||||
BackendMsgIsEmpty(<BackendMsgIsEmpty as Message>::Return),
|
||||
@@ -1573,6 +1609,31 @@ impl Wrap<BackendMsgNextOutgoing> for MpcTlsLeaderMsg {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
pub struct BackendMsgStartTraffic;
|
||||
|
||||
impl Message for BackendMsgStartTraffic {
|
||||
type Return = Result<(), BackendError>;
|
||||
}
|
||||
|
||||
impl From<BackendMsgStartTraffic> for MpcTlsLeaderMsg {
|
||||
fn from(value: BackendMsgStartTraffic) -> Self {
|
||||
MpcTlsLeaderMsg::BackendMsgStartTraffic(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Wrap<BackendMsgStartTraffic> for MpcTlsLeaderMsg {
|
||||
fn unwrap_return(
|
||||
ret: Self::Return,
|
||||
) -> Result<<BackendMsgStartTraffic as Message>::Return, Error> {
|
||||
match ret {
|
||||
Self::Return::BackendMsgStartTraffic(value) => Ok(value),
|
||||
_ => Err(Error::Wrapper),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
pub struct BackendMsgFlush;
|
||||
|
||||
@@ -15,6 +15,7 @@ pub(crate) enum Message {
|
||||
ServerFinishedVd(ServerFinishedVd),
|
||||
Encrypt(Encrypt),
|
||||
Decrypt(Decrypt),
|
||||
StartTraffic,
|
||||
Flush { is_decrypting: bool },
|
||||
CloseConnection,
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use tls_core::{
|
||||
};
|
||||
use tlsn_common::transcript::{Record, TlsTranscript};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use crate::{
|
||||
record_layer::{aes_ctr::AesCtr, decrypt::DecryptOp, encrypt::EncryptOp},
|
||||
@@ -77,6 +78,8 @@ pub(crate) struct RecordLayer {
|
||||
decrypt: Arc<Mutex<MpcAesGcm>>,
|
||||
aes_ctr: AesCtr,
|
||||
state: State,
|
||||
/// Whether the record layer has started processing application data.
|
||||
started: bool,
|
||||
|
||||
encrypt_buffer: Vec<EncryptOp>,
|
||||
decrypt_buffer: Vec<DecryptOp>,
|
||||
@@ -95,6 +98,7 @@ impl RecordLayer {
|
||||
decrypt: Arc::new(Mutex::new(decrypt)),
|
||||
aes_ctr: AesCtr::new(role),
|
||||
state: State::Init,
|
||||
started: false,
|
||||
encrypt_buffer: Vec::new(),
|
||||
decrypt_buffer: Vec::new(),
|
||||
encrypted_buffer: VecDeque::new(),
|
||||
@@ -248,6 +252,11 @@ impl RecordLayer {
|
||||
!self.encrypt_buffer.is_empty() || !self.decrypt_buffer.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn start_traffic(&mut self) {
|
||||
self.started = true;
|
||||
debug!("started processing application data");
|
||||
}
|
||||
|
||||
pub(crate) fn push_encrypt(
|
||||
&mut self,
|
||||
typ: ContentType,
|
||||
@@ -305,14 +314,27 @@ impl RecordLayer {
|
||||
|
||||
/// Returns the next encrypted record.
|
||||
pub(crate) fn next_encrypted(&mut self) -> Option<EncryptedRecord> {
|
||||
self.encrypted_buffer.pop_front()
|
||||
let typ = self.encrypted_buffer.front().map(|r| r.typ)?;
|
||||
// If we haven't started processing application data we return None.
|
||||
if !self.started && typ == ContentType::ApplicationData {
|
||||
None
|
||||
} else {
|
||||
self.encrypted_buffer.pop_front()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the next decrypted record.
|
||||
pub(crate) fn next_decrypted(&mut self) -> Option<PlainRecord> {
|
||||
self.decrypted_buffer.pop_front()
|
||||
let typ = self.decrypted_buffer.front().map(|r| r.typ)?;
|
||||
// If we haven't started processing application data we return None.
|
||||
if !self.started && typ == ContentType::ApplicationData {
|
||||
None
|
||||
} else {
|
||||
self.decrypted_buffer.pop_front()
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self, ctx, vm), err)]
|
||||
pub(crate) async fn flush(
|
||||
&mut self,
|
||||
ctx: &mut Context,
|
||||
@@ -345,19 +367,30 @@ impl RecordLayer {
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?;
|
||||
|
||||
let encrypt_ops = take(&mut self.encrypt_buffer);
|
||||
|
||||
let decrypt_end = if is_decrypting {
|
||||
self.decrypt_buffer.len()
|
||||
let encrypt_ops: Vec<_> = self.encrypt_buffer.drain(..).collect();
|
||||
let decrypt_ops: Vec<_> = if is_decrypting {
|
||||
self.decrypt_buffer.drain(..).collect()
|
||||
} else {
|
||||
// Position of the first application data in the decrypt buffer.
|
||||
self.decrypt_buffer
|
||||
// Process non-application data even if we're not decrypting.
|
||||
let decrypt_pos = self
|
||||
.decrypt_buffer
|
||||
.iter()
|
||||
.position(|op| op.typ == ContentType::ApplicationData)
|
||||
.unwrap_or(self.decrypt_buffer.len())
|
||||
.unwrap_or(self.decrypt_buffer.len());
|
||||
|
||||
self.decrypt_buffer.drain(..decrypt_pos).collect()
|
||||
};
|
||||
|
||||
let decrypt_ops: Vec<_> = self.decrypt_buffer.drain(..decrypt_end).collect();
|
||||
if encrypt_ops.is_empty() && decrypt_ops.is_empty() {
|
||||
debug!("no operations to process, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
debug!(
|
||||
"processing {} encrypt ops and {} decrypt ops",
|
||||
encrypt_ops.len(),
|
||||
decrypt_ops.len()
|
||||
);
|
||||
|
||||
let (pending_encrypt, compute_tags) =
|
||||
encrypt::encrypt(&mut (*vm), &mut encrypter, &encrypt_ops)?;
|
||||
|
||||
@@ -118,6 +118,8 @@ pub trait Backend: Send {
|
||||
async fn push_outgoing(&mut self, msg: PlainMessage) -> Result<(), BackendError>;
|
||||
/// Returns next outgoing message.
|
||||
async fn next_outgoing(&mut self) -> Result<Option<OpaqueMessage>, BackendError>;
|
||||
/// Starts processing application data traffic.
|
||||
async fn start_traffic(&mut self) -> Result<(), BackendError>;
|
||||
/// Flushes the record layer.
|
||||
async fn flush(&mut self) -> Result<(), BackendError>;
|
||||
/// Returns a notification future which resolves when the backend is ready
|
||||
|
||||
@@ -427,6 +427,10 @@ impl Backend for RustCryptoBackend {
|
||||
Ok(self.incoming_plain.pop_front())
|
||||
}
|
||||
|
||||
async fn start_traffic(&mut self) -> Result<(), BackendError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> Result<(), BackendError> {
|
||||
for incoming in take(&mut self.incoming_encrypted) {
|
||||
let seq = self.read_seq;
|
||||
|
||||
@@ -964,6 +964,7 @@ impl CommonState {
|
||||
|
||||
pub(crate) async fn start_traffic(&mut self) -> Result<(), Error> {
|
||||
self.may_receive_application_data = true;
|
||||
self.backend.start_traffic().await?;
|
||||
self.start_outgoing_traffic().await
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user