consolidate traits, convert functions to fallible, and remove recursion

This commit is contained in:
sinuio
2022-06-15 16:43:52 -07:00
parent ce66d2c419
commit 452db97dd2
9 changed files with 276 additions and 233 deletions

View File

@@ -378,7 +378,7 @@ pub struct Initialized {
config: Arc<ClientConfig>,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for Initialized {
async fn start(
self: Box<Self>,

View File

@@ -83,7 +83,7 @@ pub(super) async fn start_handshake(
let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3);
let key_share = if support_tls13 {
Some(cx.common.handshaker.client_key_share().await?)
Some(cx.common.crypto.client_key_share().await?)
} else {
None
};
@@ -113,7 +113,7 @@ pub(super) async fn start_handshake(
session_id = Some(SessionID::random()?);
}
let random = cx.common.handshaker.client_random().await?;
let random = cx.common.crypto.client_random().await?;
let hello_details = ClientHelloDetails::new();
let sent_tls13_fake_ccs = false;
let may_send_sct_list = config.verifier.request_scts();
@@ -134,7 +134,7 @@ pub(super) async fn start_handshake(
may_send_sct_list,
None,
)
.await)
.await?)
}
struct ExpectServerHello {
@@ -172,7 +172,7 @@ async fn emit_client_hello_for_retry(
extra_exts: Vec<ClientExtension>,
may_send_sct_list: bool,
suite: Option<SupportedCipherSuite>,
) -> NextState {
) -> Result<NextState, Error> {
// For now we do not support session resumption
//
// Do we have a SessionID or ticket cached for this host?
@@ -337,13 +337,13 @@ async fn emit_client_hello_for_retry(
if retryreq.is_some() {
// send dummy CCS to fool middleboxes prior
// to second client hello
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await?;
}
trace!("Sending ClientHello {:#?}", ch);
transcript_buffer.add_message(&ch);
cx.common.send_msg(ch, false).await;
cx.common.send_msg(ch, false).await?;
let next = ExpectServerHello {
config,
@@ -360,9 +360,12 @@ async fn emit_client_hello_for_retry(
};
if support_tls13 && retryreq.is_none() {
Box::new(ExpectServerHelloOrHelloRetryRequest { next, extra_exts })
Ok(Box::new(ExpectServerHelloOrHelloRetryRequest {
next,
extra_exts,
}))
} else {
Box::new(next)
Ok(Box::new(next))
}
}
@@ -377,7 +380,7 @@ pub(super) async fn process_alpn_protocol(
if !config.alpn_protocols.contains(alpn_protocol) {
return Err(common
.illegal_param("server sent non-offered ALPN protocol")
.await);
.await?);
}
}
@@ -392,7 +395,7 @@ pub(super) fn sct_list_is_invalid(scts: &SCTList) -> bool {
scts.is_empty() || scts.iter().any(|sct| sct.0.is_empty())
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectServerHello {
async fn handle(
mut self: Box<Self>,
@@ -429,7 +432,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
return Err(cx
.common
.illegal_param("server chose v1.2 using v1.3 extension")
.await);
.await?);
}
TLSv1_2
@@ -437,7 +440,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
_ => {
cx.common
.send_fatal_alert(AlertDescription::ProtocolVersion)
.await;
.await?;
let msg = match server_version {
TLSv1_2 | TLSv1_3 => "server's TLS version is disabled in client",
_ => "server does not support TLS v1.2/v1.3",
@@ -446,19 +449,19 @@ impl State<ClientConnectionData> for ExpectServerHello {
}
};
cx.common.handshaker.select_protocol_version(version)?;
cx.common.crypto.select_protocol_version(version)?;
if server_hello.compression_method != Compression::Null {
return Err(cx
.common
.illegal_param("server chose non-Null compression")
.await);
.await?);
}
if server_hello.has_duplicate_extension() {
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"server sent duplicate extensions".to_string(),
));
@@ -471,7 +474,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
{
cx.common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"server sent unsolicited extension".to_string(),
));
@@ -491,7 +494,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
if !point_fmts.contains(&ECPointFormat::Uncompressed) {
cx.common
.send_fatal_alert(AlertDescription::HandshakeFailure)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"server does not support uncompressed points".to_string(),
));
@@ -503,7 +506,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
None => {
cx.common
.send_fatal_alert(AlertDescription::HandshakeFailure)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"server chose non-offered ciphersuite".to_string(),
));
@@ -514,7 +517,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
return Err(cx
.common
.illegal_param("server chose unusable ciphersuite for version")
.await);
.await?);
}
match self.suite {
@@ -522,13 +525,13 @@ impl State<ClientConnectionData> for ExpectServerHello {
return Err(cx
.common
.illegal_param("server varied selected ciphersuite")
.await);
.await?);
}
_ => {
debug!("Using ciphersuite {:?}", suite);
self.suite = Some(suite);
cx.common.suite = Some(suite);
cx.common.handshaker.select_cipher_suite(suite)?;
cx.common.crypto.select_cipher_suite(suite)?;
}
}
@@ -620,7 +623,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server requested hrr with our group")
.await);
.await?);
}
// Or has an empty cookie.
@@ -629,7 +632,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server requested hrr with empty cookie")
.await);
.await?);
}
}
@@ -637,7 +640,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
if hrr.has_unknown_extension() {
cx.common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
.await?;
return Err(Error::PeerIncompatibleError(
"server sent hrr with unhandled extension".to_string(),
));
@@ -648,7 +651,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server send duplicate hrr extensions")
.await);
.await?);
}
// Or asks us to change nothing.
@@ -656,7 +659,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server requested hrr with no changes")
.await);
.await?);
}
// Or asks us to talk a protocol we didn't offer, or doesn't support HRR at all.
@@ -668,7 +671,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server requested unsupported version in hrr")
.await);
.await?);
}
}
@@ -680,7 +683,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server requested unsupported cs in hrr")
.await);
.await?);
}
};
@@ -705,7 +708,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
return Err(cx
.common
.illegal_param("server requested hrr with bad group")
.await);
.await?);
}
_ => offered_key_share,
};
@@ -727,11 +730,11 @@ impl ExpectServerHelloOrHelloRetryRequest {
may_send_sct_list,
Some(cs),
)
.await)
.await?)
}
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
async fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
match m.payload {
@@ -752,22 +755,27 @@ impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
}
}
pub(super) async fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error {
pub(super) async fn send_cert_error_alert(
common: &mut CommonState,
err: Error,
) -> Result<Error, Error> {
match err {
Error::InvalidCertificateEncoding => {
common.send_fatal_alert(AlertDescription::DecodeError).await;
common
.send_fatal_alert(AlertDescription::DecodeError)
.await?;
}
Error::PeerMisbehavedError(_) => {
common
.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
.await?;
}
_ => {
common
.send_fatal_alert(AlertDescription::BadCertificate)
.await;
.await?;
}
};
err
Ok(err)
}

View File

@@ -72,7 +72,7 @@ mod server_hello {
return Err(cx
.common
.illegal_param("downgrade to TLS1.2 when TLS1.3 is supported")
.await);
.await?);
}
// Doing EMS?
@@ -209,7 +209,7 @@ struct ExpectCertificate {
server_cert_sct_list: Option<SCTList>,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificate {
async fn handle(
mut self: Box<Self>,
@@ -271,7 +271,7 @@ struct ExpectCertificateStatusOrServerKx {
must_issue_new_ticket: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
async fn handle(
self: Box<Self>,
@@ -348,7 +348,7 @@ struct ExpectCertificateStatus {
must_issue_new_ticket: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificateStatus {
async fn handle(
mut self: Box<Self>,
@@ -402,7 +402,7 @@ struct ExpectServerKx {
must_issue_new_ticket: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectServerKx {
async fn handle(
mut self: Box<Self>,
@@ -422,7 +422,7 @@ impl State<ClientConnectionData> for ExpectServerKx {
// We only support ECDHE
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
};
@@ -457,7 +457,7 @@ async fn emit_certificate(
transcript: &mut HandshakeHash,
cert_chain: CertificatePayload,
common: &mut CommonState,
) {
) -> Result<(), Error> {
let cert = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(HandshakeMessagePayload {
@@ -467,14 +467,14 @@ async fn emit_certificate(
};
transcript.add_message(&cert);
common.send_msg(cert, false).await;
common.send_msg(cert, false).await
}
async fn emit_clientkx(
transcript: &mut HandshakeHash,
common: &mut CommonState,
pubkey: &PublicKey,
) {
) -> Result<(), Error> {
let ecpoint = PayloadU8::new(pubkey.key.clone());
let mut buf = Vec::new();
@@ -490,7 +490,7 @@ async fn emit_clientkx(
};
transcript.add_message(&ckx);
common.send_msg(ckx, false).await;
common.send_msg(ckx, false).await
}
async fn emit_certverify(
@@ -515,24 +515,23 @@ async fn emit_certverify(
};
transcript.add_message(&m);
common.send_msg(m, false).await;
Ok(())
common.send_msg(m, false).await
}
async fn emit_ccs(common: &mut CommonState) {
async fn emit_ccs(common: &mut CommonState) -> Result<(), Error> {
let ccs = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
};
common.send_msg(ccs, false).await;
common.send_msg(ccs, false).await
}
async fn emit_finished(
verify_data: &[u8],
transcript: &mut HandshakeHash,
common: &mut CommonState,
) {
) -> Result<(), Error> {
let verify_data_payload = Payload::new(verify_data);
let f = Message {
@@ -544,7 +543,7 @@ async fn emit_finished(
};
transcript.add_message(&f);
common.send_msg(f, true).await;
common.send_msg(f, true).await
}
struct ServerKxDetails {
@@ -578,7 +577,7 @@ struct ExpectServerDoneOrCertReq {
must_issue_new_ticket: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
async fn handle(
mut self: Box<Self>,
@@ -644,7 +643,7 @@ struct ExpectCertificateRequest {
must_issue_new_ticket: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificateRequest {
async fn handle(
mut self: Box<Self>,
@@ -705,7 +704,7 @@ struct ExpectServerDone {
must_issue_new_ticket: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectServerDone {
async fn handle(
self: Box<Self>,
@@ -764,7 +763,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
now,
) {
Ok(cert_verified) => cert_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
};
// 3.
@@ -794,7 +793,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
sig,
) {
Ok(sig_verified) => sig_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
}
};
cx.common.peer_certificates = Some(st.server_cert.cert_chain);
@@ -805,7 +804,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
ClientAuthDetails::Empty { .. } => Vec::new(),
ClientAuthDetails::Verify { certkey, .. } => certkey.cert.clone(),
};
emit_certificate(&mut st.transcript, certs, cx.common).await;
emit_certificate(&mut st.transcript, certs, cx.common).await?;
}
// 5a.
@@ -815,12 +814,12 @@ impl State<ClientConnectionData> for ExpectServerDone {
None => {
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
};
let key_share = cx.common.handshaker.client_key_share().await?;
let key_share = cx.common.crypto.client_key_share().await?;
if key_share.group != ecdh_params.curve_params.named_group {
return Err(Error::PeerMisbehavedError(
"peer chose an unsupported group".to_string(),
@@ -829,7 +828,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
// 5b.
let mut transcript = st.transcript;
emit_clientkx(&mut transcript, cx.common, &key_share).await;
emit_clientkx(&mut transcript, cx.common, &key_share).await?;
// nb. EMS handshake hash only runs up to ClientKeyExchange.
let _ems_seed = st.using_ems.then(|| transcript.get_current_hash());
@@ -839,31 +838,25 @@ impl State<ClientConnectionData> for ExpectServerDone {
}
// 5d.
emit_ccs(cx.common).await;
emit_ccs(cx.common).await?;
// 5e. Now commit secrets.
let server_key_share =
PublicKey::new(ecdh_params.curve_params.named_group, &ecdh_params.public.0);
cx.common
.handshaker
.crypto
.set_server_key_share(server_key_share)
.await?;
let enc = cx.common.handshaker.message_encrypter().await?;
let dec = cx.common.handshaker.message_decrypter().await?;
cx.common.record_layer.set_message_encrypter(enc);
cx.common.record_layer.set_message_decrypter(dec);
st.config
.key_log
.log("CLIENT_RANDOM", &st.randoms.client, &[]);
// 6.
let hs = transcript.get_current_hash();
let cf = cx.common.handshaker.client_finished(hs.as_ref()).await?;
emit_finished(&cf, &mut transcript, cx.common).await;
let cf = cx.common.crypto.client_finished(hs.as_ref()).await?;
emit_finished(&cf, &mut transcript, cx.common).await?;
if st.must_issue_new_ticket {
Ok(Box::new(ExpectNewTicket {
@@ -906,7 +899,7 @@ struct ExpectNewTicket {
sig_verified: verify::HandshakeSignatureValid,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectNewTicket {
async fn handle(
mut self: Box<Self>,
@@ -950,7 +943,7 @@ struct ExpectCcs {
sig_verified: verify::HandshakeSignatureValid,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCcs {
async fn handle(
self: Box<Self>,
@@ -1055,7 +1048,7 @@ struct ExpectFinished {
// }
// }
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectFinished {
async fn handle(
self: Box<Self>,
@@ -1070,7 +1063,7 @@ impl State<ClientConnectionData> for ExpectFinished {
// Work out what verify_data we expect.
let vh = st.transcript.get_current_hash();
let expect_verify_data = cx.common.handshaker.server_finished(vh.as_ref()).await?;
let expect_verify_data = cx.common.crypto.server_finished(vh.as_ref()).await?;
// Constant-time verification of this is relatively unimportant: they only
// get one chance. But it can't hurt.
@@ -1080,7 +1073,7 @@ impl State<ClientConnectionData> for ExpectFinished {
Err(_) => {
cx.common
.send_fatal_alert(AlertDescription::DecryptError)
.await;
.await?;
return Err(Error::DecryptError);
}
};
@@ -1091,12 +1084,12 @@ impl State<ClientConnectionData> for ExpectFinished {
// st.save_session(cx);
if st.resuming {
emit_ccs(cx.common).await;
emit_ccs(cx.common).await?;
cx.common.record_layer.start_encrypting();
emit_finished(&expect_verify_data, &mut st.transcript, cx.common).await;
emit_finished(&expect_verify_data, &mut st.transcript, cx.common).await?;
}
cx.common.start_traffic().await;
cx.common.start_traffic().await?;
Ok(Box::new(ExpectTraffic {
_cert_verified: st.cert_verified,
_sig_verified: st.sig_verified,
@@ -1112,7 +1105,7 @@ struct ExpectTraffic {
_fin_verified: verify::FinishedMessageVerified,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectTraffic {
async fn handle(
self: Box<Self>,

View File

@@ -1,5 +1,6 @@
use crate::check::inappropriate_handshake_message;
use crate::conn::{CommonState, ConnectionRandoms, State};
use crate::crypto::{DecryptMode, EncryptMode};
use crate::error::Error;
use crate::hash_hs::{HandshakeHash, HandshakeHashBuffer};
#[cfg(feature = "logging")]
@@ -70,17 +71,17 @@ pub(super) async fn handle_server_hello(
None => {
cx.common
.send_fatal_alert(AlertDescription::MissingExtension)
.await;
.await?;
return Err(Error::PeerMisbehavedError("missing key share".to_string()));
}
};
if our_key_share.group != their_key_share.group {
return Err(cx.common.illegal_param("wrong group for key share").await);
return Err(cx.common.illegal_param("wrong group for key share").await?);
}
cx.common
.handshaker
.crypto
.set_server_key_share(their_key_share.clone().into())
.await?;
@@ -115,22 +116,19 @@ pub(super) async fn handle_server_hello(
cx.common.check_aligned_handshake().await?;
cx.common
.handshaker
.crypto
.set_hs_hash_server_hello(transcript.get_current_hash().as_ref())
.await?;
let dec = cx.common.handshaker.message_decrypter().await?;
let enc = cx.common.handshaker.message_encrypter().await?;
// Decrypt with the peer's key, encrypt with our own key
cx.common.record_layer.set_message_decrypter(dec);
cx.common.crypto.set_decrypt(DecryptMode::Handshake)?;
if !cx.data.early_data.is_enabled() {
// Set the client encryption key for handshakes if early data is not used
cx.common.record_layer.set_message_encrypter(enc);
cx.common.crypto.set_encrypt(EncryptMode::Handshake)?;
}
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await?;
Ok(Box::new(ExpectEncryptedExtensions {
config,
@@ -151,7 +149,7 @@ async fn validate_server_hello(
if !ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type()) {
common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"server sent unexpected cleartext ext".to_string(),
));
@@ -224,16 +222,19 @@ async fn validate_server_hello(
// exts.push(ClientExtension::PresharedKey(psk_ext));
// }
pub(super) async fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) {
pub(super) async fn emit_fake_ccs(
sent_tls13_fake_ccs: &mut bool,
common: &mut CommonState,
) -> Result<(), Error> {
if std::mem::replace(sent_tls13_fake_ccs, true) {
return;
return Ok(());
}
let m = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
};
common.send_msg(m, false).await;
common.send_msg(m, false).await
}
async fn validate_encrypted_extensions(
@@ -242,7 +243,9 @@ async fn validate_encrypted_extensions(
exts: &EncryptedExtensions,
) -> Result<(), Error> {
if exts.has_duplicate_extension() {
common.send_fatal_alert(AlertDescription::DecodeError).await;
common
.send_fatal_alert(AlertDescription::DecodeError)
.await?;
return Err(Error::PeerMisbehavedError(
"server sent duplicate encrypted extensions".to_string(),
));
@@ -251,7 +254,7 @@ async fn validate_encrypted_extensions(
if hello.server_sent_unsolicited_extensions(exts, &[]) {
common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
.await?;
let msg = "server sent unsolicited encrypted extension".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
@@ -262,7 +265,7 @@ async fn validate_encrypted_extensions(
{
common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
.await?;
let msg = "server sent inappropriate encrypted extension".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
@@ -281,7 +284,7 @@ struct ExpectEncryptedExtensions {
hello: ClientHelloDetails,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectEncryptedExtensions {
async fn handle(
mut self: Box<Self>,
@@ -311,9 +314,8 @@ impl State<ClientConnectionData> for ExpectEncryptedExtensions {
}
if was_early_traffic && !cx.common.early_traffic {
let enc = cx.common.handshaker.message_encrypter().await?;
// If no early traffic, set the encryption key for handshakes
cx.common.record_layer.set_message_encrypter(enc);
cx.common.record_layer.set_message_encrypter();
}
cx.common.peer_certificates = Some(resuming_session.server_cert_chain().to_vec());
@@ -358,7 +360,7 @@ struct ExpectCertificateOrCertReq {
may_send_sct_list: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificateOrCertReq {
async fn handle(
self: Box<Self>,
@@ -421,7 +423,7 @@ struct ExpectCertificateRequest {
may_send_sct_list: bool,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificateRequest {
async fn handle(
mut self: Box<Self>,
@@ -444,7 +446,7 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
warn!("Server sent non-empty certreq context");
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
@@ -461,7 +463,7 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
if compat_sigschemes.is_empty() {
cx.common
.send_fatal_alert(AlertDescription::HandshakeFailure)
.await;
.await?;
return Err(Error::PeerIncompatibleError(
"server sent bad certreq schemes".to_string(),
));
@@ -496,7 +498,7 @@ struct ExpectCertificate {
client_auth: Option<ClientAuthDetails>,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificate {
async fn handle(
mut self: Box<Self>,
@@ -515,7 +517,7 @@ impl State<ClientConnectionData> for ExpectCertificate {
warn!("certificate with non-empty context during handshake");
cx.common
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
@@ -525,7 +527,7 @@ impl State<ClientConnectionData> for ExpectCertificate {
warn!("certificate chain contains unsolicited/unknown extension");
cx.common
.send_fatal_alert(AlertDescription::UnsupportedExtension)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"bad cert chain extensions".to_string(),
));
@@ -572,7 +574,7 @@ struct ExpectCertificateVerify {
client_auth: Option<ClientAuthDetails>,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectCertificateVerify {
async fn handle(
mut self: Box<Self>,
@@ -603,7 +605,7 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
now,
) {
Ok(cert_verified) => cert_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
};
// 2. Verify their signature on the handshake.
@@ -614,7 +616,7 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
cert_verify,
) {
Ok(sig_verified) => sig_verified,
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
};
cx.common.peer_certificates = Some(self.server_cert.cert_chain);
@@ -638,7 +640,7 @@ async fn emit_certificate_tls13(
certkey: Option<&CertifiedKey>,
auth_context: Option<Vec<u8>>,
common: &mut CommonState,
) {
) -> Result<(), Error> {
let context = auth_context.unwrap_or_default();
let mut cert_payload = CertificatePayloadTLS13 {
@@ -662,7 +664,7 @@ async fn emit_certificate_tls13(
}),
};
transcript.add_message(&m);
common.send_msg(m, true).await;
common.send_msg(m, true).await
}
async fn emit_certverify_tls13(
@@ -685,15 +687,14 @@ async fn emit_certverify_tls13(
};
transcript.add_message(&m);
common.send_msg(m, true).await;
Ok(())
common.send_msg(m, true).await
}
async fn emit_finished_tls13(
verify_data: &[u8],
transcript: &mut HandshakeHash,
common: &mut CommonState,
) {
) -> Result<(), Error> {
let verify_data_payload = Payload::new(verify_data);
let m = Message {
@@ -705,10 +706,13 @@ async fn emit_finished_tls13(
};
transcript.add_message(&m);
common.send_msg(m, true).await;
common.send_msg(m, true).await
}
async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) {
async fn emit_end_of_early_data_tls13(
transcript: &mut HandshakeHash,
common: &mut CommonState,
) -> Result<(), Error> {
let m = Message {
version: ProtocolVersion::TLSv1_3,
payload: MessagePayload::Handshake(HandshakeMessagePayload {
@@ -718,7 +722,7 @@ async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &m
};
transcript.add_message(&m);
common.send_msg(m, true).await;
common.send_msg(m, true).await
}
struct ExpectFinished {
@@ -732,7 +736,7 @@ struct ExpectFinished {
sig_verified: verify::HandshakeSignatureValid,
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectFinished {
async fn handle(
self: Box<Self>,
@@ -746,7 +750,7 @@ impl State<ClientConnectionData> for ExpectFinished {
let handshake_hash = st.transcript.get_current_hash();
let expect_verify_data = cx
.common
.handshaker
.crypto
.server_finished(handshake_hash.as_ref())
.await?;
@@ -758,7 +762,7 @@ impl State<ClientConnectionData> for ExpectFinished {
Err(_) => {
cx.common
.send_fatal_alert(AlertDescription::DecryptError)
.await;
.await?;
return Err(Error::DecryptError);
}
};
@@ -768,11 +772,10 @@ impl State<ClientConnectionData> for ExpectFinished {
/* The EndOfEarlyData message to server is still encrypted with early data keys,
* but appears in the transcript after the server Finished. */
if cx.common.early_traffic {
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await;
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await?;
cx.common.early_traffic = false;
cx.data.early_data.finished();
let enc = cx.common.handshaker.message_encrypter().await?;
cx.common.record_layer.set_message_encrypter(enc);
cx.common.crypto.set_encrypt(EncryptMode::Handshake)?;
}
/* Send our authentication/finished messages. These are still encrypted
@@ -782,7 +785,8 @@ impl State<ClientConnectionData> for ExpectFinished {
ClientAuthDetails::Empty {
auth_context_tls13: auth_context,
} => {
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common).await;
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common)
.await?;
}
ClientAuthDetails::Verify {
certkey,
@@ -795,7 +799,7 @@ impl State<ClientConnectionData> for ExpectFinished {
auth_context,
cx.common,
)
.await;
.await?;
emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common).await?;
}
}
@@ -804,21 +808,18 @@ impl State<ClientConnectionData> for ExpectFinished {
let handshake_hash = st.transcript.get_current_hash();
let client_finished = cx
.common
.handshaker
.crypto
.client_finished(handshake_hash.as_ref())
.await?;
emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await;
emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await?;
/* Now move to our application traffic keys. */
cx.common.check_aligned_handshake().await?;
let dec = cx.common.handshaker.message_decrypter().await?;
cx.common.record_layer.set_message_decrypter(dec);
cx.common.crypto.set_encrypt(EncryptMode::Application)?;
cx.common.crypto.set_decrypt(DecryptMode::Application)?;
let enc = cx.common.handshaker.message_encrypter().await?;
cx.common.record_layer.set_message_encrypter(enc);
cx.common.start_traffic().await;
cx.common.start_traffic().await?;
let st = ExpectTraffic {
session_storage: Arc::clone(&st.config.session_storage),
@@ -859,7 +860,7 @@ impl ExpectTraffic {
if nst.has_duplicate_extension() {
cx.common
.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"peer sent duplicate NewSessionTicket extensions".into(),
));
@@ -915,7 +916,7 @@ impl ExpectTraffic {
// Client does not support key updates
common
.send_fatal_alert(AlertDescription::InternalError)
.await;
.await?;
return Err(Error::General(
"received unsupported key update request from peer".to_string(),
));
@@ -943,7 +944,7 @@ impl ExpectTraffic {
}
}
#[async_trait]
#[async_trait(?Send)]
impl State<ClientConnectionData> for ExpectTraffic {
async fn handle(
mut self: Box<Self>,

View File

@@ -1,6 +1,6 @@
use crate::client::ClientConnectionData;
use crate::crypto::{Crypto, InvalidCrypto};
use crate::error::Error;
use crate::handshaker::{Handshake, InvalidHandShaker};
#[cfg(feature = "logging")]
use crate::log::{debug, error, trace, warn};
use crate::record_layer;
@@ -339,7 +339,7 @@ impl ConnectionCommon {
if self.handshake_joiner.take_message(msg).is_none() {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
@@ -369,7 +369,7 @@ impl ConnectionCommon {
// handshake with an "unexpected_message" alert."
self.common_state
.send_fatal_alert(AlertDescription::UnexpectedMessage)
.await;
.await?;
return Err(Error::PeerMisbehavedError(
"illegal middlebox CCS received".into(),
));
@@ -406,7 +406,7 @@ impl ConnectionCommon {
None => {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError)
.await;
.await?;
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
}
@@ -488,15 +488,15 @@ impl ConnectionCommon {
}
/// Write buffer into connection
pub async fn write_plaintext(&mut self, buf: &[u8]) -> io::Result<usize> {
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
if let Ok(st) = &mut self.state {
st.perhaps_write_key_update(&mut self.common_state).await;
}
Ok(self.common_state.send_some_plaintext(buf).await)
self.common_state.send_some_plaintext(buf).await
}
/// Write entire buffer into connection
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> io::Result<usize> {
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
let mut pos = 0;
while pos < buf.len() {
pos += self.write_plaintext(&buf[pos..]).await?;
@@ -572,7 +572,7 @@ pub struct CommonState {
pub(crate) negotiated_version: Option<ProtocolVersion>,
pub(crate) side: Side,
pub(crate) record_layer: record_layer::RecordLayer,
pub(crate) handshaker: Box<dyn Handshake>,
pub(crate) crypto: Box<dyn Crypto>,
pub(crate) suite: Option<SupportedCipherSuite>,
pub(crate) alpn_protocol: Option<Vec<u8>>,
aligned_handshake: bool,
@@ -600,7 +600,7 @@ impl CommonState {
negotiated_version: None,
side,
record_layer: record_layer::RecordLayer::new(),
handshaker: Box::new(InvalidHandShaker {}),
crypto: Box::new(InvalidCrypto {}),
suite: None,
alpn_protocol: None,
aligned_handshake: true,
@@ -697,7 +697,7 @@ impl CommonState {
};
if msg.is_handshake_type(reject_ty) {
self.send_warning_alert(AlertDescription::NoRenegotiation)
.await;
.await?;
return Ok(state);
}
}
@@ -711,7 +711,7 @@ impl CommonState {
Err(e @ Error::InappropriateMessage { .. })
| Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
.await;
.await?;
Err(e)
}
Err(e) => Err(e),
@@ -723,7 +723,7 @@ impl CommonState {
///
/// If internal buffers are too small, this function will not accept
/// all the data.
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_plain(data, Limit::Yes).await
}
@@ -734,7 +734,7 @@ impl CommonState {
pub(crate) async fn check_aligned_handshake(&mut self) -> Result<(), Error> {
if !self.aligned_handshake {
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
.await;
.await?;
Err(Error::PeerMisbehavedError(
"key epoch or handshake flight with pending fragment".to_string(),
))
@@ -743,10 +743,10 @@ impl CommonState {
}
}
pub(crate) async fn illegal_param(&mut self, why: &str) -> Error {
pub(crate) async fn illegal_param(&mut self, why: &str) -> Result<Error, Error> {
self.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
Error::PeerMisbehavedError(why.to_string())
.await?;
Ok(Error::PeerMisbehavedError(why.to_string()))
}
pub(crate) async fn decrypt_incoming(
@@ -754,16 +754,19 @@ impl CommonState {
encr: OpaqueMessage,
) -> Result<Option<PlainMessage>, Error> {
if self.record_layer.wants_close_before_decrypt() {
self.send_close_notify().await;
self.send_close_notify().await?;
}
let encrypted_len = encr.payload.0.len();
let plain = self.record_layer.decrypt_incoming(encr).await;
let plain = self
.record_layer
.decrypt_incoming(self.crypto.as_mut(), encr)
.await;
match plain {
Err(Error::PeerSentOversizedRecord) => {
self.send_fatal_alert(AlertDescription::RecordOverflow)
.await;
.await?;
Err(Error::PeerSentOversizedRecord)
}
Err(Error::DecryptError) if self.record_layer.doing_trial_decryption(encrypted_len) => {
@@ -771,7 +774,8 @@ impl CommonState {
Ok(None)
}
Err(Error::DecryptError) => {
self.send_fatal_alert(AlertDescription::BadRecordMac).await;
self.send_fatal_alert(AlertDescription::BadRecordMac)
.await?;
Err(Error::DecryptError)
}
Err(e) => Err(e),
@@ -781,18 +785,26 @@ impl CommonState {
/// Fragment `m`, encrypt the fragments, and then queue
/// the encrypted fragments for sending.
#[async_recursion]
pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) {
pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) -> Result<(), Error> {
let mut plain_messages = VecDeque::new();
self.message_fragmenter.fragment(m, &mut plain_messages);
for m in plain_messages {
self.send_single_fragment(m).await;
// Close connection once we start to run out of
// sequence space.
if self.record_layer.wants_close_before_encrypt() {
debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
let m = Message::build_alert(AlertLevel::Warning, AlertDescription::CloseNotify);
self.send_single_fragment(m.into()).await?;
}
for m in plain_messages {
self.send_single_fragment(m).await?;
}
Ok(())
}
/// Like send_msg_encrypt, but operate on an appdata directly.
async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> Result<usize, Error> {
// Here, the limit on sendable_tls applies to encrypted data,
// but we're respecting it for plaintext data -- so we'll
// be out by whatever the cipher+record overhead is. That's a
@@ -813,27 +825,25 @@ impl CommonState {
);
for m in plain_messages {
self.send_single_fragment(m).await;
self.send_single_fragment(m).await?;
}
len
Ok(len)
}
async fn send_single_fragment(&mut self, m: PlainMessage) {
// Close connection once we start to run out of
// sequence space.
if self.record_layer.wants_close_before_encrypt() {
self.send_close_notify().await;
}
async fn send_single_fragment(&mut self, m: PlainMessage) -> Result<(), Error> {
// Refuse to wrap counter at all costs. This
// is basically untestable unfortunately.
if self.record_layer.encrypt_exhausted() {
return;
return Err(Error::EncryptError);
}
let em = self.record_layer.encrypt_outgoing(m).await;
let em = self
.record_layer
.encrypt_outgoing(self.crypto.as_mut(), m)
.await?;
self.queue_tls_message(em);
Ok(())
}
/// Writes TLS messages to `wr`.
@@ -852,7 +862,7 @@ impl CommonState {
///
/// Returns the number of bytes written from `data`: this might
/// be less than `data.len()` if buffer limits were exceeded.
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
if !self.may_send_application_data {
// If we haven't completed handshaking, buffer
// plaintext to send once we do.
@@ -860,27 +870,27 @@ impl CommonState {
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
Limit::No => self.sendable_plaintext.append(data.to_vec()),
};
return len;
return Ok(len);
}
debug_assert!(self.record_layer.is_encrypting());
if data.is_empty() {
// Don't send empty fragments.
return 0;
return Ok(0);
}
self.send_appdata_encrypt(data, limit).await
}
pub(crate) async fn start_outgoing_traffic(&mut self) {
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
self.may_send_application_data = true;
self.flush_plaintext().await;
self.flush_plaintext().await
}
pub(crate) async fn start_traffic(&mut self) {
pub(crate) async fn start_traffic(&mut self) -> Result<(), Error> {
self.may_receive_application_data = true;
self.start_outgoing_traffic().await;
self.start_outgoing_traffic().await
}
/// Sets a limit on the internal buffers used to buffer
@@ -929,14 +939,16 @@ impl CommonState {
/// Send any buffered plaintext. Plaintext is buffered if
/// written during handshake.
async fn flush_plaintext(&mut self) {
async fn flush_plaintext(&mut self) -> Result<(), Error> {
if !self.may_send_application_data {
return;
return Ok(());
}
while let Some(buf) = self.sendable_plaintext.pop() {
self.send_plain(&buf, Limit::No).await;
self.send_plain(&buf, Limit::No).await?;
}
Ok(())
}
// Put m into sendable_tls for writing.
@@ -945,15 +957,16 @@ impl CommonState {
}
/// Send a raw TLS message, fragmenting it if needed.
pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) {
pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) -> Result<(), Error> {
if !must_encrypt {
let mut to_send = VecDeque::new();
self.message_fragmenter.fragment(m.into(), &mut to_send);
for mm in to_send {
self.queue_tls_message(mm.into_unencrypted_opaque());
}
return Ok(());
} else {
self.send_msg_encrypt(m.into()).await;
self.send_msg_encrypt(m.into()).await
}
}
@@ -961,16 +974,16 @@ impl CommonState {
self.received_plaintext.append(bytes.0);
}
async fn send_warning_alert(&mut self, desc: AlertDescription) {
async fn send_warning_alert(&mut self, desc: AlertDescription) -> Result<(), Error> {
warn!("Sending warning alert {:?}", desc);
self.send_warning_alert_no_log(desc).await;
self.send_warning_alert_no_log(desc).await
}
async fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
// Reject unknown AlertLevels.
if let AlertLevel::Unknown(_) = alert.level {
self.send_fatal_alert(AlertDescription::IllegalParameter)
.await;
.await?;
}
// If we get a CloseNotify, make a note to declare EOF to our
@@ -984,7 +997,7 @@ impl CommonState {
// (except, for no good reason, user_cancelled).
if alert.level == AlertLevel::Warning {
if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
self.send_fatal_alert(AlertDescription::DecodeError).await;
self.send_fatal_alert(AlertDescription::DecodeError).await?;
} else {
warn!("TLS alert warning received: {:#?}", alert);
return Ok(());
@@ -995,26 +1008,27 @@ impl CommonState {
Err(Error::AlertReceived(alert.description))
}
pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) {
pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) -> Result<(), Error> {
warn!("Sending fatal alert {:?}", desc);
debug_assert!(!self.sent_fatal_alert);
let m = Message::build_alert(AlertLevel::Fatal, desc);
self.send_msg(m, self.record_layer.is_encrypting()).await;
self.send_msg(m, self.record_layer.is_encrypting()).await?;
self.sent_fatal_alert = true;
Ok(())
}
/// Queues a close_notify warning alert to be sent in the next
/// [`CommonState::write_tls`] call. This informs the peer that the
/// connection is being closed.
pub async fn send_close_notify(&mut self) {
pub async fn send_close_notify(&mut self) -> Result<(), Error> {
debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
self.send_warning_alert_no_log(AlertDescription::CloseNotify)
.await;
.await
}
async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) -> Result<(), Error> {
let m = Message::build_alert(AlertLevel::Warning, desc);
self.send_msg(m, self.record_layer.is_encrypting()).await;
self.send_msg(m, self.record_layer.is_encrypting()).await
}
pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
@@ -1054,7 +1068,7 @@ impl CommonState {
}
}
#[async_trait]
#[async_trait(?Send)]
pub(crate) trait State<ClientConnectionData>: Send + Sync {
async fn start(
self: Box<Self>,

View File

@@ -1,16 +1,35 @@
use crate::Error;
use tls_core::msgs::enums::ProtocolVersion;
use tls_core::msgs::handshake::Random;
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
use tls_core::{key::PublicKey, suites::SupportedCipherSuite};
use async_trait::async_trait;
use crate::cipher::{MessageDecrypter, MessageEncrypter};
#[derive(Debug, Clone)]
pub enum EncryptMode {
/// Encrypt payload with PSK
EarlyData,
/// Encrypt payload with Handshake keys
Handshake,
/// Encrypt payload with Application traffic keys
Application,
}
#[derive(Debug, Clone)]
pub enum DecryptMode {
/// Decrypt payload with Handshake keys
Handshake,
/// Decrypt payload with Application traffic keys
Application,
}
/// Core trait which manages crypto operations for the TLS connection such as key exchange, encryption
/// and decryption.
#[async_trait]
pub trait Handshake: Send {
pub trait Crypto: Send {
/// Signals selected protocol version to implementor.
/// Throws error if version is not supported.
fn select_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), Error>;
@@ -19,6 +38,10 @@ pub trait Handshake: Send {
fn select_cipher_suite(&mut self, suite: SupportedCipherSuite) -> Result<(), Error>;
/// Returns configured cipher suite.
fn suite(&self) -> Result<SupportedCipherSuite, Error>;
/// Set encryption mode
fn set_encrypt(&mut self, mode: EncryptMode) -> Result<(), Error>;
/// Start decryption
fn set_decrypt(&mut self, mode: DecryptMode) -> Result<(), Error>;
/// Returns client_random value.
async fn client_random(&mut self) -> Result<Random, Error>;
/// Returns public client keyshare.
@@ -33,16 +56,16 @@ pub trait Handshake: Send {
async fn server_finished(&mut self, hash: &[u8]) -> Result<Vec<u8>, Error>;
/// Returns ClientFinished verify_data.
async fn client_finished(&mut self, hash: &[u8]) -> Result<Vec<u8>, Error>;
/// Returns initialized MessageEncrypter.
async fn message_encrypter(&mut self) -> Result<Box<dyn MessageEncrypter>, Error>;
/// Returns initialized MessageDecrypter.
async fn message_decrypter(&mut self) -> Result<Box<dyn MessageDecrypter>, Error>;
/// Perform the encryption over the concerned TLS message.
async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
/// Perform the decryption over the concerned TLS message.
async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
}
pub struct InvalidHandShaker {}
pub struct InvalidCrypto {}
#[async_trait]
impl Handshake for InvalidHandShaker {
impl Crypto for InvalidCrypto {
fn select_protocol_version(&mut self, _version: ProtocolVersion) -> Result<(), Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
@@ -52,6 +75,14 @@ impl Handshake for InvalidHandShaker {
fn suite(&self) -> Result<SupportedCipherSuite, Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
/// Start encryption
fn set_encrypt(&mut self, _mode: EncryptMode) -> Result<(), Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
/// Start decryption
fn set_decrypt(&mut self, _mode: DecryptMode) -> Result<(), Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
async fn client_random(&mut self) -> Result<Random, Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
@@ -73,10 +104,10 @@ impl Handshake for InvalidHandShaker {
async fn client_finished(&mut self, _hash: &[u8]) -> Result<Vec<u8>, Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
async fn message_encrypter(&mut self) -> Result<Box<dyn MessageEncrypter>, Error> {
async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
async fn message_decrypter(&mut self) -> Result<Box<dyn MessageDecrypter>, Error> {
async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
Err(Error::General("handshaker not yet available".to_string()))
}
}

View File

@@ -315,8 +315,8 @@ pub extern crate tls_core;
mod anchors;
mod cipher;
mod conn;
mod crypto;
mod error;
mod handshaker;
mod hash_hs;
mod limited_cache;
mod msgs;
@@ -357,7 +357,7 @@ pub use crate::key_log::{KeyLog, NoKeyLog};
pub use crate::key_log_file::KeyLogFile;
pub use crate::kx::{SupportedKxGroup, ALL_KX_GROUPS};
pub use cipher::{MessageDecrypter, MessageEncrypter};
pub use handshaker::Handshake;
pub use crypto::Crypto;
pub use tls_core::key::{Certificate, PrivateKey};
pub use tls_core::msgs::enums::CipherSuite;
pub use tls_core::msgs::enums::ProtocolVersion;

View File

@@ -3,6 +3,7 @@ use crate::{
InvalidMessageDecrypter, InvalidMessageEncrypter, MessageDecrypter, MessageEncrypter,
},
error::Error,
Crypto,
};
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
@@ -22,8 +23,6 @@ enum DirectionState {
}
pub(crate) struct RecordLayer {
message_encrypter: Box<dyn MessageEncrypter>,
message_decrypter: Box<dyn MessageDecrypter>,
write_seq: u64,
read_seq: u64,
encrypt_state: DirectionState,
@@ -38,8 +37,6 @@ pub(crate) struct RecordLayer {
impl RecordLayer {
pub(crate) fn new() -> Self {
Self {
message_encrypter: Box::new(InvalidMessageEncrypter {}),
message_decrypter: Box::new(InvalidMessageDecrypter {}),
write_seq: 0,
read_seq: 0,
encrypt_state: DirectionState::Invalid,
@@ -71,16 +68,14 @@ impl RecordLayer {
/// Prepare to use the given `MessageEncrypter` for future message encryption.
/// It is not used until you call `start_encrypting`.
pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
self.message_encrypter = cipher;
pub(crate) fn prepare_message_encrypter(&mut self) {
self.write_seq = 0;
self.encrypt_state = DirectionState::Prepared;
}
/// Prepare to use the given `MessageDecrypter` for future message decryption.
/// It is not used until you call `start_decrypting`.
pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
self.message_decrypter = cipher;
pub(crate) fn prepare_message_decrypter(&mut self) {
self.read_seq = 0;
self.decrypt_state = DirectionState::Prepared;
}
@@ -101,15 +96,15 @@ impl RecordLayer {
/// Set and start using the given `MessageEncrypter` for future outgoing
/// message encryption.
pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
self.prepare_message_encrypter(cipher);
pub(crate) fn set_message_encrypter(&mut self) {
self.prepare_message_encrypter();
self.start_encrypting();
}
/// Set and start using the given `MessageDecrypter` for future incoming
/// message decryption.
pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
self.prepare_message_decrypter(cipher);
pub(crate) fn set_message_decrypter(&mut self) {
self.prepare_message_decrypter();
self.start_decrypting();
self.trial_decryption_len = None;
}
@@ -117,12 +112,8 @@ impl RecordLayer {
/// Set and start using the given `MessageDecrypter` for future incoming
/// message decryption, and enable "trial decryption" mode for when TLS1.3
/// 0-RTT is attempted but rejected by the server.
pub(crate) fn set_message_decrypter_with_trial_decryption(
&mut self,
cipher: Box<dyn MessageDecrypter>,
max_length: usize,
) {
self.prepare_message_decrypter(cipher);
pub(crate) fn set_message_decrypter_with_trial_decryption(&mut self, max_length: usize) {
self.prepare_message_decrypter();
self.start_decrypting();
self.trial_decryption_len = Some(max_length);
}
@@ -162,11 +153,12 @@ impl RecordLayer {
/// an error is returned.
pub(crate) async fn decrypt_incoming(
&mut self,
cipher: &mut dyn Crypto,
encr: OpaqueMessage,
) -> Result<PlainMessage, Error> {
debug_assert!(self.is_decrypting());
let seq = self.read_seq;
let msg = self.message_decrypter.decrypt(encr, seq).await?;
let msg = cipher.decrypt(encr, seq).await?;
self.read_seq += 1;
Ok(msg)
}
@@ -175,11 +167,15 @@ impl RecordLayer {
///
/// `plain` is a TLS message we'd like to send. This function
/// panics if the requisite keying material hasn't been established yet.
pub(crate) async fn encrypt_outgoing(&mut self, plain: PlainMessage) -> OpaqueMessage {
pub(crate) async fn encrypt_outgoing(
&mut self,
cipher: &mut dyn Crypto,
plain: PlainMessage,
) -> Result<OpaqueMessage, Error> {
debug_assert!(self.encrypt_state == DirectionState::Active);
assert!(!self.encrypt_exhausted());
let seq = self.write_seq;
self.write_seq += 1;
self.message_encrypter.encrypt(plain, seq).await.unwrap()
cipher.encrypt(plain, seq).await
}
}

View File

@@ -460,7 +460,7 @@ async fn client_close_notify() {
// check that alerts don't overtake appdata
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
client.send_close_notify().await;
client.send_close_notify().await.unwrap();
send(&mut client, &mut server);
let io_state = server.process_new_packets().unwrap();
@@ -973,7 +973,7 @@ where
}
}
fn new_fails(sess: &'a mut C) -> ServerSession<'a, C, S> {
fn _new_fails(sess: &'a mut C) -> ServerSession<'a, C, S> {
let mut os = ServerSession::new(sess);
os.fail_ok = true;
os
@@ -1067,7 +1067,7 @@ where
}
}
fn new_fails(sess: &'a mut C) -> ClientSession<'a, C> {
fn _new_fails(sess: &'a mut C) -> ClientSession<'a, C> {
let mut os = ClientSession::new(sess);
os.fail_ok = true;
os
@@ -1088,7 +1088,7 @@ impl<'a, C> io::Write for ClientSession<'a, C>
where
C: DerefMut + Deref<Target = tls_client::ConnectionCommon>,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
unreachable!()
}