From 452db97dd21281da14d80e27e1868df910fd99bc Mon Sep 17 00:00:00 2001 From: sinuio <> Date: Wed, 15 Jun 2022 16:43:52 -0700 Subject: [PATCH] consolidate traits, convert functions to fallible, and remove recursion --- tls-client/src/client/client_conn.rs | 2 +- tls-client/src/client/hs.rs | 80 ++++++------ tls-client/src/client/tls12.rs | 81 ++++++------ tls-client/src/client/tls13.rs | 115 ++++++++--------- tls-client/src/conn.rs | 132 +++++++++++--------- tls-client/src/{handshaker.rs => crypto.rs} | 49 ++++++-- tls-client/src/lib.rs | 4 +- tls-client/src/record_layer.rs | 38 +++--- tls-client/tests/api.rs | 8 +- 9 files changed, 276 insertions(+), 233 deletions(-) rename tls-client/src/{handshaker.rs => crypto.rs} (68%) diff --git a/tls-client/src/client/client_conn.rs b/tls-client/src/client/client_conn.rs index b34c8bf3a..6b0210dc6 100644 --- a/tls-client/src/client/client_conn.rs +++ b/tls-client/src/client/client_conn.rs @@ -378,7 +378,7 @@ pub struct Initialized { config: Arc, } -#[async_trait] +#[async_trait(?Send)] impl State for Initialized { async fn start( self: Box, diff --git a/tls-client/src/client/hs.rs b/tls-client/src/client/hs.rs index 04c601402..a51551d4b 100644 --- a/tls-client/src/client/hs.rs +++ b/tls-client/src/client/hs.rs @@ -83,7 +83,7 @@ pub(super) async fn start_handshake( let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3); let key_share = if support_tls13 { - Some(cx.common.handshaker.client_key_share().await?) + Some(cx.common.crypto.client_key_share().await?) } else { None }; @@ -113,7 +113,7 @@ pub(super) async fn start_handshake( session_id = Some(SessionID::random()?); } - let random = cx.common.handshaker.client_random().await?; + let random = cx.common.crypto.client_random().await?; let hello_details = ClientHelloDetails::new(); let sent_tls13_fake_ccs = false; let may_send_sct_list = config.verifier.request_scts(); @@ -134,7 +134,7 @@ pub(super) async fn start_handshake( may_send_sct_list, None, ) - .await) + .await?) } struct ExpectServerHello { @@ -172,7 +172,7 @@ async fn emit_client_hello_for_retry( extra_exts: Vec, may_send_sct_list: bool, suite: Option, -) -> NextState { +) -> Result { // For now we do not support session resumption // // Do we have a SessionID or ticket cached for this host? @@ -337,13 +337,13 @@ async fn emit_client_hello_for_retry( if retryreq.is_some() { // send dummy CCS to fool middleboxes prior // to second client hello - tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await; + tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await?; } trace!("Sending ClientHello {:#?}", ch); transcript_buffer.add_message(&ch); - cx.common.send_msg(ch, false).await; + cx.common.send_msg(ch, false).await?; let next = ExpectServerHello { config, @@ -360,9 +360,12 @@ async fn emit_client_hello_for_retry( }; if support_tls13 && retryreq.is_none() { - Box::new(ExpectServerHelloOrHelloRetryRequest { next, extra_exts }) + Ok(Box::new(ExpectServerHelloOrHelloRetryRequest { + next, + extra_exts, + })) } else { - Box::new(next) + Ok(Box::new(next)) } } @@ -377,7 +380,7 @@ pub(super) async fn process_alpn_protocol( if !config.alpn_protocols.contains(alpn_protocol) { return Err(common .illegal_param("server sent non-offered ALPN protocol") - .await); + .await?); } } @@ -392,7 +395,7 @@ pub(super) fn sct_list_is_invalid(scts: &SCTList) -> bool { scts.is_empty() || scts.iter().any(|sct| sct.0.is_empty()) } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectServerHello { async fn handle( mut self: Box, @@ -429,7 +432,7 @@ impl State for ExpectServerHello { return Err(cx .common .illegal_param("server chose v1.2 using v1.3 extension") - .await); + .await?); } TLSv1_2 @@ -437,7 +440,7 @@ impl State for ExpectServerHello { _ => { cx.common .send_fatal_alert(AlertDescription::ProtocolVersion) - .await; + .await?; let msg = match server_version { TLSv1_2 | TLSv1_3 => "server's TLS version is disabled in client", _ => "server does not support TLS v1.2/v1.3", @@ -446,19 +449,19 @@ impl State for ExpectServerHello { } }; - cx.common.handshaker.select_protocol_version(version)?; + cx.common.crypto.select_protocol_version(version)?; if server_hello.compression_method != Compression::Null { return Err(cx .common .illegal_param("server chose non-Null compression") - .await); + .await?); } if server_hello.has_duplicate_extension() { cx.common .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::PeerMisbehavedError( "server sent duplicate extensions".to_string(), )); @@ -471,7 +474,7 @@ impl State for ExpectServerHello { { cx.common .send_fatal_alert(AlertDescription::UnsupportedExtension) - .await; + .await?; return Err(Error::PeerMisbehavedError( "server sent unsolicited extension".to_string(), )); @@ -491,7 +494,7 @@ impl State for ExpectServerHello { if !point_fmts.contains(&ECPointFormat::Uncompressed) { cx.common .send_fatal_alert(AlertDescription::HandshakeFailure) - .await; + .await?; return Err(Error::PeerMisbehavedError( "server does not support uncompressed points".to_string(), )); @@ -503,7 +506,7 @@ impl State for ExpectServerHello { None => { cx.common .send_fatal_alert(AlertDescription::HandshakeFailure) - .await; + .await?; return Err(Error::PeerMisbehavedError( "server chose non-offered ciphersuite".to_string(), )); @@ -514,7 +517,7 @@ impl State for ExpectServerHello { return Err(cx .common .illegal_param("server chose unusable ciphersuite for version") - .await); + .await?); } match self.suite { @@ -522,13 +525,13 @@ impl State for ExpectServerHello { return Err(cx .common .illegal_param("server varied selected ciphersuite") - .await); + .await?); } _ => { debug!("Using ciphersuite {:?}", suite); self.suite = Some(suite); cx.common.suite = Some(suite); - cx.common.handshaker.select_cipher_suite(suite)?; + cx.common.crypto.select_cipher_suite(suite)?; } } @@ -620,7 +623,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server requested hrr with our group") - .await); + .await?); } // Or has an empty cookie. @@ -629,7 +632,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server requested hrr with empty cookie") - .await); + .await?); } } @@ -637,7 +640,7 @@ impl ExpectServerHelloOrHelloRetryRequest { if hrr.has_unknown_extension() { cx.common .send_fatal_alert(AlertDescription::UnsupportedExtension) - .await; + .await?; return Err(Error::PeerIncompatibleError( "server sent hrr with unhandled extension".to_string(), )); @@ -648,7 +651,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server send duplicate hrr extensions") - .await); + .await?); } // Or asks us to change nothing. @@ -656,7 +659,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server requested hrr with no changes") - .await); + .await?); } // Or asks us to talk a protocol we didn't offer, or doesn't support HRR at all. @@ -668,7 +671,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server requested unsupported version in hrr") - .await); + .await?); } } @@ -680,7 +683,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server requested unsupported cs in hrr") - .await); + .await?); } }; @@ -705,7 +708,7 @@ impl ExpectServerHelloOrHelloRetryRequest { return Err(cx .common .illegal_param("server requested hrr with bad group") - .await); + .await?); } _ => offered_key_share, }; @@ -727,11 +730,11 @@ impl ExpectServerHelloOrHelloRetryRequest { may_send_sct_list, Some(cs), ) - .await) + .await?) } } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectServerHelloOrHelloRetryRequest { async fn handle(self: Box, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError { match m.payload { @@ -752,22 +755,27 @@ impl State for ExpectServerHelloOrHelloRetryRequest { } } -pub(super) async fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error { +pub(super) async fn send_cert_error_alert( + common: &mut CommonState, + err: Error, +) -> Result { match err { Error::InvalidCertificateEncoding => { - common.send_fatal_alert(AlertDescription::DecodeError).await; + common + .send_fatal_alert(AlertDescription::DecodeError) + .await?; } Error::PeerMisbehavedError(_) => { common .send_fatal_alert(AlertDescription::IllegalParameter) - .await; + .await?; } _ => { common .send_fatal_alert(AlertDescription::BadCertificate) - .await; + .await?; } }; - err + Ok(err) } diff --git a/tls-client/src/client/tls12.rs b/tls-client/src/client/tls12.rs index 905623a2c..e2f3ad9d9 100644 --- a/tls-client/src/client/tls12.rs +++ b/tls-client/src/client/tls12.rs @@ -72,7 +72,7 @@ mod server_hello { return Err(cx .common .illegal_param("downgrade to TLS1.2 when TLS1.3 is supported") - .await); + .await?); } // Doing EMS? @@ -209,7 +209,7 @@ struct ExpectCertificate { server_cert_sct_list: Option, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificate { async fn handle( mut self: Box, @@ -271,7 +271,7 @@ struct ExpectCertificateStatusOrServerKx { must_issue_new_ticket: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificateStatusOrServerKx { async fn handle( self: Box, @@ -348,7 +348,7 @@ struct ExpectCertificateStatus { must_issue_new_ticket: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificateStatus { async fn handle( mut self: Box, @@ -402,7 +402,7 @@ struct ExpectServerKx { must_issue_new_ticket: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectServerKx { async fn handle( mut self: Box, @@ -422,7 +422,7 @@ impl State for ExpectServerKx { // We only support ECDHE cx.common .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::CorruptMessagePayload(ContentType::Handshake)); } }; @@ -457,7 +457,7 @@ async fn emit_certificate( transcript: &mut HandshakeHash, cert_chain: CertificatePayload, common: &mut CommonState, -) { +) -> Result<(), Error> { let cert = Message { version: ProtocolVersion::TLSv1_2, payload: MessagePayload::Handshake(HandshakeMessagePayload { @@ -467,14 +467,14 @@ async fn emit_certificate( }; transcript.add_message(&cert); - common.send_msg(cert, false).await; + common.send_msg(cert, false).await } async fn emit_clientkx( transcript: &mut HandshakeHash, common: &mut CommonState, pubkey: &PublicKey, -) { +) -> Result<(), Error> { let ecpoint = PayloadU8::new(pubkey.key.clone()); let mut buf = Vec::new(); @@ -490,7 +490,7 @@ async fn emit_clientkx( }; transcript.add_message(&ckx); - common.send_msg(ckx, false).await; + common.send_msg(ckx, false).await } async fn emit_certverify( @@ -515,24 +515,23 @@ async fn emit_certverify( }; transcript.add_message(&m); - common.send_msg(m, false).await; - Ok(()) + common.send_msg(m, false).await } -async fn emit_ccs(common: &mut CommonState) { +async fn emit_ccs(common: &mut CommonState) -> Result<(), Error> { let ccs = Message { version: ProtocolVersion::TLSv1_2, payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}), }; - common.send_msg(ccs, false).await; + common.send_msg(ccs, false).await } async fn emit_finished( verify_data: &[u8], transcript: &mut HandshakeHash, common: &mut CommonState, -) { +) -> Result<(), Error> { let verify_data_payload = Payload::new(verify_data); let f = Message { @@ -544,7 +543,7 @@ async fn emit_finished( }; transcript.add_message(&f); - common.send_msg(f, true).await; + common.send_msg(f, true).await } struct ServerKxDetails { @@ -578,7 +577,7 @@ struct ExpectServerDoneOrCertReq { must_issue_new_ticket: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectServerDoneOrCertReq { async fn handle( mut self: Box, @@ -644,7 +643,7 @@ struct ExpectCertificateRequest { must_issue_new_ticket: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificateRequest { async fn handle( mut self: Box, @@ -705,7 +704,7 @@ struct ExpectServerDone { must_issue_new_ticket: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectServerDone { async fn handle( self: Box, @@ -764,7 +763,7 @@ impl State for ExpectServerDone { now, ) { Ok(cert_verified) => cert_verified, - Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await), + Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?), }; // 3. @@ -794,7 +793,7 @@ impl State for ExpectServerDone { sig, ) { Ok(sig_verified) => sig_verified, - Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await), + Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?), } }; cx.common.peer_certificates = Some(st.server_cert.cert_chain); @@ -805,7 +804,7 @@ impl State for ExpectServerDone { ClientAuthDetails::Empty { .. } => Vec::new(), ClientAuthDetails::Verify { certkey, .. } => certkey.cert.clone(), }; - emit_certificate(&mut st.transcript, certs, cx.common).await; + emit_certificate(&mut st.transcript, certs, cx.common).await?; } // 5a. @@ -815,12 +814,12 @@ impl State for ExpectServerDone { None => { cx.common .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::CorruptMessagePayload(ContentType::Handshake)); } }; - let key_share = cx.common.handshaker.client_key_share().await?; + let key_share = cx.common.crypto.client_key_share().await?; if key_share.group != ecdh_params.curve_params.named_group { return Err(Error::PeerMisbehavedError( "peer chose an unsupported group".to_string(), @@ -829,7 +828,7 @@ impl State for ExpectServerDone { // 5b. let mut transcript = st.transcript; - emit_clientkx(&mut transcript, cx.common, &key_share).await; + emit_clientkx(&mut transcript, cx.common, &key_share).await?; // nb. EMS handshake hash only runs up to ClientKeyExchange. let _ems_seed = st.using_ems.then(|| transcript.get_current_hash()); @@ -839,31 +838,25 @@ impl State for ExpectServerDone { } // 5d. - emit_ccs(cx.common).await; + emit_ccs(cx.common).await?; // 5e. Now commit secrets. let server_key_share = PublicKey::new(ecdh_params.curve_params.named_group, &ecdh_params.public.0); cx.common - .handshaker + .crypto .set_server_key_share(server_key_share) .await?; - let enc = cx.common.handshaker.message_encrypter().await?; - let dec = cx.common.handshaker.message_decrypter().await?; - - cx.common.record_layer.set_message_encrypter(enc); - cx.common.record_layer.set_message_decrypter(dec); - st.config .key_log .log("CLIENT_RANDOM", &st.randoms.client, &[]); // 6. let hs = transcript.get_current_hash(); - let cf = cx.common.handshaker.client_finished(hs.as_ref()).await?; - emit_finished(&cf, &mut transcript, cx.common).await; + let cf = cx.common.crypto.client_finished(hs.as_ref()).await?; + emit_finished(&cf, &mut transcript, cx.common).await?; if st.must_issue_new_ticket { Ok(Box::new(ExpectNewTicket { @@ -906,7 +899,7 @@ struct ExpectNewTicket { sig_verified: verify::HandshakeSignatureValid, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectNewTicket { async fn handle( mut self: Box, @@ -950,7 +943,7 @@ struct ExpectCcs { sig_verified: verify::HandshakeSignatureValid, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCcs { async fn handle( self: Box, @@ -1055,7 +1048,7 @@ struct ExpectFinished { // } // } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectFinished { async fn handle( self: Box, @@ -1070,7 +1063,7 @@ impl State for ExpectFinished { // Work out what verify_data we expect. let vh = st.transcript.get_current_hash(); - let expect_verify_data = cx.common.handshaker.server_finished(vh.as_ref()).await?; + let expect_verify_data = cx.common.crypto.server_finished(vh.as_ref()).await?; // Constant-time verification of this is relatively unimportant: they only // get one chance. But it can't hurt. @@ -1080,7 +1073,7 @@ impl State for ExpectFinished { Err(_) => { cx.common .send_fatal_alert(AlertDescription::DecryptError) - .await; + .await?; return Err(Error::DecryptError); } }; @@ -1091,12 +1084,12 @@ impl State for ExpectFinished { // st.save_session(cx); if st.resuming { - emit_ccs(cx.common).await; + emit_ccs(cx.common).await?; cx.common.record_layer.start_encrypting(); - emit_finished(&expect_verify_data, &mut st.transcript, cx.common).await; + emit_finished(&expect_verify_data, &mut st.transcript, cx.common).await?; } - cx.common.start_traffic().await; + cx.common.start_traffic().await?; Ok(Box::new(ExpectTraffic { _cert_verified: st.cert_verified, _sig_verified: st.sig_verified, @@ -1112,7 +1105,7 @@ struct ExpectTraffic { _fin_verified: verify::FinishedMessageVerified, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectTraffic { async fn handle( self: Box, diff --git a/tls-client/src/client/tls13.rs b/tls-client/src/client/tls13.rs index e45b42754..0ae274e57 100644 --- a/tls-client/src/client/tls13.rs +++ b/tls-client/src/client/tls13.rs @@ -1,5 +1,6 @@ use crate::check::inappropriate_handshake_message; use crate::conn::{CommonState, ConnectionRandoms, State}; +use crate::crypto::{DecryptMode, EncryptMode}; use crate::error::Error; use crate::hash_hs::{HandshakeHash, HandshakeHashBuffer}; #[cfg(feature = "logging")] @@ -70,17 +71,17 @@ pub(super) async fn handle_server_hello( None => { cx.common .send_fatal_alert(AlertDescription::MissingExtension) - .await; + .await?; return Err(Error::PeerMisbehavedError("missing key share".to_string())); } }; if our_key_share.group != their_key_share.group { - return Err(cx.common.illegal_param("wrong group for key share").await); + return Err(cx.common.illegal_param("wrong group for key share").await?); } cx.common - .handshaker + .crypto .set_server_key_share(their_key_share.clone().into()) .await?; @@ -115,22 +116,19 @@ pub(super) async fn handle_server_hello( cx.common.check_aligned_handshake().await?; cx.common - .handshaker + .crypto .set_hs_hash_server_hello(transcript.get_current_hash().as_ref()) .await?; - let dec = cx.common.handshaker.message_decrypter().await?; - let enc = cx.common.handshaker.message_encrypter().await?; - // Decrypt with the peer's key, encrypt with our own key - cx.common.record_layer.set_message_decrypter(dec); + cx.common.crypto.set_decrypt(DecryptMode::Handshake)?; if !cx.data.early_data.is_enabled() { // Set the client encryption key for handshakes if early data is not used - cx.common.record_layer.set_message_encrypter(enc); + cx.common.crypto.set_encrypt(EncryptMode::Handshake)?; } - emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await; + emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await?; Ok(Box::new(ExpectEncryptedExtensions { config, @@ -151,7 +149,7 @@ async fn validate_server_hello( if !ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type()) { common .send_fatal_alert(AlertDescription::UnsupportedExtension) - .await; + .await?; return Err(Error::PeerMisbehavedError( "server sent unexpected cleartext ext".to_string(), )); @@ -224,16 +222,19 @@ async fn validate_server_hello( // exts.push(ClientExtension::PresharedKey(psk_ext)); // } -pub(super) async fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) { +pub(super) async fn emit_fake_ccs( + sent_tls13_fake_ccs: &mut bool, + common: &mut CommonState, +) -> Result<(), Error> { if std::mem::replace(sent_tls13_fake_ccs, true) { - return; + return Ok(()); } let m = Message { version: ProtocolVersion::TLSv1_2, payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}), }; - common.send_msg(m, false).await; + common.send_msg(m, false).await } async fn validate_encrypted_extensions( @@ -242,7 +243,9 @@ async fn validate_encrypted_extensions( exts: &EncryptedExtensions, ) -> Result<(), Error> { if exts.has_duplicate_extension() { - common.send_fatal_alert(AlertDescription::DecodeError).await; + common + .send_fatal_alert(AlertDescription::DecodeError) + .await?; return Err(Error::PeerMisbehavedError( "server sent duplicate encrypted extensions".to_string(), )); @@ -251,7 +254,7 @@ async fn validate_encrypted_extensions( if hello.server_sent_unsolicited_extensions(exts, &[]) { common .send_fatal_alert(AlertDescription::UnsupportedExtension) - .await; + .await?; let msg = "server sent unsolicited encrypted extension".to_string(); return Err(Error::PeerMisbehavedError(msg)); } @@ -262,7 +265,7 @@ async fn validate_encrypted_extensions( { common .send_fatal_alert(AlertDescription::UnsupportedExtension) - .await; + .await?; let msg = "server sent inappropriate encrypted extension".to_string(); return Err(Error::PeerMisbehavedError(msg)); } @@ -281,7 +284,7 @@ struct ExpectEncryptedExtensions { hello: ClientHelloDetails, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectEncryptedExtensions { async fn handle( mut self: Box, @@ -311,9 +314,8 @@ impl State for ExpectEncryptedExtensions { } if was_early_traffic && !cx.common.early_traffic { - let enc = cx.common.handshaker.message_encrypter().await?; // If no early traffic, set the encryption key for handshakes - cx.common.record_layer.set_message_encrypter(enc); + cx.common.record_layer.set_message_encrypter(); } cx.common.peer_certificates = Some(resuming_session.server_cert_chain().to_vec()); @@ -358,7 +360,7 @@ struct ExpectCertificateOrCertReq { may_send_sct_list: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificateOrCertReq { async fn handle( self: Box, @@ -421,7 +423,7 @@ struct ExpectCertificateRequest { may_send_sct_list: bool, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificateRequest { async fn handle( mut self: Box, @@ -444,7 +446,7 @@ impl State for ExpectCertificateRequest { warn!("Server sent non-empty certreq context"); cx.common .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::CorruptMessagePayload(ContentType::Handshake)); } @@ -461,7 +463,7 @@ impl State for ExpectCertificateRequest { if compat_sigschemes.is_empty() { cx.common .send_fatal_alert(AlertDescription::HandshakeFailure) - .await; + .await?; return Err(Error::PeerIncompatibleError( "server sent bad certreq schemes".to_string(), )); @@ -496,7 +498,7 @@ struct ExpectCertificate { client_auth: Option, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificate { async fn handle( mut self: Box, @@ -515,7 +517,7 @@ impl State for ExpectCertificate { warn!("certificate with non-empty context during handshake"); cx.common .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::CorruptMessagePayload(ContentType::Handshake)); } @@ -525,7 +527,7 @@ impl State for ExpectCertificate { warn!("certificate chain contains unsolicited/unknown extension"); cx.common .send_fatal_alert(AlertDescription::UnsupportedExtension) - .await; + .await?; return Err(Error::PeerMisbehavedError( "bad cert chain extensions".to_string(), )); @@ -572,7 +574,7 @@ struct ExpectCertificateVerify { client_auth: Option, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectCertificateVerify { async fn handle( mut self: Box, @@ -603,7 +605,7 @@ impl State for ExpectCertificateVerify { now, ) { Ok(cert_verified) => cert_verified, - Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await), + Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?), }; // 2. Verify their signature on the handshake. @@ -614,7 +616,7 @@ impl State for ExpectCertificateVerify { cert_verify, ) { Ok(sig_verified) => sig_verified, - Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await), + Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?), }; cx.common.peer_certificates = Some(self.server_cert.cert_chain); @@ -638,7 +640,7 @@ async fn emit_certificate_tls13( certkey: Option<&CertifiedKey>, auth_context: Option>, common: &mut CommonState, -) { +) -> Result<(), Error> { let context = auth_context.unwrap_or_default(); let mut cert_payload = CertificatePayloadTLS13 { @@ -662,7 +664,7 @@ async fn emit_certificate_tls13( }), }; transcript.add_message(&m); - common.send_msg(m, true).await; + common.send_msg(m, true).await } async fn emit_certverify_tls13( @@ -685,15 +687,14 @@ async fn emit_certverify_tls13( }; transcript.add_message(&m); - common.send_msg(m, true).await; - Ok(()) + common.send_msg(m, true).await } async fn emit_finished_tls13( verify_data: &[u8], transcript: &mut HandshakeHash, common: &mut CommonState, -) { +) -> Result<(), Error> { let verify_data_payload = Payload::new(verify_data); let m = Message { @@ -705,10 +706,13 @@ async fn emit_finished_tls13( }; transcript.add_message(&m); - common.send_msg(m, true).await; + common.send_msg(m, true).await } -async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) { +async fn emit_end_of_early_data_tls13( + transcript: &mut HandshakeHash, + common: &mut CommonState, +) -> Result<(), Error> { let m = Message { version: ProtocolVersion::TLSv1_3, payload: MessagePayload::Handshake(HandshakeMessagePayload { @@ -718,7 +722,7 @@ async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &m }; transcript.add_message(&m); - common.send_msg(m, true).await; + common.send_msg(m, true).await } struct ExpectFinished { @@ -732,7 +736,7 @@ struct ExpectFinished { sig_verified: verify::HandshakeSignatureValid, } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectFinished { async fn handle( self: Box, @@ -746,7 +750,7 @@ impl State for ExpectFinished { let handshake_hash = st.transcript.get_current_hash(); let expect_verify_data = cx .common - .handshaker + .crypto .server_finished(handshake_hash.as_ref()) .await?; @@ -758,7 +762,7 @@ impl State for ExpectFinished { Err(_) => { cx.common .send_fatal_alert(AlertDescription::DecryptError) - .await; + .await?; return Err(Error::DecryptError); } }; @@ -768,11 +772,10 @@ impl State for ExpectFinished { /* The EndOfEarlyData message to server is still encrypted with early data keys, * but appears in the transcript after the server Finished. */ if cx.common.early_traffic { - emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await; + emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await?; cx.common.early_traffic = false; cx.data.early_data.finished(); - let enc = cx.common.handshaker.message_encrypter().await?; - cx.common.record_layer.set_message_encrypter(enc); + cx.common.crypto.set_encrypt(EncryptMode::Handshake)?; } /* Send our authentication/finished messages. These are still encrypted @@ -782,7 +785,8 @@ impl State for ExpectFinished { ClientAuthDetails::Empty { auth_context_tls13: auth_context, } => { - emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common).await; + emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common) + .await?; } ClientAuthDetails::Verify { certkey, @@ -795,7 +799,7 @@ impl State for ExpectFinished { auth_context, cx.common, ) - .await; + .await?; emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common).await?; } } @@ -804,21 +808,18 @@ impl State for ExpectFinished { let handshake_hash = st.transcript.get_current_hash(); let client_finished = cx .common - .handshaker + .crypto .client_finished(handshake_hash.as_ref()) .await?; - emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await; + emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await?; /* Now move to our application traffic keys. */ cx.common.check_aligned_handshake().await?; - let dec = cx.common.handshaker.message_decrypter().await?; - cx.common.record_layer.set_message_decrypter(dec); + cx.common.crypto.set_encrypt(EncryptMode::Application)?; + cx.common.crypto.set_decrypt(DecryptMode::Application)?; - let enc = cx.common.handshaker.message_encrypter().await?; - cx.common.record_layer.set_message_encrypter(enc); - - cx.common.start_traffic().await; + cx.common.start_traffic().await?; let st = ExpectTraffic { session_storage: Arc::clone(&st.config.session_storage), @@ -859,7 +860,7 @@ impl ExpectTraffic { if nst.has_duplicate_extension() { cx.common .send_fatal_alert(AlertDescription::IllegalParameter) - .await; + .await?; return Err(Error::PeerMisbehavedError( "peer sent duplicate NewSessionTicket extensions".into(), )); @@ -915,7 +916,7 @@ impl ExpectTraffic { // Client does not support key updates common .send_fatal_alert(AlertDescription::InternalError) - .await; + .await?; return Err(Error::General( "received unsupported key update request from peer".to_string(), )); @@ -943,7 +944,7 @@ impl ExpectTraffic { } } -#[async_trait] +#[async_trait(?Send)] impl State for ExpectTraffic { async fn handle( mut self: Box, diff --git a/tls-client/src/conn.rs b/tls-client/src/conn.rs index 618159a7b..895376db3 100644 --- a/tls-client/src/conn.rs +++ b/tls-client/src/conn.rs @@ -1,6 +1,6 @@ use crate::client::ClientConnectionData; +use crate::crypto::{Crypto, InvalidCrypto}; use crate::error::Error; -use crate::handshaker::{Handshake, InvalidHandShaker}; #[cfg(feature = "logging")] use crate::log::{debug, error, trace, warn}; use crate::record_layer; @@ -339,7 +339,7 @@ impl ConnectionCommon { if self.handshake_joiner.take_message(msg).is_none() { self.common_state .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::CorruptMessagePayload(ContentType::Handshake)); } @@ -369,7 +369,7 @@ impl ConnectionCommon { // handshake with an "unexpected_message" alert." self.common_state .send_fatal_alert(AlertDescription::UnexpectedMessage) - .await; + .await?; return Err(Error::PeerMisbehavedError( "illegal middlebox CCS received".into(), )); @@ -406,7 +406,7 @@ impl ConnectionCommon { None => { self.common_state .send_fatal_alert(AlertDescription::DecodeError) - .await; + .await?; return Err(Error::CorruptMessagePayload(ContentType::Handshake)); } } @@ -488,15 +488,15 @@ impl ConnectionCommon { } /// Write buffer into connection - pub async fn write_plaintext(&mut self, buf: &[u8]) -> io::Result { + pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result { if let Ok(st) = &mut self.state { st.perhaps_write_key_update(&mut self.common_state).await; } - Ok(self.common_state.send_some_plaintext(buf).await) + self.common_state.send_some_plaintext(buf).await } /// Write entire buffer into connection - pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> io::Result { + pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result { let mut pos = 0; while pos < buf.len() { pos += self.write_plaintext(&buf[pos..]).await?; @@ -572,7 +572,7 @@ pub struct CommonState { pub(crate) negotiated_version: Option, pub(crate) side: Side, pub(crate) record_layer: record_layer::RecordLayer, - pub(crate) handshaker: Box, + pub(crate) crypto: Box, pub(crate) suite: Option, pub(crate) alpn_protocol: Option>, aligned_handshake: bool, @@ -600,7 +600,7 @@ impl CommonState { negotiated_version: None, side, record_layer: record_layer::RecordLayer::new(), - handshaker: Box::new(InvalidHandShaker {}), + crypto: Box::new(InvalidCrypto {}), suite: None, alpn_protocol: None, aligned_handshake: true, @@ -697,7 +697,7 @@ impl CommonState { }; if msg.is_handshake_type(reject_ty) { self.send_warning_alert(AlertDescription::NoRenegotiation) - .await; + .await?; return Ok(state); } } @@ -711,7 +711,7 @@ impl CommonState { Err(e @ Error::InappropriateMessage { .. }) | Err(e @ Error::InappropriateHandshakeMessage { .. }) => { self.send_fatal_alert(AlertDescription::UnexpectedMessage) - .await; + .await?; Err(e) } Err(e) => Err(e), @@ -723,7 +723,7 @@ impl CommonState { /// /// If internal buffers are too small, this function will not accept /// all the data. - pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> usize { + pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result { self.send_plain(data, Limit::Yes).await } @@ -734,7 +734,7 @@ impl CommonState { pub(crate) async fn check_aligned_handshake(&mut self) -> Result<(), Error> { if !self.aligned_handshake { self.send_fatal_alert(AlertDescription::UnexpectedMessage) - .await; + .await?; Err(Error::PeerMisbehavedError( "key epoch or handshake flight with pending fragment".to_string(), )) @@ -743,10 +743,10 @@ impl CommonState { } } - pub(crate) async fn illegal_param(&mut self, why: &str) -> Error { + pub(crate) async fn illegal_param(&mut self, why: &str) -> Result { self.send_fatal_alert(AlertDescription::IllegalParameter) - .await; - Error::PeerMisbehavedError(why.to_string()) + .await?; + Ok(Error::PeerMisbehavedError(why.to_string())) } pub(crate) async fn decrypt_incoming( @@ -754,16 +754,19 @@ impl CommonState { encr: OpaqueMessage, ) -> Result, Error> { if self.record_layer.wants_close_before_decrypt() { - self.send_close_notify().await; + self.send_close_notify().await?; } let encrypted_len = encr.payload.0.len(); - let plain = self.record_layer.decrypt_incoming(encr).await; + let plain = self + .record_layer + .decrypt_incoming(self.crypto.as_mut(), encr) + .await; match plain { Err(Error::PeerSentOversizedRecord) => { self.send_fatal_alert(AlertDescription::RecordOverflow) - .await; + .await?; Err(Error::PeerSentOversizedRecord) } Err(Error::DecryptError) if self.record_layer.doing_trial_decryption(encrypted_len) => { @@ -771,7 +774,8 @@ impl CommonState { Ok(None) } Err(Error::DecryptError) => { - self.send_fatal_alert(AlertDescription::BadRecordMac).await; + self.send_fatal_alert(AlertDescription::BadRecordMac) + .await?; Err(Error::DecryptError) } Err(e) => Err(e), @@ -781,18 +785,26 @@ impl CommonState { /// Fragment `m`, encrypt the fragments, and then queue /// the encrypted fragments for sending. - #[async_recursion] - pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) { + pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) -> Result<(), Error> { let mut plain_messages = VecDeque::new(); self.message_fragmenter.fragment(m, &mut plain_messages); - for m in plain_messages { - self.send_single_fragment(m).await; + // Close connection once we start to run out of + // sequence space. + if self.record_layer.wants_close_before_encrypt() { + debug!("Sending warning alert {:?}", AlertDescription::CloseNotify); + let m = Message::build_alert(AlertLevel::Warning, AlertDescription::CloseNotify); + self.send_single_fragment(m.into()).await?; } + + for m in plain_messages { + self.send_single_fragment(m).await?; + } + Ok(()) } /// Like send_msg_encrypt, but operate on an appdata directly. - async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize { + async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> Result { // Here, the limit on sendable_tls applies to encrypted data, // but we're respecting it for plaintext data -- so we'll // be out by whatever the cipher+record overhead is. That's a @@ -813,27 +825,25 @@ impl CommonState { ); for m in plain_messages { - self.send_single_fragment(m).await; + self.send_single_fragment(m).await?; } - len + Ok(len) } - async fn send_single_fragment(&mut self, m: PlainMessage) { - // Close connection once we start to run out of - // sequence space. - if self.record_layer.wants_close_before_encrypt() { - self.send_close_notify().await; - } - + async fn send_single_fragment(&mut self, m: PlainMessage) -> Result<(), Error> { // Refuse to wrap counter at all costs. This // is basically untestable unfortunately. if self.record_layer.encrypt_exhausted() { - return; + return Err(Error::EncryptError); } - let em = self.record_layer.encrypt_outgoing(m).await; + let em = self + .record_layer + .encrypt_outgoing(self.crypto.as_mut(), m) + .await?; self.queue_tls_message(em); + Ok(()) } /// Writes TLS messages to `wr`. @@ -852,7 +862,7 @@ impl CommonState { /// /// Returns the number of bytes written from `data`: this might /// be less than `data.len()` if buffer limits were exceeded. - async fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize { + async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result { if !self.may_send_application_data { // If we haven't completed handshaking, buffer // plaintext to send once we do. @@ -860,27 +870,27 @@ impl CommonState { Limit::Yes => self.sendable_plaintext.append_limited_copy(data), Limit::No => self.sendable_plaintext.append(data.to_vec()), }; - return len; + return Ok(len); } debug_assert!(self.record_layer.is_encrypting()); if data.is_empty() { // Don't send empty fragments. - return 0; + return Ok(0); } self.send_appdata_encrypt(data, limit).await } - pub(crate) async fn start_outgoing_traffic(&mut self) { + pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> { self.may_send_application_data = true; - self.flush_plaintext().await; + self.flush_plaintext().await } - pub(crate) async fn start_traffic(&mut self) { + pub(crate) async fn start_traffic(&mut self) -> Result<(), Error> { self.may_receive_application_data = true; - self.start_outgoing_traffic().await; + self.start_outgoing_traffic().await } /// Sets a limit on the internal buffers used to buffer @@ -929,14 +939,16 @@ impl CommonState { /// Send any buffered plaintext. Plaintext is buffered if /// written during handshake. - async fn flush_plaintext(&mut self) { + async fn flush_plaintext(&mut self) -> Result<(), Error> { if !self.may_send_application_data { - return; + return Ok(()); } while let Some(buf) = self.sendable_plaintext.pop() { - self.send_plain(&buf, Limit::No).await; + self.send_plain(&buf, Limit::No).await?; } + + Ok(()) } // Put m into sendable_tls for writing. @@ -945,15 +957,16 @@ impl CommonState { } /// Send a raw TLS message, fragmenting it if needed. - pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) { + pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) -> Result<(), Error> { if !must_encrypt { let mut to_send = VecDeque::new(); self.message_fragmenter.fragment(m.into(), &mut to_send); for mm in to_send { self.queue_tls_message(mm.into_unencrypted_opaque()); } + return Ok(()); } else { - self.send_msg_encrypt(m.into()).await; + self.send_msg_encrypt(m.into()).await } } @@ -961,16 +974,16 @@ impl CommonState { self.received_plaintext.append(bytes.0); } - async fn send_warning_alert(&mut self, desc: AlertDescription) { + async fn send_warning_alert(&mut self, desc: AlertDescription) -> Result<(), Error> { warn!("Sending warning alert {:?}", desc); - self.send_warning_alert_no_log(desc).await; + self.send_warning_alert_no_log(desc).await } async fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> { // Reject unknown AlertLevels. if let AlertLevel::Unknown(_) = alert.level { self.send_fatal_alert(AlertDescription::IllegalParameter) - .await; + .await?; } // If we get a CloseNotify, make a note to declare EOF to our @@ -984,7 +997,7 @@ impl CommonState { // (except, for no good reason, user_cancelled). if alert.level == AlertLevel::Warning { if self.is_tls13() && alert.description != AlertDescription::UserCanceled { - self.send_fatal_alert(AlertDescription::DecodeError).await; + self.send_fatal_alert(AlertDescription::DecodeError).await?; } else { warn!("TLS alert warning received: {:#?}", alert); return Ok(()); @@ -995,26 +1008,27 @@ impl CommonState { Err(Error::AlertReceived(alert.description)) } - pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) { + pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) -> Result<(), Error> { warn!("Sending fatal alert {:?}", desc); debug_assert!(!self.sent_fatal_alert); let m = Message::build_alert(AlertLevel::Fatal, desc); - self.send_msg(m, self.record_layer.is_encrypting()).await; + self.send_msg(m, self.record_layer.is_encrypting()).await?; self.sent_fatal_alert = true; + Ok(()) } /// Queues a close_notify warning alert to be sent in the next /// [`CommonState::write_tls`] call. This informs the peer that the /// connection is being closed. - pub async fn send_close_notify(&mut self) { + pub async fn send_close_notify(&mut self) -> Result<(), Error> { debug!("Sending warning alert {:?}", AlertDescription::CloseNotify); self.send_warning_alert_no_log(AlertDescription::CloseNotify) - .await; + .await } - async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) { + async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) -> Result<(), Error> { let m = Message::build_alert(AlertLevel::Warning, desc); - self.send_msg(m, self.record_layer.is_encrypting()).await; + self.send_msg(m, self.record_layer.is_encrypting()).await } pub(crate) fn set_max_fragment_size(&mut self, new: Option) -> Result<(), Error> { @@ -1054,7 +1068,7 @@ impl CommonState { } } -#[async_trait] +#[async_trait(?Send)] pub(crate) trait State: Send + Sync { async fn start( self: Box, diff --git a/tls-client/src/handshaker.rs b/tls-client/src/crypto.rs similarity index 68% rename from tls-client/src/handshaker.rs rename to tls-client/src/crypto.rs index 0a40e8b49..bbc3c5501 100644 --- a/tls-client/src/handshaker.rs +++ b/tls-client/src/crypto.rs @@ -1,16 +1,35 @@ use crate::Error; use tls_core::msgs::enums::ProtocolVersion; use tls_core::msgs::handshake::Random; +use tls_core::msgs::message::{OpaqueMessage, PlainMessage}; use tls_core::{key::PublicKey, suites::SupportedCipherSuite}; use async_trait::async_trait; use crate::cipher::{MessageDecrypter, MessageEncrypter}; +#[derive(Debug, Clone)] +pub enum EncryptMode { + /// Encrypt payload with PSK + EarlyData, + /// Encrypt payload with Handshake keys + Handshake, + /// Encrypt payload with Application traffic keys + Application, +} + +#[derive(Debug, Clone)] +pub enum DecryptMode { + /// Decrypt payload with Handshake keys + Handshake, + /// Decrypt payload with Application traffic keys + Application, +} + /// Core trait which manages crypto operations for the TLS connection such as key exchange, encryption /// and decryption. #[async_trait] -pub trait Handshake: Send { +pub trait Crypto: Send { /// Signals selected protocol version to implementor. /// Throws error if version is not supported. fn select_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), Error>; @@ -19,6 +38,10 @@ pub trait Handshake: Send { fn select_cipher_suite(&mut self, suite: SupportedCipherSuite) -> Result<(), Error>; /// Returns configured cipher suite. fn suite(&self) -> Result; + /// Set encryption mode + fn set_encrypt(&mut self, mode: EncryptMode) -> Result<(), Error>; + /// Start decryption + fn set_decrypt(&mut self, mode: DecryptMode) -> Result<(), Error>; /// Returns client_random value. async fn client_random(&mut self) -> Result; /// Returns public client keyshare. @@ -33,16 +56,16 @@ pub trait Handshake: Send { async fn server_finished(&mut self, hash: &[u8]) -> Result, Error>; /// Returns ClientFinished verify_data. async fn client_finished(&mut self, hash: &[u8]) -> Result, Error>; - /// Returns initialized MessageEncrypter. - async fn message_encrypter(&mut self) -> Result, Error>; - /// Returns initialized MessageDecrypter. - async fn message_decrypter(&mut self) -> Result, Error>; + /// Perform the encryption over the concerned TLS message. + async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result; + /// Perform the decryption over the concerned TLS message. + async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result; } -pub struct InvalidHandShaker {} +pub struct InvalidCrypto {} #[async_trait] -impl Handshake for InvalidHandShaker { +impl Crypto for InvalidCrypto { fn select_protocol_version(&mut self, _version: ProtocolVersion) -> Result<(), Error> { Err(Error::General("handshaker not yet available".to_string())) } @@ -52,6 +75,14 @@ impl Handshake for InvalidHandShaker { fn suite(&self) -> Result { Err(Error::General("handshaker not yet available".to_string())) } + /// Start encryption + fn set_encrypt(&mut self, _mode: EncryptMode) -> Result<(), Error> { + Err(Error::General("handshaker not yet available".to_string())) + } + /// Start decryption + fn set_decrypt(&mut self, _mode: DecryptMode) -> Result<(), Error> { + Err(Error::General("handshaker not yet available".to_string())) + } async fn client_random(&mut self) -> Result { Err(Error::General("handshaker not yet available".to_string())) } @@ -73,10 +104,10 @@ impl Handshake for InvalidHandShaker { async fn client_finished(&mut self, _hash: &[u8]) -> Result, Error> { Err(Error::General("handshaker not yet available".to_string())) } - async fn message_encrypter(&mut self) -> Result, Error> { + async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result { Err(Error::General("handshaker not yet available".to_string())) } - async fn message_decrypter(&mut self) -> Result, Error> { + async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result { Err(Error::General("handshaker not yet available".to_string())) } } diff --git a/tls-client/src/lib.rs b/tls-client/src/lib.rs index a236a8d27..f1648d8e1 100644 --- a/tls-client/src/lib.rs +++ b/tls-client/src/lib.rs @@ -315,8 +315,8 @@ pub extern crate tls_core; mod anchors; mod cipher; mod conn; +mod crypto; mod error; -mod handshaker; mod hash_hs; mod limited_cache; mod msgs; @@ -357,7 +357,7 @@ pub use crate::key_log::{KeyLog, NoKeyLog}; pub use crate::key_log_file::KeyLogFile; pub use crate::kx::{SupportedKxGroup, ALL_KX_GROUPS}; pub use cipher::{MessageDecrypter, MessageEncrypter}; -pub use handshaker::Handshake; +pub use crypto::Crypto; pub use tls_core::key::{Certificate, PrivateKey}; pub use tls_core::msgs::enums::CipherSuite; pub use tls_core::msgs::enums::ProtocolVersion; diff --git a/tls-client/src/record_layer.rs b/tls-client/src/record_layer.rs index 17f4fc9b4..ea43dc6e8 100644 --- a/tls-client/src/record_layer.rs +++ b/tls-client/src/record_layer.rs @@ -3,6 +3,7 @@ use crate::{ InvalidMessageDecrypter, InvalidMessageEncrypter, MessageDecrypter, MessageEncrypter, }, error::Error, + Crypto, }; use tls_core::msgs::message::{OpaqueMessage, PlainMessage}; @@ -22,8 +23,6 @@ enum DirectionState { } pub(crate) struct RecordLayer { - message_encrypter: Box, - message_decrypter: Box, write_seq: u64, read_seq: u64, encrypt_state: DirectionState, @@ -38,8 +37,6 @@ pub(crate) struct RecordLayer { impl RecordLayer { pub(crate) fn new() -> Self { Self { - message_encrypter: Box::new(InvalidMessageEncrypter {}), - message_decrypter: Box::new(InvalidMessageDecrypter {}), write_seq: 0, read_seq: 0, encrypt_state: DirectionState::Invalid, @@ -71,16 +68,14 @@ impl RecordLayer { /// Prepare to use the given `MessageEncrypter` for future message encryption. /// It is not used until you call `start_encrypting`. - pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box) { - self.message_encrypter = cipher; + pub(crate) fn prepare_message_encrypter(&mut self) { self.write_seq = 0; self.encrypt_state = DirectionState::Prepared; } /// Prepare to use the given `MessageDecrypter` for future message decryption. /// It is not used until you call `start_decrypting`. - pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box) { - self.message_decrypter = cipher; + pub(crate) fn prepare_message_decrypter(&mut self) { self.read_seq = 0; self.decrypt_state = DirectionState::Prepared; } @@ -101,15 +96,15 @@ impl RecordLayer { /// Set and start using the given `MessageEncrypter` for future outgoing /// message encryption. - pub(crate) fn set_message_encrypter(&mut self, cipher: Box) { - self.prepare_message_encrypter(cipher); + pub(crate) fn set_message_encrypter(&mut self) { + self.prepare_message_encrypter(); self.start_encrypting(); } /// Set and start using the given `MessageDecrypter` for future incoming /// message decryption. - pub(crate) fn set_message_decrypter(&mut self, cipher: Box) { - self.prepare_message_decrypter(cipher); + pub(crate) fn set_message_decrypter(&mut self) { + self.prepare_message_decrypter(); self.start_decrypting(); self.trial_decryption_len = None; } @@ -117,12 +112,8 @@ impl RecordLayer { /// Set and start using the given `MessageDecrypter` for future incoming /// message decryption, and enable "trial decryption" mode for when TLS1.3 /// 0-RTT is attempted but rejected by the server. - pub(crate) fn set_message_decrypter_with_trial_decryption( - &mut self, - cipher: Box, - max_length: usize, - ) { - self.prepare_message_decrypter(cipher); + pub(crate) fn set_message_decrypter_with_trial_decryption(&mut self, max_length: usize) { + self.prepare_message_decrypter(); self.start_decrypting(); self.trial_decryption_len = Some(max_length); } @@ -162,11 +153,12 @@ impl RecordLayer { /// an error is returned. pub(crate) async fn decrypt_incoming( &mut self, + cipher: &mut dyn Crypto, encr: OpaqueMessage, ) -> Result { debug_assert!(self.is_decrypting()); let seq = self.read_seq; - let msg = self.message_decrypter.decrypt(encr, seq).await?; + let msg = cipher.decrypt(encr, seq).await?; self.read_seq += 1; Ok(msg) } @@ -175,11 +167,15 @@ impl RecordLayer { /// /// `plain` is a TLS message we'd like to send. This function /// panics if the requisite keying material hasn't been established yet. - pub(crate) async fn encrypt_outgoing(&mut self, plain: PlainMessage) -> OpaqueMessage { + pub(crate) async fn encrypt_outgoing( + &mut self, + cipher: &mut dyn Crypto, + plain: PlainMessage, + ) -> Result { debug_assert!(self.encrypt_state == DirectionState::Active); assert!(!self.encrypt_exhausted()); let seq = self.write_seq; self.write_seq += 1; - self.message_encrypter.encrypt(plain, seq).await.unwrap() + cipher.encrypt(plain, seq).await } } diff --git a/tls-client/tests/api.rs b/tls-client/tests/api.rs index 915a1ed01..d4ee8a376 100644 --- a/tls-client/tests/api.rs +++ b/tls-client/tests/api.rs @@ -460,7 +460,7 @@ async fn client_close_notify() { // check that alerts don't overtake appdata assert_eq!(12, server.writer().write(b"from-server!").unwrap()); assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap()); - client.send_close_notify().await; + client.send_close_notify().await.unwrap(); send(&mut client, &mut server); let io_state = server.process_new_packets().unwrap(); @@ -973,7 +973,7 @@ where } } - fn new_fails(sess: &'a mut C) -> ServerSession<'a, C, S> { + fn _new_fails(sess: &'a mut C) -> ServerSession<'a, C, S> { let mut os = ServerSession::new(sess); os.fail_ok = true; os @@ -1067,7 +1067,7 @@ where } } - fn new_fails(sess: &'a mut C) -> ClientSession<'a, C> { + fn _new_fails(sess: &'a mut C) -> ClientSession<'a, C> { let mut os = ClientSession::new(sess); os.fail_ok = true; os @@ -1088,7 +1088,7 @@ impl<'a, C> io::Write for ClientSession<'a, C> where C: DerefMut + Deref, { - fn write(&mut self, buf: &[u8]) -> io::Result { + fn write(&mut self, _buf: &[u8]) -> io::Result { unreachable!() }