convert traits to async, remove message borrowing

This commit is contained in:
sinuio
2022-06-02 15:53:46 -07:00
parent 6886158c69
commit f786dfce1f
5 changed files with 46 additions and 37 deletions

View File

@@ -2,18 +2,21 @@ use crate::error::Error;
use crate::msgs::codec;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
use async_trait::async_trait;
use ring::{aead, hkdf};
/// Objects with this trait can decrypt TLS messages.
#[async_trait]
pub trait MessageDecrypter: Send + Sync {
/// Perform the decryption over the concerned TLS message.
fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
}
/// Objects with this trait can encrypt TLS messages.
pub(crate) trait MessageEncrypter: Send + Sync {
fn encrypt(&self, m: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
#[async_trait]
pub trait MessageEncrypter: Send + Sync {
async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
}
impl dyn MessageEncrypter {
@@ -72,12 +75,9 @@ pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce {
let mut nonce = [0u8; ring::aead::NONCE_LEN];
codec::put_u64(seq, &mut nonce[4..]);
nonce
.iter_mut()
.zip(iv.0.iter())
.for_each(|(nonce, iv)| {
*nonce ^= *iv;
});
nonce.iter_mut().zip(iv.0.iter()).for_each(|(nonce, iv)| {
*nonce ^= *iv;
});
aead::Nonce::assume_unique_for_key(nonce)
}
@@ -85,8 +85,9 @@ pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce {
/// A `MessageEncrypter` which doesn't work.
struct InvalidMessageEncrypter {}
#[async_trait]
impl MessageEncrypter for InvalidMessageEncrypter {
fn encrypt(&self, _m: BorrowedPlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
Err(Error::General("encrypt not yet available".to_string()))
}
}
@@ -94,8 +95,9 @@ impl MessageEncrypter for InvalidMessageEncrypter {
/// A `MessageDecrypter` which doesn't work.
struct InvalidMessageDecrypter {}
#[async_trait]
impl MessageDecrypter for InvalidMessageDecrypter {
fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
Err(Error::DecryptError)
}
}

View File

@@ -803,7 +803,7 @@ impl CommonState {
self.message_fragmenter.fragment(m, &mut plain_messages);
for m in plain_messages {
self.send_single_fragment(m.borrow()).await;
self.send_single_fragment(m).await;
}
}
@@ -819,10 +819,12 @@ impl CommonState {
};
let mut plain_messages = VecDeque::new();
self.message_fragmenter.fragment_borrow(
ContentType::ApplicationData,
ProtocolVersion::TLSv1_2,
&payload[..len],
self.message_fragmenter.fragment(
PlainMessage {
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(&payload[..len]),
},
&mut plain_messages,
);
@@ -833,7 +835,7 @@ impl CommonState {
len
}
async fn send_single_fragment<'a>(&mut self, m: BorrowedPlainMessage<'a>) {
async fn send_single_fragment(&mut self, m: PlainMessage) {
// Close connection once we start to run out of
// sequence space.
if self.record_layer.wants_close_before_encrypt() {

View File

@@ -162,7 +162,7 @@ impl RecordLayer {
) -> Result<PlainMessage, Error> {
debug_assert!(self.is_decrypting());
let seq = self.read_seq;
let msg = self.message_decrypter.decrypt(encr, seq)?;
let msg = self.message_decrypter.decrypt(encr, seq).await?;
self.read_seq += 1;
Ok(msg)
}
@@ -171,14 +171,11 @@ impl RecordLayer {
///
/// `plain` is a TLS message we'd like to send. This function
/// panics if the requisite keying material hasn't been established yet.
pub(crate) async fn encrypt_outgoing<'a>(
&mut self,
plain: BorrowedPlainMessage<'a>,
) -> OpaqueMessage {
pub(crate) async fn encrypt_outgoing(&mut self, plain: PlainMessage) -> OpaqueMessage {
debug_assert!(self.encrypt_state == DirectionState::Active);
assert!(!self.encrypt_exhausted());
let seq = self.write_seq;
self.write_seq += 1;
self.message_encrypter.encrypt(plain, seq).unwrap()
self.message_encrypter.encrypt(plain, seq).await.unwrap()
}
}

View File

@@ -6,6 +6,7 @@ use crate::msgs::enums::{ContentType, ProtocolVersion};
use crate::msgs::fragmenter::MAX_FRAGMENT_LEN;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
use async_trait::async_trait;
use ring::aead;
const TLS12_AAD_SIZE: usize = 8 + 1 + 2 + 2;
@@ -110,8 +111,9 @@ struct GcmMessageDecrypter {
const GCM_EXPLICIT_NONCE_LEN: usize = 8;
const GCM_OVERHEAD: usize = GCM_EXPLICIT_NONCE_LEN + 16;
#[async_trait]
impl MessageDecrypter for GcmMessageDecrypter {
fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
let payload = &mut msg.payload.0;
if payload.len() < GCM_OVERHEAD {
return Err(Error::DecryptError);
@@ -141,15 +143,16 @@ impl MessageDecrypter for GcmMessageDecrypter {
}
}
#[async_trait]
impl MessageEncrypter for GcmMessageEncrypter {
fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = make_nonce(&self.iv, seq);
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.0.len());
let total_len = msg.payload.len() + self.enc_key.algorithm().tag_len();
let total_len = msg.payload.0.len() + self.enc_key.algorithm().tag_len();
let mut payload = Vec::with_capacity(GCM_EXPLICIT_NONCE_LEN + total_len);
payload.extend_from_slice(&nonce.as_ref()[4..]);
payload.extend_from_slice(msg.payload);
payload.extend_from_slice(&msg.payload.0);
self.enc_key
.seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..])
@@ -182,8 +185,9 @@ struct ChaCha20Poly1305MessageDecrypter {
const CHACHAPOLY1305_OVERHEAD: usize = 16;
#[async_trait]
impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
let payload = &mut msg.payload.0;
if payload.len() < CHACHAPOLY1305_OVERHEAD {
@@ -213,14 +217,15 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
}
}
#[async_trait]
impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = make_nonce(&self.enc_offset, seq);
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.0.len());
let total_len = msg.payload.len() + self.enc_key.algorithm().tag_len();
let total_len = msg.payload.0.len() + self.enc_key.algorithm().tag_len();
let mut buf = Vec::with_capacity(total_len);
buf.extend_from_slice(msg.payload);
buf.extend_from_slice(&msg.payload.0);
self.enc_key
.seal_in_place_append_tag(nonce, aad, &mut buf)

View File

@@ -9,6 +9,7 @@ use crate::suites::{BulkAlgorithm, CipherSuiteCommon, SupportedCipherSuite};
use ring::{aead, hkdf};
use async_trait::async_trait;
use std::fmt;
pub(crate) mod key_schedule;
@@ -146,11 +147,12 @@ fn make_tls13_aad(len: usize) -> ring::aead::Aad<[u8; TLS13_AAD_SIZE]> {
// https://datatracker.ietf.org/doc/html/rfc8446#section-5.2
const TLS13_AAD_SIZE: usize = 1 + 2 + 2;
#[async_trait]
impl MessageEncrypter for Tls13MessageEncrypter {
fn encrypt(&self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let total_len = msg.payload.len() + 1 + self.enc_key.algorithm().tag_len();
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let total_len = msg.payload.0.len() + 1 + self.enc_key.algorithm().tag_len();
let mut payload = Vec::with_capacity(total_len);
payload.extend_from_slice(msg.payload);
payload.extend_from_slice(&msg.payload.0);
msg.typ.encode(&mut payload);
let nonce = make_nonce(&self.iv, seq);
@@ -168,8 +170,9 @@ impl MessageEncrypter for Tls13MessageEncrypter {
}
}
#[async_trait]
impl MessageDecrypter for Tls13MessageDecrypter {
fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
let payload = &mut msg.payload.0;
if payload.len() < self.dec_key.algorithm().tag_len() {
return Err(Error::DecryptError);