mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 21:38:00 -05:00
convert traits to async, remove message borrowing
This commit is contained in:
@@ -2,18 +2,21 @@ use crate::error::Error;
|
||||
use crate::msgs::codec;
|
||||
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ring::{aead, hkdf};
|
||||
|
||||
/// Objects with this trait can decrypt TLS messages.
|
||||
#[async_trait]
|
||||
pub trait MessageDecrypter: Send + Sync {
|
||||
/// Perform the decryption over the concerned TLS message.
|
||||
|
||||
fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
|
||||
async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
|
||||
}
|
||||
|
||||
/// Objects with this trait can encrypt TLS messages.
|
||||
pub(crate) trait MessageEncrypter: Send + Sync {
|
||||
fn encrypt(&self, m: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
|
||||
#[async_trait]
|
||||
pub trait MessageEncrypter: Send + Sync {
|
||||
async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
|
||||
}
|
||||
|
||||
impl dyn MessageEncrypter {
|
||||
@@ -72,12 +75,9 @@ pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce {
|
||||
let mut nonce = [0u8; ring::aead::NONCE_LEN];
|
||||
codec::put_u64(seq, &mut nonce[4..]);
|
||||
|
||||
nonce
|
||||
.iter_mut()
|
||||
.zip(iv.0.iter())
|
||||
.for_each(|(nonce, iv)| {
|
||||
*nonce ^= *iv;
|
||||
});
|
||||
nonce.iter_mut().zip(iv.0.iter()).for_each(|(nonce, iv)| {
|
||||
*nonce ^= *iv;
|
||||
});
|
||||
|
||||
aead::Nonce::assume_unique_for_key(nonce)
|
||||
}
|
||||
@@ -85,8 +85,9 @@ pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce {
|
||||
/// A `MessageEncrypter` which doesn't work.
|
||||
struct InvalidMessageEncrypter {}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageEncrypter for InvalidMessageEncrypter {
|
||||
fn encrypt(&self, _m: BorrowedPlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
Err(Error::General("encrypt not yet available".to_string()))
|
||||
}
|
||||
}
|
||||
@@ -94,8 +95,9 @@ impl MessageEncrypter for InvalidMessageEncrypter {
|
||||
/// A `MessageDecrypter` which doesn't work.
|
||||
struct InvalidMessageDecrypter {}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageDecrypter for InvalidMessageDecrypter {
|
||||
fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
|
||||
async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
|
||||
Err(Error::DecryptError)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -803,7 +803,7 @@ impl CommonState {
|
||||
self.message_fragmenter.fragment(m, &mut plain_messages);
|
||||
|
||||
for m in plain_messages {
|
||||
self.send_single_fragment(m.borrow()).await;
|
||||
self.send_single_fragment(m).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -819,10 +819,12 @@ impl CommonState {
|
||||
};
|
||||
|
||||
let mut plain_messages = VecDeque::new();
|
||||
self.message_fragmenter.fragment_borrow(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
&payload[..len],
|
||||
self.message_fragmenter.fragment(
|
||||
PlainMessage {
|
||||
typ: ContentType::ApplicationData,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(&payload[..len]),
|
||||
},
|
||||
&mut plain_messages,
|
||||
);
|
||||
|
||||
@@ -833,7 +835,7 @@ impl CommonState {
|
||||
len
|
||||
}
|
||||
|
||||
async fn send_single_fragment<'a>(&mut self, m: BorrowedPlainMessage<'a>) {
|
||||
async fn send_single_fragment(&mut self, m: PlainMessage) {
|
||||
// Close connection once we start to run out of
|
||||
// sequence space.
|
||||
if self.record_layer.wants_close_before_encrypt() {
|
||||
|
||||
@@ -162,7 +162,7 @@ impl RecordLayer {
|
||||
) -> Result<PlainMessage, Error> {
|
||||
debug_assert!(self.is_decrypting());
|
||||
let seq = self.read_seq;
|
||||
let msg = self.message_decrypter.decrypt(encr, seq)?;
|
||||
let msg = self.message_decrypter.decrypt(encr, seq).await?;
|
||||
self.read_seq += 1;
|
||||
Ok(msg)
|
||||
}
|
||||
@@ -171,14 +171,11 @@ impl RecordLayer {
|
||||
///
|
||||
/// `plain` is a TLS message we'd like to send. This function
|
||||
/// panics if the requisite keying material hasn't been established yet.
|
||||
pub(crate) async fn encrypt_outgoing<'a>(
|
||||
&mut self,
|
||||
plain: BorrowedPlainMessage<'a>,
|
||||
) -> OpaqueMessage {
|
||||
pub(crate) async fn encrypt_outgoing(&mut self, plain: PlainMessage) -> OpaqueMessage {
|
||||
debug_assert!(self.encrypt_state == DirectionState::Active);
|
||||
assert!(!self.encrypt_exhausted());
|
||||
let seq = self.write_seq;
|
||||
self.write_seq += 1;
|
||||
self.message_encrypter.encrypt(plain, seq).unwrap()
|
||||
self.message_encrypter.encrypt(plain, seq).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::msgs::enums::{ContentType, ProtocolVersion};
|
||||
use crate::msgs::fragmenter::MAX_FRAGMENT_LEN;
|
||||
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ring::aead;
|
||||
|
||||
const TLS12_AAD_SIZE: usize = 8 + 1 + 2 + 2;
|
||||
@@ -110,8 +111,9 @@ struct GcmMessageDecrypter {
|
||||
const GCM_EXPLICIT_NONCE_LEN: usize = 8;
|
||||
const GCM_OVERHEAD: usize = GCM_EXPLICIT_NONCE_LEN + 16;
|
||||
|
||||
#[async_trait]
|
||||
impl MessageDecrypter for GcmMessageDecrypter {
|
||||
fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
|
||||
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
|
||||
let payload = &mut msg.payload.0;
|
||||
if payload.len() < GCM_OVERHEAD {
|
||||
return Err(Error::DecryptError);
|
||||
@@ -141,15 +143,16 @@ impl MessageDecrypter for GcmMessageDecrypter {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageEncrypter for GcmMessageEncrypter {
|
||||
fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let nonce = make_nonce(&self.iv, seq);
|
||||
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
|
||||
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.0.len());
|
||||
|
||||
let total_len = msg.payload.len() + self.enc_key.algorithm().tag_len();
|
||||
let total_len = msg.payload.0.len() + self.enc_key.algorithm().tag_len();
|
||||
let mut payload = Vec::with_capacity(GCM_EXPLICIT_NONCE_LEN + total_len);
|
||||
payload.extend_from_slice(&nonce.as_ref()[4..]);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
payload.extend_from_slice(&msg.payload.0);
|
||||
|
||||
self.enc_key
|
||||
.seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..])
|
||||
@@ -182,8 +185,9 @@ struct ChaCha20Poly1305MessageDecrypter {
|
||||
|
||||
const CHACHAPOLY1305_OVERHEAD: usize = 16;
|
||||
|
||||
#[async_trait]
|
||||
impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
|
||||
fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
|
||||
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
|
||||
let payload = &mut msg.payload.0;
|
||||
|
||||
if payload.len() < CHACHAPOLY1305_OVERHEAD {
|
||||
@@ -213,14 +217,15 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
|
||||
fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let nonce = make_nonce(&self.enc_offset, seq);
|
||||
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
|
||||
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.0.len());
|
||||
|
||||
let total_len = msg.payload.len() + self.enc_key.algorithm().tag_len();
|
||||
let total_len = msg.payload.0.len() + self.enc_key.algorithm().tag_len();
|
||||
let mut buf = Vec::with_capacity(total_len);
|
||||
buf.extend_from_slice(msg.payload);
|
||||
buf.extend_from_slice(&msg.payload.0);
|
||||
|
||||
self.enc_key
|
||||
.seal_in_place_append_tag(nonce, aad, &mut buf)
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::suites::{BulkAlgorithm, CipherSuiteCommon, SupportedCipherSuite};
|
||||
|
||||
use ring::{aead, hkdf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::fmt;
|
||||
|
||||
pub(crate) mod key_schedule;
|
||||
@@ -146,11 +147,12 @@ fn make_tls13_aad(len: usize) -> ring::aead::Aad<[u8; TLS13_AAD_SIZE]> {
|
||||
// https://datatracker.ietf.org/doc/html/rfc8446#section-5.2
|
||||
const TLS13_AAD_SIZE: usize = 1 + 2 + 2;
|
||||
|
||||
#[async_trait]
|
||||
impl MessageEncrypter for Tls13MessageEncrypter {
|
||||
fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let total_len = msg.payload.len() + 1 + self.enc_key.algorithm().tag_len();
|
||||
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let total_len = msg.payload.0.len() + 1 + self.enc_key.algorithm().tag_len();
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
payload.extend_from_slice(&msg.payload.0);
|
||||
msg.typ.encode(&mut payload);
|
||||
|
||||
let nonce = make_nonce(&self.iv, seq);
|
||||
@@ -168,8 +170,9 @@ impl MessageEncrypter for Tls13MessageEncrypter {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageDecrypter for Tls13MessageDecrypter {
|
||||
fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
|
||||
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
|
||||
let payload = &mut msg.payload.0;
|
||||
if payload.len() < self.dec_key.algorithm().tag_len() {
|
||||
return Err(Error::DecryptError);
|
||||
|
||||
Reference in New Issue
Block a user