migrate MessageEncrypter and MessageDecrypter traits to tls-aio

This commit is contained in:
sinuio
2022-06-03 20:06:32 -07:00
parent e428db6163
commit 04d08ec60f
12 changed files with 146 additions and 88 deletions

View File

@@ -7,3 +7,6 @@ edition = "2021"
name = "tls_aio"
[dependencies]
async-trait = "0.1.56"
thiserror = "1.0.30"
tlsn-tls-core = { path = "../tls-core" }

18
tls-aio/src/cipher/mod.rs Normal file
View File

@@ -0,0 +1,18 @@
use async_trait::async_trait;
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
/// Objects with this trait can decrypt TLS messages.
#[async_trait]
pub trait MessageDecrypter: Send + Sync {
type Error;
/// Perform the decryption over the concerned TLS message.
async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Self::Error>;
}
/// Objects with this trait can encrypt TLS messages.
#[async_trait]
pub trait MessageEncrypter: Send + Sync {
type Error: Sized;
/// Perform the encryption over the concerned TLS message.
async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result<OpaqueMessage, Self::Error>;
}

7
tls-aio/src/error.rs Normal file
View File

@@ -0,0 +1,7 @@
#[derive(thiserror::Error, Debug, Clone, PartialEq)]
pub enum Error {
#[error("Encountered error during encryption")]
EncryptError,
#[error("Encountered error during decryption")]
DecryptError,
}

View File

@@ -1,8 +1,4 @@
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}
pub mod cipher;
pub mod error;
pub use error::Error;

View File

@@ -25,6 +25,7 @@ sct = "0.7.0"
tokio = { version = "1.18.2", features = ["macros", "rt", "rt-multi-thread"] }
webpki = { version = "0.22.0", features = ["alloc", "std"] }
tlsn-tls-core = { path = "../tls-core" }
tlsn-tls-aio = { path = "../tls-aio" }
[features]
default = ["logging", "tls12"]

View File

@@ -1,33 +1,32 @@
use crate::error::Error;
use tls_core::msgs::codec;
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
use crate::Error;
use tls_aio::cipher::{MessageDecrypter, MessageEncrypter};
use tls_core::msgs::{
codec,
message::{OpaqueMessage, PlainMessage},
};
use async_trait::async_trait;
use ring::{aead, hkdf};
/// Objects with this trait can decrypt TLS messages.
/// A `MessageEncrypter` which doesn't work.
pub struct InvalidMessageEncrypter {}
#[async_trait]
pub trait MessageDecrypter: Send + Sync {
/// Perform the decryption over the concerned TLS message.
async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
}
/// Objects with this trait can encrypt TLS messages.
#[async_trait]
pub trait MessageEncrypter: Send + Sync {
async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
}
impl dyn MessageEncrypter {
pub(crate) fn invalid() -> Box<dyn MessageEncrypter> {
Box::new(InvalidMessageEncrypter {})
impl MessageEncrypter for InvalidMessageEncrypter {
type Error = Error;
async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
Err(Error::EncryptError)
}
}
impl dyn MessageDecrypter {
pub(crate) fn invalid() -> Box<dyn MessageDecrypter> {
Box::new(InvalidMessageDecrypter {})
/// A `MessageDecrypter` which doesn't work.
pub struct InvalidMessageDecrypter {}
#[async_trait]
impl MessageDecrypter for InvalidMessageDecrypter {
type Error = Error;
async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
Err(Error::DecryptError)
}
}
@@ -81,23 +80,3 @@ pub(crate) fn make_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce {
aead::Nonce::assume_unique_for_key(nonce)
}
/// A `MessageEncrypter` which doesn't work.
struct InvalidMessageEncrypter {}
#[async_trait]
impl MessageEncrypter for InvalidMessageEncrypter {
async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
Err(Error::General("encrypt not yet available".to_string()))
}
}
/// A `MessageDecrypter` which doesn't work.
struct InvalidMessageDecrypter {}
#[async_trait]
impl MessageDecrypter for InvalidMessageDecrypter {
async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
Err(Error::DecryptError)
}
}

View File

@@ -1,4 +1,5 @@
use crate::rand;
use tls_aio::Error as AioError;
use tls_core::msgs::enums::{AlertDescription, ContentType, HandshakeType};
use tls_core::Error as CoreError;
@@ -9,9 +10,12 @@ use std::time::SystemTimeError;
/// rustls reports protocol errors using this type.
#[derive(Debug, PartialEq, Clone)]
pub enum Error {
/// Error propagated from Core component
/// Error propagated from tls-core library
CoreError(CoreError),
/// Error propagated from tls-aio library
AioError(AioError),
/// We received a TLS message that isn't valid right now.
/// `expect_types` lists the message types we can expect right now.
/// `got_type` is the type we found. This error is typically
@@ -34,6 +38,13 @@ pub enum Error {
got_type: HandshakeType,
},
/// We couldn't decrypt a message. This is invariably fatal.
DecryptError,
/// We couldn't encrypt a message because it was larger than the allowed message size.
/// This should never happen if the application is using valid record sizes.
EncryptError,
/// The peer sent us a syntactically incorrect TLS message.
CorruptMessage,
@@ -46,13 +57,6 @@ pub enum Error {
/// The certificate verifier doesn't support the given type of name.
UnsupportedNameType,
/// We couldn't decrypt a message. This is invariably fatal.
DecryptError,
/// We couldn't encrypt a message because it was larger than the allowed message size.
/// This should never happen if the application is using valid record sizes.
EncryptError,
/// The peer doesn't support a protocol version/feature we require.
/// The parameter gives a hint as to what version/feature it is.
PeerIncompatibleError(String),
@@ -117,6 +121,9 @@ impl fmt::Display for Error {
Error::CoreError(ref e) => {
write!(f, "core error: {}", e)
}
Error::AioError(ref e) => {
write!(f, "aio error: {}", e)
}
Error::InappropriateMessage {
ref expect_types,
ref got_type,
@@ -179,6 +186,17 @@ impl From<CoreError> for Error {
}
}
impl From<AioError> for Error {
#[inline]
fn from(e: AioError) -> Self {
match e {
AioError::DecryptError => Self::DecryptError,
AioError::EncryptError => Self::EncryptError,
e => Self::AioError(e),
}
}
}
impl From<SystemTimeError> for Error {
#[inline]
fn from(_: SystemTimeError) -> Self {
@@ -215,7 +233,6 @@ mod tests {
Error::CorruptMessage,
Error::CorruptMessagePayload(ContentType::Alert),
Error::NoCertificatesPresented,
Error::DecryptError,
Error::PeerIncompatibleError("no tls1.2".to_string()),
Error::PeerMisbehavedError("inconsistent something".to_string()),
Error::AlertReceived(AlertDescription::ExportRestriction),

View File

@@ -349,10 +349,6 @@ pub mod internal {
pub mod msgs {
pub use tls_core::msgs::*;
}
/// Low-level TLS message decryption functions.
pub mod cipher {
pub use crate::cipher::MessageDecrypter;
}
}
// The public interface is:

View File

@@ -1,5 +1,8 @@
use crate::cipher::{MessageDecrypter, MessageEncrypter};
use crate::error::Error;
use crate::{
cipher::{InvalidMessageDecrypter, InvalidMessageEncrypter},
error::Error,
};
use tls_aio::cipher::{MessageDecrypter, MessageEncrypter};
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
@@ -18,8 +21,8 @@ enum DirectionState {
}
pub(crate) struct RecordLayer {
message_encrypter: Box<dyn MessageEncrypter>,
message_decrypter: Box<dyn MessageDecrypter>,
message_encrypter: Box<dyn MessageEncrypter<Error = Error>>,
message_decrypter: Box<dyn MessageDecrypter<Error = Error>>,
write_seq: u64,
read_seq: u64,
encrypt_state: DirectionState,
@@ -34,8 +37,8 @@ pub(crate) struct RecordLayer {
impl RecordLayer {
pub(crate) fn new() -> Self {
Self {
message_encrypter: <dyn MessageEncrypter>::invalid(),
message_decrypter: <dyn MessageDecrypter>::invalid(),
message_encrypter: Box::new(InvalidMessageEncrypter {}),
message_decrypter: Box::new(InvalidMessageDecrypter {}),
write_seq: 0,
read_seq: 0,
encrypt_state: DirectionState::Invalid,
@@ -67,7 +70,10 @@ impl RecordLayer {
/// Prepare to use the given `MessageEncrypter` for future message encryption.
/// It is not used until you call `start_encrypting`.
pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
pub(crate) fn prepare_message_encrypter(
&mut self,
cipher: Box<dyn MessageEncrypter<Error = Error>>,
) {
self.message_encrypter = cipher;
self.write_seq = 0;
self.encrypt_state = DirectionState::Prepared;
@@ -75,7 +81,10 @@ impl RecordLayer {
/// Prepare to use the given `MessageDecrypter` for future message decryption.
/// It is not used until you call `start_decrypting`.
pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
pub(crate) fn prepare_message_decrypter(
&mut self,
cipher: Box<dyn MessageDecrypter<Error = Error>>,
) {
self.message_decrypter = cipher;
self.read_seq = 0;
self.decrypt_state = DirectionState::Prepared;
@@ -97,14 +106,20 @@ impl RecordLayer {
/// Set and start using the given `MessageEncrypter` for future outgoing
/// message encryption.
pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
pub(crate) fn set_message_encrypter(
&mut self,
cipher: Box<dyn MessageEncrypter<Error = Error>>,
) {
self.prepare_message_encrypter(cipher);
self.start_encrypting();
}
/// Set and start using the given `MessageDecrypter` for future incoming
/// message decryption.
pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
pub(crate) fn set_message_decrypter(
&mut self,
cipher: Box<dyn MessageDecrypter<Error = Error>>,
) {
self.prepare_message_decrypter(cipher);
self.start_decrypting();
self.trial_decryption_len = None;
@@ -115,7 +130,7 @@ impl RecordLayer {
/// 0-RTT is attempted but rejected by the server.
pub(crate) fn set_message_decrypter_with_trial_decryption(
&mut self,
cipher: Box<dyn MessageDecrypter>,
cipher: Box<dyn MessageDecrypter<Error = Error>>,
max_length: usize,
) {
self.prepare_message_decrypter(cipher);

View File

@@ -1,5 +1,6 @@
use crate::cipher::{make_nonce, Iv, MessageDecrypter, MessageEncrypter};
use crate::cipher::{make_nonce, Iv};
use crate::error::Error;
use tls_aio::cipher::{MessageDecrypter, MessageEncrypter};
use tls_core::msgs::base::Payload;
use tls_core::msgs::codec;
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
@@ -28,7 +29,11 @@ fn make_tls12_aad(
pub(crate) struct AesGcm;
impl Tls12AeadAlgorithm for AesGcm {
fn decrypter(&self, dec_key: aead::LessSafeKey, dec_iv: &[u8]) -> Box<dyn MessageDecrypter> {
fn decrypter(
&self,
dec_key: aead::LessSafeKey,
dec_iv: &[u8],
) -> Box<dyn MessageDecrypter<Error = Error>> {
let mut ret = GcmMessageDecrypter {
dec_key,
dec_salt: [0u8; 4],
@@ -44,7 +49,7 @@ impl Tls12AeadAlgorithm for AesGcm {
enc_key: aead::LessSafeKey,
write_iv: &[u8],
explicit: &[u8],
) -> Box<dyn MessageEncrypter> {
) -> Box<dyn MessageEncrypter<Error = Error>> {
debug_assert_eq!(write_iv.len(), 4);
debug_assert_eq!(explicit.len(), 8);
@@ -66,7 +71,11 @@ impl Tls12AeadAlgorithm for AesGcm {
pub(crate) struct ChaCha20Poly1305;
impl Tls12AeadAlgorithm for ChaCha20Poly1305 {
fn decrypter(&self, dec_key: aead::LessSafeKey, iv: &[u8]) -> Box<dyn MessageDecrypter> {
fn decrypter(
&self,
dec_key: aead::LessSafeKey,
iv: &[u8],
) -> Box<dyn MessageDecrypter<Error = Error>> {
Box::new(ChaCha20Poly1305MessageDecrypter {
dec_key,
dec_offset: Iv::copy(iv),
@@ -78,7 +87,7 @@ impl Tls12AeadAlgorithm for ChaCha20Poly1305 {
enc_key: aead::LessSafeKey,
enc_iv: &[u8],
_: &[u8],
) -> Box<dyn MessageEncrypter> {
) -> Box<dyn MessageEncrypter<Error = Error>> {
Box::new(ChaCha20Poly1305MessageEncrypter {
enc_key,
enc_offset: Iv::copy(enc_iv),
@@ -87,13 +96,17 @@ impl Tls12AeadAlgorithm for ChaCha20Poly1305 {
}
pub(crate) trait Tls12AeadAlgorithm: Send + Sync + 'static {
fn decrypter(&self, key: aead::LessSafeKey, iv: &[u8]) -> Box<dyn MessageDecrypter>;
fn decrypter(
&self,
key: aead::LessSafeKey,
iv: &[u8],
) -> Box<dyn MessageDecrypter<Error = Error>>;
fn encrypter(
&self,
key: aead::LessSafeKey,
iv: &[u8],
extra: &[u8],
) -> Box<dyn MessageEncrypter>;
) -> Box<dyn MessageEncrypter<Error = Error>>;
}
/// A `MessageEncrypter` for AES-GCM AEAD ciphersuites. TLS 1.2 only.
@@ -113,6 +126,7 @@ const GCM_OVERHEAD: usize = GCM_EXPLICIT_NONCE_LEN + 16;
#[async_trait]
impl MessageDecrypter for GcmMessageDecrypter {
type Error = Error;
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
let payload = &mut msg.payload.0;
if payload.len() < GCM_OVERHEAD {
@@ -145,6 +159,7 @@ impl MessageDecrypter for GcmMessageDecrypter {
#[async_trait]
impl MessageEncrypter for GcmMessageEncrypter {
type Error = Error;
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = make_nonce(&self.iv, seq);
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.0.len());
@@ -187,6 +202,7 @@ const CHACHAPOLY1305_OVERHEAD: usize = 16;
#[async_trait]
impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
type Error = Error;
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
let payload = &mut msg.payload.0;
@@ -219,6 +235,7 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
#[async_trait]
impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
type Error = Error;
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = make_nonce(&self.enc_offset, seq);
let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.0.len());

View File

@@ -1,12 +1,12 @@
use crate::cipher::{MessageDecrypter, MessageEncrypter};
use crate::conn::{CommonState, ConnectionRandoms, Side};
use crate::kx;
use crate::suites::{BulkAlgorithm, CipherSuiteCommon, SupportedCipherSuite};
use crate::Error;
use tls_aio::cipher::{MessageDecrypter, MessageEncrypter};
use tls_core::msgs::codec::{Codec, Reader};
use tls_core::msgs::enums::{AlertDescription, ContentType};
use tls_core::msgs::enums::{CipherSuite, SignatureScheme};
use tls_core::msgs::handshake::KeyExchangeAlgorithm;
use crate::suites::{BulkAlgorithm, CipherSuiteCommon, SupportedCipherSuite};
use crate::Error;
use ring::aead;
use ring::digest::Digest;
@@ -408,7 +408,10 @@ fn join_randoms(first: &[u8; 32], second: &[u8; 32]) -> [u8; 64] {
randoms
}
type MessageCipherPair = (Box<dyn MessageDecrypter>, Box<dyn MessageEncrypter>);
type MessageCipherPair = (
Box<dyn MessageDecrypter<Error = Error>>,
Box<dyn MessageEncrypter<Error = Error>>,
);
pub(crate) async fn decode_ecdh_params<T: Codec>(
common: &mut CommonState,

View File

@@ -1,6 +1,7 @@
use crate::cipher::{make_nonce, Iv, MessageDecrypter, MessageEncrypter};
use crate::cipher::{make_nonce, Iv};
use crate::error::Error;
use crate::suites::{BulkAlgorithm, CipherSuiteCommon, SupportedCipherSuite};
use tls_aio::cipher::{MessageDecrypter, MessageEncrypter};
use tls_core::msgs::base::Payload;
use tls_core::msgs::codec::Codec;
use tls_core::msgs::enums::{CipherSuite, ContentType, ProtocolVersion};
@@ -60,7 +61,10 @@ pub struct Tls13CipherSuite {
}
impl Tls13CipherSuite {
pub(crate) fn derive_encrypter(&self, secret: &hkdf::Prk) -> Box<dyn MessageEncrypter> {
pub(crate) fn derive_encrypter(
&self,
secret: &hkdf::Prk,
) -> Box<dyn MessageEncrypter<Error = Error>> {
let key = derive_traffic_key(secret, self.common.aead_algorithm);
let iv = derive_traffic_iv(secret);
@@ -72,7 +76,7 @@ impl Tls13CipherSuite {
/// Derive a `MessageDecrypter` object from the concerned TLS 1.3
/// cipher suite.
pub fn derive_decrypter(&self, secret: &hkdf::Prk) -> Box<dyn MessageDecrypter> {
pub fn derive_decrypter(&self, secret: &hkdf::Prk) -> Box<dyn MessageDecrypter<Error = Error>> {
let key = derive_traffic_key(secret, self.common.aead_algorithm);
let iv = derive_traffic_iv(secret);
@@ -149,6 +153,7 @@ const TLS13_AAD_SIZE: usize = 1 + 2 + 2;
#[async_trait]
impl MessageEncrypter for Tls13MessageEncrypter {
type Error = Error;
async fn encrypt(&self, msg: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let total_len = msg.payload.0.len() + 1 + self.enc_key.algorithm().tag_len();
let mut payload = Vec::with_capacity(total_len);
@@ -172,6 +177,7 @@ impl MessageEncrypter for Tls13MessageEncrypter {
#[async_trait]
impl MessageDecrypter for Tls13MessageDecrypter {
type Error = Error;
async fn decrypt(&self, mut msg: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error> {
let payload = &mut msg.payload.0;
if payload.len() < self.dec_key.algorithm().tag_len() {