mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 14:48:13 -05:00
feat: compress partial transcript (#653)
* feat: compress partial transcript * add missing dep
This commit is contained in:
@@ -40,8 +40,9 @@ web-time = { workspace = true }
|
||||
webpki-roots = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tlsn-data-fixtures = { workspace = true }
|
||||
|
||||
[[test]]
|
||||
|
||||
@@ -152,8 +152,8 @@ impl Transcript {
|
||||
PartialTranscript {
|
||||
sent,
|
||||
received,
|
||||
sent_authed: sent_idx,
|
||||
received_authed: recv_idx,
|
||||
sent_authed_idx: sent_idx,
|
||||
received_authed_idx: recv_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,16 +163,83 @@ impl Transcript {
|
||||
/// A partial transcript is a transcript which may not have all the data
|
||||
/// authenticated.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "validation::PartialTranscriptUnchecked")]
|
||||
#[serde(try_from = "CompressedPartialTranscript")]
|
||||
#[serde(into = "CompressedPartialTranscript")]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub struct PartialTranscript {
|
||||
/// Data sent from the Prover to the Server.
|
||||
sent: Vec<u8>,
|
||||
/// Data received by the Prover from the Server.
|
||||
received: Vec<u8>,
|
||||
/// Index of `sent` which have been authenticated.
|
||||
sent_authed: Idx,
|
||||
sent_authed_idx: Idx,
|
||||
/// Index of `received` which have been authenticated.
|
||||
received_authed: Idx,
|
||||
received_authed_idx: Idx,
|
||||
}
|
||||
|
||||
/// `PartialTranscript` in a compressed form.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(try_from = "validation::CompressedPartialTranscriptUnchecked")]
|
||||
pub struct CompressedPartialTranscript {
|
||||
/// Sent data which has been authenticated.
|
||||
sent_authed: Vec<u8>,
|
||||
/// Received data which has been authenticated.
|
||||
received_authed: Vec<u8>,
|
||||
/// Index of `sent_authed`.
|
||||
sent_idx: Idx,
|
||||
/// Index of `received_authed`.
|
||||
recv_idx: Idx,
|
||||
/// Total bytelength of sent data in the original partial transcript.
|
||||
sent_total: usize,
|
||||
/// Total bytelength of received data in the original partial transcript.
|
||||
recv_total: usize,
|
||||
}
|
||||
|
||||
impl From<PartialTranscript> for CompressedPartialTranscript {
|
||||
fn from(uncompressed: PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent_authed: uncompressed
|
||||
.sent
|
||||
.index_ranges(&uncompressed.sent_authed_idx.0),
|
||||
received_authed: uncompressed
|
||||
.received
|
||||
.index_ranges(&uncompressed.received_authed_idx.0),
|
||||
sent_idx: uncompressed.sent_authed_idx,
|
||||
recv_idx: uncompressed.received_authed_idx,
|
||||
sent_total: uncompressed.sent.len(),
|
||||
recv_total: uncompressed.received.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
fn from(compressed: CompressedPartialTranscript) -> Self {
|
||||
let mut sent = vec![0; compressed.sent_total];
|
||||
let mut received = vec![0; compressed.recv_total];
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.sent_idx.iter_ranges() {
|
||||
sent[range.clone()]
|
||||
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
}
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.recv_idx.iter_ranges() {
|
||||
received[range.clone()]
|
||||
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
}
|
||||
|
||||
Self {
|
||||
sent,
|
||||
received,
|
||||
sent_authed_idx: compressed.sent_idx,
|
||||
received_authed_idx: compressed.recv_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialTranscript {
|
||||
@@ -186,8 +253,8 @@ impl PartialTranscript {
|
||||
Self {
|
||||
sent: vec![0; sent_len],
|
||||
received: vec![0; received_len],
|
||||
sent_authed: Idx::default(),
|
||||
received_authed: Idx::default(),
|
||||
sent_authed_idx: Idx::default(),
|
||||
received_authed_idx: Idx::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,8 +270,8 @@ impl PartialTranscript {
|
||||
|
||||
/// Returns whether the transcript is complete.
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.sent_authed.len() == self.sent.len()
|
||||
&& self.received_authed.len() == self.received.len()
|
||||
self.sent_authed_idx.len() == self.sent.len()
|
||||
&& self.received_authed_idx.len() == self.received.len()
|
||||
}
|
||||
|
||||
/// Returns whether the index is in bounds of the transcript.
|
||||
@@ -239,29 +306,29 @@ impl PartialTranscript {
|
||||
|
||||
/// Returns the index of sent data which have been authenticated.
|
||||
pub fn sent_authed(&self) -> &Idx {
|
||||
&self.sent_authed
|
||||
&self.sent_authed_idx
|
||||
}
|
||||
|
||||
/// Returns the index of received data which have been authenticated.
|
||||
pub fn received_authed(&self) -> &Idx {
|
||||
&self.received_authed
|
||||
&self.received_authed_idx
|
||||
}
|
||||
|
||||
/// Returns the index of sent data which haven't been authenticated.
|
||||
pub fn sent_unauthed(&self) -> Idx {
|
||||
Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed.0))
|
||||
Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed_idx.0))
|
||||
}
|
||||
|
||||
/// Returns the index of received data which haven't been authenticated.
|
||||
pub fn received_unauthed(&self) -> Idx {
|
||||
Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed.0))
|
||||
Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed_idx.0))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the authenticated data in the transcript.
|
||||
pub fn iter(&self, direction: Direction) -> impl Iterator<Item = u8> + '_ {
|
||||
let (data, authed) = match direction {
|
||||
Direction::Sent => (&self.sent, &self.sent_authed),
|
||||
Direction::Received => (&self.received, &self.received_authed),
|
||||
Direction::Sent => (&self.sent, &self.sent_authed_idx),
|
||||
Direction::Received => (&self.received, &self.received_authed_idx),
|
||||
};
|
||||
|
||||
authed.0.iter().map(|i| data[i])
|
||||
@@ -285,25 +352,25 @@ impl PartialTranscript {
|
||||
);
|
||||
|
||||
for range in other
|
||||
.sent_authed
|
||||
.sent_authed_idx
|
||||
.0
|
||||
.difference(&self.sent_authed.0)
|
||||
.difference(&self.sent_authed_idx.0)
|
||||
.iter_ranges()
|
||||
{
|
||||
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
|
||||
}
|
||||
|
||||
for range in other
|
||||
.received_authed
|
||||
.received_authed_idx
|
||||
.0
|
||||
.difference(&self.received_authed.0)
|
||||
.difference(&self.received_authed_idx.0)
|
||||
.iter_ranges()
|
||||
{
|
||||
self.received[range.clone()].copy_from_slice(&other.received[range]);
|
||||
}
|
||||
|
||||
self.sent_authed = self.sent_authed.union(&other.sent_authed);
|
||||
self.received_authed = self.received_authed.union(&other.received_authed);
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
|
||||
}
|
||||
|
||||
/// Unions an authenticated subsequence into this transcript.
|
||||
@@ -315,11 +382,11 @@ impl PartialTranscript {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
seq.copy_to(&mut self.sent);
|
||||
self.sent_authed = self.sent_authed.union(&seq.idx);
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
|
||||
}
|
||||
Direction::Received => {
|
||||
seq.copy_to(&mut self.received);
|
||||
self.received_authed = self.received_authed.union(&seq.idx);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -348,12 +415,12 @@ impl PartialTranscript {
|
||||
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
for range in range.difference(&self.sent_authed.0).iter_ranges() {
|
||||
for range in range.difference(&self.sent_authed_idx.0).iter_ranges() {
|
||||
self.sent[range].fill(value);
|
||||
}
|
||||
}
|
||||
Direction::Received => {
|
||||
for range in range.difference(&self.received_authed.0).iter_ranges() {
|
||||
for range in range.difference(&self.received_authed_idx.0).iter_ranges() {
|
||||
self.received[range].fill(value);
|
||||
}
|
||||
}
|
||||
@@ -549,51 +616,118 @@ mod validation {
|
||||
}
|
||||
}
|
||||
|
||||
/// Invalid partial transcript error.
|
||||
/// Invalid compressed partial transcript error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("invalid partial transcript: {0}")]
|
||||
pub struct InvalidPartialTranscript(&'static str);
|
||||
#[error("invalid compressed partial transcript: {0}")]
|
||||
pub struct InvalidCompressedPartialTranscript(&'static str);
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct PartialTranscriptUnchecked {
|
||||
sent: Vec<u8>,
|
||||
received: Vec<u8>,
|
||||
sent_authed: Idx,
|
||||
received_authed: Idx,
|
||||
#[cfg_attr(test, derive(Serialize))]
|
||||
pub(super) struct CompressedPartialTranscriptUnchecked {
|
||||
sent_authed: Vec<u8>,
|
||||
received_authed: Vec<u8>,
|
||||
sent_idx: Idx,
|
||||
recv_idx: Idx,
|
||||
sent_total: usize,
|
||||
recv_total: usize,
|
||||
}
|
||||
|
||||
impl TryFrom<PartialTranscriptUnchecked> for PartialTranscript {
|
||||
type Error = InvalidPartialTranscript;
|
||||
impl TryFrom<CompressedPartialTranscriptUnchecked> for CompressedPartialTranscript {
|
||||
type Error = InvalidCompressedPartialTranscript;
|
||||
|
||||
fn try_from(unchecked: PartialTranscriptUnchecked) -> Result<Self, Self::Error> {
|
||||
if unchecked.sent_authed.end() > unchecked.sent.len()
|
||||
|| unchecked.received_authed.end() > unchecked.received.len()
|
||||
fn try_from(unchecked: CompressedPartialTranscriptUnchecked) -> Result<Self, Self::Error> {
|
||||
if unchecked.sent_authed.len() != unchecked.sent_idx.len()
|
||||
|| unchecked.received_authed.len() != unchecked.recv_idx.len()
|
||||
{
|
||||
return Err(InvalidPartialTranscript(
|
||||
"authenticated ranges are not in bounds of the data",
|
||||
return Err(InvalidCompressedPartialTranscript(
|
||||
"lengths of index and data don't match",
|
||||
));
|
||||
}
|
||||
|
||||
// Rewrite the data to ensure that unauthenticated data is zeroed out.
|
||||
let mut sent = vec![0; unchecked.sent.len()];
|
||||
let mut received = vec![0; unchecked.received.len()];
|
||||
|
||||
for range in unchecked.sent_authed.iter_ranges() {
|
||||
sent[range.clone()].copy_from_slice(&unchecked.sent[range]);
|
||||
}
|
||||
|
||||
for range in unchecked.received_authed.iter_ranges() {
|
||||
received[range.clone()].copy_from_slice(&unchecked.received[range]);
|
||||
if unchecked.sent_idx.end() > unchecked.sent_total
|
||||
|| unchecked.recv_idx.end() > unchecked.recv_total
|
||||
{
|
||||
return Err(InvalidCompressedPartialTranscript(
|
||||
"ranges are not in bounds of the data",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
sent,
|
||||
received,
|
||||
sent_authed: unchecked.sent_authed,
|
||||
received_authed: unchecked.received_authed,
|
||||
recv_idx: unchecked.recv_idx,
|
||||
recv_total: unchecked.recv_total,
|
||||
sent_authed: unchecked.sent_authed,
|
||||
sent_idx: unchecked.sent_idx,
|
||||
sent_total: unchecked.sent_total,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rstest::{fixture, rstest};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[fixture]
|
||||
fn partial_transcript() -> CompressedPartialTranscriptUnchecked {
|
||||
CompressedPartialTranscriptUnchecked {
|
||||
received_authed: vec![1, 2, 3, 11, 12, 13],
|
||||
sent_authed: vec![4, 5, 6, 14, 15, 16],
|
||||
recv_idx: Idx(RangeSet::new(&[1..4, 11..14])),
|
||||
sent_idx: Idx(RangeSet::new(&[4..7, 14..17])),
|
||||
sent_total: 20,
|
||||
recv_total: 20,
|
||||
}
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_partial_transcript_valid(partial_transcript: CompressedPartialTranscriptUnchecked) {
|
||||
let bytes = bincode::serialize(&partial_transcript).unwrap();
|
||||
let transcript: Result<CompressedPartialTranscript, Box<bincode::ErrorKind>> =
|
||||
bincode::deserialize(&bytes);
|
||||
assert!(transcript.is_ok());
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
// Expect to fail since the length of data and the length of the index do not
|
||||
// match.
|
||||
fn test_partial_transcript_invalid_lengths(
|
||||
mut partial_transcript: CompressedPartialTranscriptUnchecked,
|
||||
) {
|
||||
// Add an extra byte to the data.
|
||||
let mut old = partial_transcript.sent_authed;
|
||||
old.extend([1]);
|
||||
partial_transcript.sent_authed = old;
|
||||
|
||||
let bytes = bincode::serialize(&partial_transcript).unwrap();
|
||||
let transcript: Result<CompressedPartialTranscript, Box<bincode::ErrorKind>> =
|
||||
bincode::deserialize(&bytes);
|
||||
assert!(transcript.is_err());
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
// Expect to fail since the index is out of bounds.
|
||||
fn test_partial_transcript_invalid_ranges(
|
||||
mut partial_transcript: CompressedPartialTranscriptUnchecked,
|
||||
) {
|
||||
// Change the total to be less than the last range's end bound.
|
||||
let end = partial_transcript
|
||||
.sent_idx
|
||||
.0
|
||||
.iter_ranges()
|
||||
.last()
|
||||
.unwrap()
|
||||
.end;
|
||||
|
||||
partial_transcript.sent_total = end - 1;
|
||||
|
||||
let bytes = bincode::serialize(&partial_transcript).unwrap();
|
||||
let transcript: Result<CompressedPartialTranscript, Box<bincode::ErrorKind>> =
|
||||
bincode::deserialize(&bytes);
|
||||
assert!(transcript.is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -610,6 +744,14 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn partial_transcript() -> PartialTranscript {
|
||||
transcript().to_partial(
|
||||
Idx::new(RangeSet::new(&[1..4, 6..9])),
|
||||
Idx::new(RangeSet::new(&[2..5, 7..10])),
|
||||
)
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_transcript_get_subsequence(transcript: Transcript) {
|
||||
let subseq = transcript
|
||||
@@ -632,6 +774,13 @@ mod tests {
|
||||
assert_eq!(subseq, None);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_partial_transcript_serialization_ok(partial_transcript: PartialTranscript) {
|
||||
let bytes = bincode::serialize(&partial_transcript).unwrap();
|
||||
let deserialized_transcript: PartialTranscript = bincode::deserialize(&bytes).unwrap();
|
||||
assert_eq!(partial_transcript, deserialized_transcript);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_transcript_to_partial_success(transcript: Transcript) {
|
||||
let partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));
|
||||
|
||||
Reference in New Issue
Block a user