feat: compress partial transcript (#653)

* feat: compress partial transcript

* add missing dep
This commit is contained in:
dan
2024-12-26 11:41:22 +01:00
committed by GitHub
parent 7bec5a84ee
commit c03418a642
2 changed files with 204 additions and 54 deletions

View File

@@ -40,8 +40,9 @@ web-time = { workspace = true }
webpki-roots = { workspace = true }
[dev-dependencies]
rstest = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-data-fixtures = { workspace = true }
[[test]]

View File

@@ -152,8 +152,8 @@ impl Transcript {
PartialTranscript {
sent,
received,
sent_authed: sent_idx,
received_authed: recv_idx,
sent_authed_idx: sent_idx,
received_authed_idx: recv_idx,
}
}
}
@@ -163,16 +163,83 @@ impl Transcript {
/// A partial transcript is a transcript which may not have all the data
/// authenticated.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(try_from = "validation::PartialTranscriptUnchecked")]
#[serde(try_from = "CompressedPartialTranscript")]
#[serde(into = "CompressedPartialTranscript")]
#[cfg_attr(test, derive(PartialEq))]
pub struct PartialTranscript {
/// Data sent from the Prover to the Server.
sent: Vec<u8>,
/// Data received by the Prover from the Server.
received: Vec<u8>,
/// Index of `sent` which have been authenticated.
sent_authed: Idx,
sent_authed_idx: Idx,
/// Index of `received` which have been authenticated.
received_authed: Idx,
received_authed_idx: Idx,
}
/// `PartialTranscript` in a compressed form.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(try_from = "validation::CompressedPartialTranscriptUnchecked")]
pub struct CompressedPartialTranscript {
/// Sent data which has been authenticated.
sent_authed: Vec<u8>,
/// Received data which has been authenticated.
received_authed: Vec<u8>,
/// Index of `sent_authed`.
sent_idx: Idx,
/// Index of `received_authed`.
recv_idx: Idx,
/// Total bytelength of sent data in the original partial transcript.
sent_total: usize,
/// Total bytelength of received data in the original partial transcript.
recv_total: usize,
}
impl From<PartialTranscript> for CompressedPartialTranscript {
fn from(uncompressed: PartialTranscript) -> Self {
Self {
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx.0),
received_authed: uncompressed
.received
.index_ranges(&uncompressed.received_authed_idx.0),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
recv_total: uncompressed.received.len(),
}
}
}
impl From<CompressedPartialTranscript> for PartialTranscript {
fn from(compressed: CompressedPartialTranscript) -> Self {
let mut sent = vec![0; compressed.sent_total];
let mut received = vec![0; compressed.recv_total];
let mut offset = 0;
for range in compressed.sent_idx.iter_ranges() {
sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len();
}
let mut offset = 0;
for range in compressed.recv_idx.iter_ranges() {
received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len();
}
Self {
sent,
received,
sent_authed_idx: compressed.sent_idx,
received_authed_idx: compressed.recv_idx,
}
}
}
impl PartialTranscript {
@@ -186,8 +253,8 @@ impl PartialTranscript {
Self {
sent: vec![0; sent_len],
received: vec![0; received_len],
sent_authed: Idx::default(),
received_authed: Idx::default(),
sent_authed_idx: Idx::default(),
received_authed_idx: Idx::default(),
}
}
@@ -203,8 +270,8 @@ impl PartialTranscript {
/// Returns whether the transcript is complete.
pub fn is_complete(&self) -> bool {
self.sent_authed.len() == self.sent.len()
&& self.received_authed.len() == self.received.len()
self.sent_authed_idx.len() == self.sent.len()
&& self.received_authed_idx.len() == self.received.len()
}
/// Returns whether the index is in bounds of the transcript.
@@ -239,29 +306,29 @@ impl PartialTranscript {
/// Returns the index of sent data which have been authenticated.
pub fn sent_authed(&self) -> &Idx {
&self.sent_authed
&self.sent_authed_idx
}
/// Returns the index of received data which have been authenticated.
pub fn received_authed(&self) -> &Idx {
&self.received_authed
&self.received_authed_idx
}
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> Idx {
Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed.0))
Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed_idx.0))
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> Idx {
Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed.0))
Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed_idx.0))
}
/// Returns an iterator over the authenticated data in the transcript.
pub fn iter(&self, direction: Direction) -> impl Iterator<Item = u8> + '_ {
let (data, authed) = match direction {
Direction::Sent => (&self.sent, &self.sent_authed),
Direction::Received => (&self.received, &self.received_authed),
Direction::Sent => (&self.sent, &self.sent_authed_idx),
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.0.iter().map(|i| data[i])
@@ -285,25 +352,25 @@ impl PartialTranscript {
);
for range in other
.sent_authed
.sent_authed_idx
.0
.difference(&self.sent_authed.0)
.difference(&self.sent_authed_idx.0)
.iter_ranges()
{
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
}
for range in other
.received_authed
.received_authed_idx
.0
.difference(&self.received_authed.0)
.difference(&self.received_authed_idx.0)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
}
self.sent_authed = self.sent_authed.union(&other.sent_authed);
self.received_authed = self.received_authed.union(&other.received_authed);
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
}
/// Unions an authenticated subsequence into this transcript.
@@ -315,11 +382,11 @@ impl PartialTranscript {
match direction {
Direction::Sent => {
seq.copy_to(&mut self.sent);
self.sent_authed = self.sent_authed.union(&seq.idx);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
}
Direction::Received => {
seq.copy_to(&mut self.received);
self.received_authed = self.received_authed.union(&seq.idx);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
}
}
}
@@ -348,12 +415,12 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for range in range.difference(&self.sent_authed.0).iter_ranges() {
for range in range.difference(&self.sent_authed_idx.0).iter_ranges() {
self.sent[range].fill(value);
}
}
Direction::Received => {
for range in range.difference(&self.received_authed.0).iter_ranges() {
for range in range.difference(&self.received_authed_idx.0).iter_ranges() {
self.received[range].fill(value);
}
}
@@ -549,51 +616,118 @@ mod validation {
}
}
/// Invalid partial transcript error.
/// Invalid compressed partial transcript error.
#[derive(Debug, thiserror::Error)]
#[error("invalid partial transcript: {0}")]
pub struct InvalidPartialTranscript(&'static str);
#[error("invalid compressed partial transcript: {0}")]
pub struct InvalidCompressedPartialTranscript(&'static str);
#[derive(Debug, Deserialize)]
pub(super) struct PartialTranscriptUnchecked {
sent: Vec<u8>,
received: Vec<u8>,
sent_authed: Idx,
received_authed: Idx,
#[cfg_attr(test, derive(Serialize))]
pub(super) struct CompressedPartialTranscriptUnchecked {
sent_authed: Vec<u8>,
received_authed: Vec<u8>,
sent_idx: Idx,
recv_idx: Idx,
sent_total: usize,
recv_total: usize,
}
impl TryFrom<PartialTranscriptUnchecked> for PartialTranscript {
type Error = InvalidPartialTranscript;
impl TryFrom<CompressedPartialTranscriptUnchecked> for CompressedPartialTranscript {
type Error = InvalidCompressedPartialTranscript;
fn try_from(unchecked: PartialTranscriptUnchecked) -> Result<Self, Self::Error> {
if unchecked.sent_authed.end() > unchecked.sent.len()
|| unchecked.received_authed.end() > unchecked.received.len()
fn try_from(unchecked: CompressedPartialTranscriptUnchecked) -> Result<Self, Self::Error> {
if unchecked.sent_authed.len() != unchecked.sent_idx.len()
|| unchecked.received_authed.len() != unchecked.recv_idx.len()
{
return Err(InvalidPartialTranscript(
"authenticated ranges are not in bounds of the data",
return Err(InvalidCompressedPartialTranscript(
"lengths of index and data don't match",
));
}
// Rewrite the data to ensure that unauthenticated data is zeroed out.
let mut sent = vec![0; unchecked.sent.len()];
let mut received = vec![0; unchecked.received.len()];
for range in unchecked.sent_authed.iter_ranges() {
sent[range.clone()].copy_from_slice(&unchecked.sent[range]);
}
for range in unchecked.received_authed.iter_ranges() {
received[range.clone()].copy_from_slice(&unchecked.received[range]);
if unchecked.sent_idx.end() > unchecked.sent_total
|| unchecked.recv_idx.end() > unchecked.recv_total
{
return Err(InvalidCompressedPartialTranscript(
"ranges are not in bounds of the data",
));
}
Ok(Self {
sent,
received,
sent_authed: unchecked.sent_authed,
received_authed: unchecked.received_authed,
recv_idx: unchecked.recv_idx,
recv_total: unchecked.recv_total,
sent_authed: unchecked.sent_authed,
sent_idx: unchecked.sent_idx,
sent_total: unchecked.sent_total,
})
}
}
#[cfg(test)]
mod tests {
use rstest::{fixture, rstest};
use super::*;
#[fixture]
fn partial_transcript() -> CompressedPartialTranscriptUnchecked {
CompressedPartialTranscriptUnchecked {
received_authed: vec![1, 2, 3, 11, 12, 13],
sent_authed: vec![4, 5, 6, 14, 15, 16],
recv_idx: Idx(RangeSet::new(&[1..4, 11..14])),
sent_idx: Idx(RangeSet::new(&[4..7, 14..17])),
sent_total: 20,
recv_total: 20,
}
}
#[rstest]
fn test_partial_transcript_valid(partial_transcript: CompressedPartialTranscriptUnchecked) {
let bytes = bincode::serialize(&partial_transcript).unwrap();
let transcript: Result<CompressedPartialTranscript, Box<bincode::ErrorKind>> =
bincode::deserialize(&bytes);
assert!(transcript.is_ok());
}
#[rstest]
// Expect to fail since the length of data and the length of the index do not
// match.
fn test_partial_transcript_invalid_lengths(
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Add an extra byte to the data.
let mut old = partial_transcript.sent_authed;
old.extend([1]);
partial_transcript.sent_authed = old;
let bytes = bincode::serialize(&partial_transcript).unwrap();
let transcript: Result<CompressedPartialTranscript, Box<bincode::ErrorKind>> =
bincode::deserialize(&bytes);
assert!(transcript.is_err());
}
#[rstest]
// Expect to fail since the index is out of bounds.
fn test_partial_transcript_invalid_ranges(
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Change the total to be less than the last range's end bound.
let end = partial_transcript
.sent_idx
.0
.iter_ranges()
.last()
.unwrap()
.end;
partial_transcript.sent_total = end - 1;
let bytes = bincode::serialize(&partial_transcript).unwrap();
let transcript: Result<CompressedPartialTranscript, Box<bincode::ErrorKind>> =
bincode::deserialize(&bytes);
assert!(transcript.is_err());
}
}
}
#[cfg(test)]
@@ -610,6 +744,14 @@ mod tests {
)
}
#[fixture]
fn partial_transcript() -> PartialTranscript {
transcript().to_partial(
Idx::new(RangeSet::new(&[1..4, 6..9])),
Idx::new(RangeSet::new(&[2..5, 7..10])),
)
}
#[rstest]
fn test_transcript_get_subsequence(transcript: Transcript) {
let subseq = transcript
@@ -632,6 +774,13 @@ mod tests {
assert_eq!(subseq, None);
}
#[rstest]
fn test_partial_transcript_serialization_ok(partial_transcript: PartialTranscript) {
let bytes = bincode::serialize(&partial_transcript).unwrap();
let deserialized_transcript: PartialTranscript = bincode::deserialize(&bytes).unwrap();
assert_eq!(partial_transcript, deserialized_transcript);
}
#[rstest]
fn test_transcript_to_partial_success(transcript: Transcript) {
let partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));