mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-04-28 03:00:14 -04:00
consolidate traits, convert functions to fallible, and remove recursion
This commit is contained in:
@@ -378,7 +378,7 @@ pub struct Initialized {
|
||||
config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for Initialized {
|
||||
async fn start(
|
||||
self: Box<Self>,
|
||||
|
||||
@@ -83,7 +83,7 @@ pub(super) async fn start_handshake(
|
||||
|
||||
let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3);
|
||||
let key_share = if support_tls13 {
|
||||
Some(cx.common.handshaker.client_key_share().await?)
|
||||
Some(cx.common.crypto.client_key_share().await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -113,7 +113,7 @@ pub(super) async fn start_handshake(
|
||||
session_id = Some(SessionID::random()?);
|
||||
}
|
||||
|
||||
let random = cx.common.handshaker.client_random().await?;
|
||||
let random = cx.common.crypto.client_random().await?;
|
||||
let hello_details = ClientHelloDetails::new();
|
||||
let sent_tls13_fake_ccs = false;
|
||||
let may_send_sct_list = config.verifier.request_scts();
|
||||
@@ -134,7 +134,7 @@ pub(super) async fn start_handshake(
|
||||
may_send_sct_list,
|
||||
None,
|
||||
)
|
||||
.await)
|
||||
.await?)
|
||||
}
|
||||
|
||||
struct ExpectServerHello {
|
||||
@@ -172,7 +172,7 @@ async fn emit_client_hello_for_retry(
|
||||
extra_exts: Vec<ClientExtension>,
|
||||
may_send_sct_list: bool,
|
||||
suite: Option<SupportedCipherSuite>,
|
||||
) -> NextState {
|
||||
) -> Result<NextState, Error> {
|
||||
// For now we do not support session resumption
|
||||
//
|
||||
// Do we have a SessionID or ticket cached for this host?
|
||||
@@ -337,13 +337,13 @@ async fn emit_client_hello_for_retry(
|
||||
if retryreq.is_some() {
|
||||
// send dummy CCS to fool middleboxes prior
|
||||
// to second client hello
|
||||
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
|
||||
tls13::emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await?;
|
||||
}
|
||||
|
||||
trace!("Sending ClientHello {:#?}", ch);
|
||||
|
||||
transcript_buffer.add_message(&ch);
|
||||
cx.common.send_msg(ch, false).await;
|
||||
cx.common.send_msg(ch, false).await?;
|
||||
|
||||
let next = ExpectServerHello {
|
||||
config,
|
||||
@@ -360,9 +360,12 @@ async fn emit_client_hello_for_retry(
|
||||
};
|
||||
|
||||
if support_tls13 && retryreq.is_none() {
|
||||
Box::new(ExpectServerHelloOrHelloRetryRequest { next, extra_exts })
|
||||
Ok(Box::new(ExpectServerHelloOrHelloRetryRequest {
|
||||
next,
|
||||
extra_exts,
|
||||
}))
|
||||
} else {
|
||||
Box::new(next)
|
||||
Ok(Box::new(next))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,7 +380,7 @@ pub(super) async fn process_alpn_protocol(
|
||||
if !config.alpn_protocols.contains(alpn_protocol) {
|
||||
return Err(common
|
||||
.illegal_param("server sent non-offered ALPN protocol")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,7 +395,7 @@ pub(super) fn sct_list_is_invalid(scts: &SCTList) -> bool {
|
||||
scts.is_empty() || scts.iter().any(|sct| sct.0.is_empty())
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectServerHello {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -429,7 +432,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server chose v1.2 using v1.3 extension")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
TLSv1_2
|
||||
@@ -437,7 +440,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
_ => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::ProtocolVersion)
|
||||
.await;
|
||||
.await?;
|
||||
let msg = match server_version {
|
||||
TLSv1_2 | TLSv1_3 => "server's TLS version is disabled in client",
|
||||
_ => "server does not support TLS v1.2/v1.3",
|
||||
@@ -446,19 +449,19 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
}
|
||||
};
|
||||
|
||||
cx.common.handshaker.select_protocol_version(version)?;
|
||||
cx.common.crypto.select_protocol_version(version)?;
|
||||
|
||||
if server_hello.compression_method != Compression::Null {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server chose non-Null compression")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
if server_hello.has_duplicate_extension() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent duplicate extensions".to_string(),
|
||||
));
|
||||
@@ -471,7 +474,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
{
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent unsolicited extension".to_string(),
|
||||
));
|
||||
@@ -491,7 +494,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
if !point_fmts.contains(&ECPointFormat::Uncompressed) {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server does not support uncompressed points".to_string(),
|
||||
));
|
||||
@@ -503,7 +506,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
None => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server chose non-offered ciphersuite".to_string(),
|
||||
));
|
||||
@@ -514,7 +517,7 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server chose unusable ciphersuite for version")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
match self.suite {
|
||||
@@ -522,13 +525,13 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server varied selected ciphersuite")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
_ => {
|
||||
debug!("Using ciphersuite {:?}", suite);
|
||||
self.suite = Some(suite);
|
||||
cx.common.suite = Some(suite);
|
||||
cx.common.handshaker.select_cipher_suite(suite)?;
|
||||
cx.common.crypto.select_cipher_suite(suite)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -620,7 +623,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with our group")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
// Or has an empty cookie.
|
||||
@@ -629,7 +632,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with empty cookie")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -637,7 +640,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
if hrr.has_unknown_extension() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerIncompatibleError(
|
||||
"server sent hrr with unhandled extension".to_string(),
|
||||
));
|
||||
@@ -648,7 +651,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server send duplicate hrr extensions")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
// Or asks us to change nothing.
|
||||
@@ -656,7 +659,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with no changes")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
// Or asks us to talk a protocol we didn't offer, or doesn't support HRR at all.
|
||||
@@ -668,7 +671,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested unsupported version in hrr")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -680,7 +683,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested unsupported cs in hrr")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -705,7 +708,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("server requested hrr with bad group")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
_ => offered_key_share,
|
||||
};
|
||||
@@ -727,11 +730,11 @@ impl ExpectServerHelloOrHelloRetryRequest {
|
||||
may_send_sct_list,
|
||||
Some(cs),
|
||||
)
|
||||
.await)
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
|
||||
async fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
|
||||
match m.payload {
|
||||
@@ -752,22 +755,27 @@ impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn send_cert_error_alert(common: &mut CommonState, err: Error) -> Error {
|
||||
pub(super) async fn send_cert_error_alert(
|
||||
common: &mut CommonState,
|
||||
err: Error,
|
||||
) -> Result<Error, Error> {
|
||||
match err {
|
||||
Error::InvalidCertificateEncoding => {
|
||||
common.send_fatal_alert(AlertDescription::DecodeError).await;
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await?;
|
||||
}
|
||||
Error::PeerMisbehavedError(_) => {
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::BadCertificate)
|
||||
.await;
|
||||
.await?;
|
||||
}
|
||||
};
|
||||
|
||||
err
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ mod server_hello {
|
||||
return Err(cx
|
||||
.common
|
||||
.illegal_param("downgrade to TLS1.2 when TLS1.3 is supported")
|
||||
.await);
|
||||
.await?);
|
||||
}
|
||||
|
||||
// Doing EMS?
|
||||
@@ -209,7 +209,7 @@ struct ExpectCertificate {
|
||||
server_cert_sct_list: Option<SCTList>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificate {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -271,7 +271,7 @@ struct ExpectCertificateStatusOrServerKx {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
@@ -348,7 +348,7 @@ struct ExpectCertificateStatus {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificateStatus {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -402,7 +402,7 @@ struct ExpectServerKx {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectServerKx {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -422,7 +422,7 @@ impl State<ClientConnectionData> for ExpectServerKx {
|
||||
// We only support ECDHE
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
};
|
||||
@@ -457,7 +457,7 @@ async fn emit_certificate(
|
||||
transcript: &mut HandshakeHash,
|
||||
cert_chain: CertificatePayload,
|
||||
common: &mut CommonState,
|
||||
) {
|
||||
) -> Result<(), Error> {
|
||||
let cert = Message {
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
@@ -467,14 +467,14 @@ async fn emit_certificate(
|
||||
};
|
||||
|
||||
transcript.add_message(&cert);
|
||||
common.send_msg(cert, false).await;
|
||||
common.send_msg(cert, false).await
|
||||
}
|
||||
|
||||
async fn emit_clientkx(
|
||||
transcript: &mut HandshakeHash,
|
||||
common: &mut CommonState,
|
||||
pubkey: &PublicKey,
|
||||
) {
|
||||
) -> Result<(), Error> {
|
||||
let ecpoint = PayloadU8::new(pubkey.key.clone());
|
||||
|
||||
let mut buf = Vec::new();
|
||||
@@ -490,7 +490,7 @@ async fn emit_clientkx(
|
||||
};
|
||||
|
||||
transcript.add_message(&ckx);
|
||||
common.send_msg(ckx, false).await;
|
||||
common.send_msg(ckx, false).await
|
||||
}
|
||||
|
||||
async fn emit_certverify(
|
||||
@@ -515,24 +515,23 @@ async fn emit_certverify(
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, false).await;
|
||||
Ok(())
|
||||
common.send_msg(m, false).await
|
||||
}
|
||||
|
||||
async fn emit_ccs(common: &mut CommonState) {
|
||||
async fn emit_ccs(common: &mut CommonState) -> Result<(), Error> {
|
||||
let ccs = Message {
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
|
||||
};
|
||||
|
||||
common.send_msg(ccs, false).await;
|
||||
common.send_msg(ccs, false).await
|
||||
}
|
||||
|
||||
async fn emit_finished(
|
||||
verify_data: &[u8],
|
||||
transcript: &mut HandshakeHash,
|
||||
common: &mut CommonState,
|
||||
) {
|
||||
) -> Result<(), Error> {
|
||||
let verify_data_payload = Payload::new(verify_data);
|
||||
|
||||
let f = Message {
|
||||
@@ -544,7 +543,7 @@ async fn emit_finished(
|
||||
};
|
||||
|
||||
transcript.add_message(&f);
|
||||
common.send_msg(f, true).await;
|
||||
common.send_msg(f, true).await
|
||||
}
|
||||
|
||||
struct ServerKxDetails {
|
||||
@@ -578,7 +577,7 @@ struct ExpectServerDoneOrCertReq {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -644,7 +643,7 @@ struct ExpectCertificateRequest {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -705,7 +704,7 @@ struct ExpectServerDone {
|
||||
must_issue_new_ticket: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectServerDone {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
@@ -764,7 +763,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
now,
|
||||
) {
|
||||
Ok(cert_verified) => cert_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
|
||||
};
|
||||
|
||||
// 3.
|
||||
@@ -794,7 +793,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
sig,
|
||||
) {
|
||||
Ok(sig_verified) => sig_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
|
||||
}
|
||||
};
|
||||
cx.common.peer_certificates = Some(st.server_cert.cert_chain);
|
||||
@@ -805,7 +804,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
ClientAuthDetails::Empty { .. } => Vec::new(),
|
||||
ClientAuthDetails::Verify { certkey, .. } => certkey.cert.clone(),
|
||||
};
|
||||
emit_certificate(&mut st.transcript, certs, cx.common).await;
|
||||
emit_certificate(&mut st.transcript, certs, cx.common).await?;
|
||||
}
|
||||
|
||||
// 5a.
|
||||
@@ -815,12 +814,12 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
None => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
};
|
||||
|
||||
let key_share = cx.common.handshaker.client_key_share().await?;
|
||||
let key_share = cx.common.crypto.client_key_share().await?;
|
||||
if key_share.group != ecdh_params.curve_params.named_group {
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"peer chose an unsupported group".to_string(),
|
||||
@@ -829,7 +828,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
|
||||
// 5b.
|
||||
let mut transcript = st.transcript;
|
||||
emit_clientkx(&mut transcript, cx.common, &key_share).await;
|
||||
emit_clientkx(&mut transcript, cx.common, &key_share).await?;
|
||||
// nb. EMS handshake hash only runs up to ClientKeyExchange.
|
||||
let _ems_seed = st.using_ems.then(|| transcript.get_current_hash());
|
||||
|
||||
@@ -839,31 +838,25 @@ impl State<ClientConnectionData> for ExpectServerDone {
|
||||
}
|
||||
|
||||
// 5d.
|
||||
emit_ccs(cx.common).await;
|
||||
emit_ccs(cx.common).await?;
|
||||
|
||||
// 5e. Now commit secrets.
|
||||
let server_key_share =
|
||||
PublicKey::new(ecdh_params.curve_params.named_group, &ecdh_params.public.0);
|
||||
|
||||
cx.common
|
||||
.handshaker
|
||||
.crypto
|
||||
.set_server_key_share(server_key_share)
|
||||
.await?;
|
||||
|
||||
let enc = cx.common.handshaker.message_encrypter().await?;
|
||||
let dec = cx.common.handshaker.message_decrypter().await?;
|
||||
|
||||
cx.common.record_layer.set_message_encrypter(enc);
|
||||
cx.common.record_layer.set_message_decrypter(dec);
|
||||
|
||||
st.config
|
||||
.key_log
|
||||
.log("CLIENT_RANDOM", &st.randoms.client, &[]);
|
||||
|
||||
// 6.
|
||||
let hs = transcript.get_current_hash();
|
||||
let cf = cx.common.handshaker.client_finished(hs.as_ref()).await?;
|
||||
emit_finished(&cf, &mut transcript, cx.common).await;
|
||||
let cf = cx.common.crypto.client_finished(hs.as_ref()).await?;
|
||||
emit_finished(&cf, &mut transcript, cx.common).await?;
|
||||
|
||||
if st.must_issue_new_ticket {
|
||||
Ok(Box::new(ExpectNewTicket {
|
||||
@@ -906,7 +899,7 @@ struct ExpectNewTicket {
|
||||
sig_verified: verify::HandshakeSignatureValid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectNewTicket {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -950,7 +943,7 @@ struct ExpectCcs {
|
||||
sig_verified: verify::HandshakeSignatureValid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCcs {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
@@ -1055,7 +1048,7 @@ struct ExpectFinished {
|
||||
// }
|
||||
// }
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectFinished {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
@@ -1070,7 +1063,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
|
||||
// Work out what verify_data we expect.
|
||||
let vh = st.transcript.get_current_hash();
|
||||
let expect_verify_data = cx.common.handshaker.server_finished(vh.as_ref()).await?;
|
||||
let expect_verify_data = cx.common.crypto.server_finished(vh.as_ref()).await?;
|
||||
|
||||
// Constant-time verification of this is relatively unimportant: they only
|
||||
// get one chance. But it can't hurt.
|
||||
@@ -1080,7 +1073,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
Err(_) => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecryptError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::DecryptError);
|
||||
}
|
||||
};
|
||||
@@ -1091,12 +1084,12 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
// st.save_session(cx);
|
||||
|
||||
if st.resuming {
|
||||
emit_ccs(cx.common).await;
|
||||
emit_ccs(cx.common).await?;
|
||||
cx.common.record_layer.start_encrypting();
|
||||
emit_finished(&expect_verify_data, &mut st.transcript, cx.common).await;
|
||||
emit_finished(&expect_verify_data, &mut st.transcript, cx.common).await?;
|
||||
}
|
||||
|
||||
cx.common.start_traffic().await;
|
||||
cx.common.start_traffic().await?;
|
||||
Ok(Box::new(ExpectTraffic {
|
||||
_cert_verified: st.cert_verified,
|
||||
_sig_verified: st.sig_verified,
|
||||
@@ -1112,7 +1105,7 @@ struct ExpectTraffic {
|
||||
_fin_verified: verify::FinishedMessageVerified,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectTraffic {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::check::inappropriate_handshake_message;
|
||||
use crate::conn::{CommonState, ConnectionRandoms, State};
|
||||
use crate::crypto::{DecryptMode, EncryptMode};
|
||||
use crate::error::Error;
|
||||
use crate::hash_hs::{HandshakeHash, HandshakeHashBuffer};
|
||||
#[cfg(feature = "logging")]
|
||||
@@ -70,17 +71,17 @@ pub(super) async fn handle_server_hello(
|
||||
None => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::MissingExtension)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError("missing key share".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
if our_key_share.group != their_key_share.group {
|
||||
return Err(cx.common.illegal_param("wrong group for key share").await);
|
||||
return Err(cx.common.illegal_param("wrong group for key share").await?);
|
||||
}
|
||||
|
||||
cx.common
|
||||
.handshaker
|
||||
.crypto
|
||||
.set_server_key_share(their_key_share.clone().into())
|
||||
.await?;
|
||||
|
||||
@@ -115,22 +116,19 @@ pub(super) async fn handle_server_hello(
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
cx.common
|
||||
.handshaker
|
||||
.crypto
|
||||
.set_hs_hash_server_hello(transcript.get_current_hash().as_ref())
|
||||
.await?;
|
||||
|
||||
let dec = cx.common.handshaker.message_decrypter().await?;
|
||||
let enc = cx.common.handshaker.message_encrypter().await?;
|
||||
|
||||
// Decrypt with the peer's key, encrypt with our own key
|
||||
cx.common.record_layer.set_message_decrypter(dec);
|
||||
cx.common.crypto.set_decrypt(DecryptMode::Handshake)?;
|
||||
|
||||
if !cx.data.early_data.is_enabled() {
|
||||
// Set the client encryption key for handshakes if early data is not used
|
||||
cx.common.record_layer.set_message_encrypter(enc);
|
||||
cx.common.crypto.set_encrypt(EncryptMode::Handshake)?;
|
||||
}
|
||||
|
||||
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await;
|
||||
emit_fake_ccs(&mut sent_tls13_fake_ccs, cx.common).await?;
|
||||
|
||||
Ok(Box::new(ExpectEncryptedExtensions {
|
||||
config,
|
||||
@@ -151,7 +149,7 @@ async fn validate_server_hello(
|
||||
if !ALLOWED_PLAINTEXT_EXTS.contains(&ext.get_type()) {
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent unexpected cleartext ext".to_string(),
|
||||
));
|
||||
@@ -224,16 +222,19 @@ async fn validate_server_hello(
|
||||
// exts.push(ClientExtension::PresharedKey(psk_ext));
|
||||
// }
|
||||
|
||||
pub(super) async fn emit_fake_ccs(sent_tls13_fake_ccs: &mut bool, common: &mut CommonState) {
|
||||
pub(super) async fn emit_fake_ccs(
|
||||
sent_tls13_fake_ccs: &mut bool,
|
||||
common: &mut CommonState,
|
||||
) -> Result<(), Error> {
|
||||
if std::mem::replace(sent_tls13_fake_ccs, true) {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let m = Message {
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}),
|
||||
};
|
||||
common.send_msg(m, false).await;
|
||||
common.send_msg(m, false).await
|
||||
}
|
||||
|
||||
async fn validate_encrypted_extensions(
|
||||
@@ -242,7 +243,9 @@ async fn validate_encrypted_extensions(
|
||||
exts: &EncryptedExtensions,
|
||||
) -> Result<(), Error> {
|
||||
if exts.has_duplicate_extension() {
|
||||
common.send_fatal_alert(AlertDescription::DecodeError).await;
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"server sent duplicate encrypted extensions".to_string(),
|
||||
));
|
||||
@@ -251,7 +254,7 @@ async fn validate_encrypted_extensions(
|
||||
if hello.server_sent_unsolicited_extensions(exts, &[]) {
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
.await?;
|
||||
let msg = "server sent unsolicited encrypted extension".to_string();
|
||||
return Err(Error::PeerMisbehavedError(msg));
|
||||
}
|
||||
@@ -262,7 +265,7 @@ async fn validate_encrypted_extensions(
|
||||
{
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
.await?;
|
||||
let msg = "server sent inappropriate encrypted extension".to_string();
|
||||
return Err(Error::PeerMisbehavedError(msg));
|
||||
}
|
||||
@@ -281,7 +284,7 @@ struct ExpectEncryptedExtensions {
|
||||
hello: ClientHelloDetails,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectEncryptedExtensions {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -311,9 +314,8 @@ impl State<ClientConnectionData> for ExpectEncryptedExtensions {
|
||||
}
|
||||
|
||||
if was_early_traffic && !cx.common.early_traffic {
|
||||
let enc = cx.common.handshaker.message_encrypter().await?;
|
||||
// If no early traffic, set the encryption key for handshakes
|
||||
cx.common.record_layer.set_message_encrypter(enc);
|
||||
cx.common.record_layer.set_message_encrypter();
|
||||
}
|
||||
|
||||
cx.common.peer_certificates = Some(resuming_session.server_cert_chain().to_vec());
|
||||
@@ -358,7 +360,7 @@ struct ExpectCertificateOrCertReq {
|
||||
may_send_sct_list: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificateOrCertReq {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
@@ -421,7 +423,7 @@ struct ExpectCertificateRequest {
|
||||
may_send_sct_list: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -444,7 +446,7 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
warn!("Server sent non-empty certreq context");
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
@@ -461,7 +463,7 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
|
||||
if compat_sigschemes.is_empty() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::HandshakeFailure)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerIncompatibleError(
|
||||
"server sent bad certreq schemes".to_string(),
|
||||
));
|
||||
@@ -496,7 +498,7 @@ struct ExpectCertificate {
|
||||
client_auth: Option<ClientAuthDetails>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificate {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -515,7 +517,7 @@ impl State<ClientConnectionData> for ExpectCertificate {
|
||||
warn!("certificate with non-empty context during handshake");
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
@@ -525,7 +527,7 @@ impl State<ClientConnectionData> for ExpectCertificate {
|
||||
warn!("certificate chain contains unsolicited/unknown extension");
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::UnsupportedExtension)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"bad cert chain extensions".to_string(),
|
||||
));
|
||||
@@ -572,7 +574,7 @@ struct ExpectCertificateVerify {
|
||||
client_auth: Option<ClientAuthDetails>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectCertificateVerify {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
@@ -603,7 +605,7 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
|
||||
now,
|
||||
) {
|
||||
Ok(cert_verified) => cert_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
|
||||
};
|
||||
|
||||
// 2. Verify their signature on the handshake.
|
||||
@@ -614,7 +616,7 @@ impl State<ClientConnectionData> for ExpectCertificateVerify {
|
||||
cert_verify,
|
||||
) {
|
||||
Ok(sig_verified) => sig_verified,
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await),
|
||||
Err(e) => return Err(hs::send_cert_error_alert(cx.common, e).await?),
|
||||
};
|
||||
|
||||
cx.common.peer_certificates = Some(self.server_cert.cert_chain);
|
||||
@@ -638,7 +640,7 @@ async fn emit_certificate_tls13(
|
||||
certkey: Option<&CertifiedKey>,
|
||||
auth_context: Option<Vec<u8>>,
|
||||
common: &mut CommonState,
|
||||
) {
|
||||
) -> Result<(), Error> {
|
||||
let context = auth_context.unwrap_or_default();
|
||||
|
||||
let mut cert_payload = CertificatePayloadTLS13 {
|
||||
@@ -662,7 +664,7 @@ async fn emit_certificate_tls13(
|
||||
}),
|
||||
};
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true).await;
|
||||
common.send_msg(m, true).await
|
||||
}
|
||||
|
||||
async fn emit_certverify_tls13(
|
||||
@@ -685,15 +687,14 @@ async fn emit_certverify_tls13(
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true).await;
|
||||
Ok(())
|
||||
common.send_msg(m, true).await
|
||||
}
|
||||
|
||||
async fn emit_finished_tls13(
|
||||
verify_data: &[u8],
|
||||
transcript: &mut HandshakeHash,
|
||||
common: &mut CommonState,
|
||||
) {
|
||||
) -> Result<(), Error> {
|
||||
let verify_data_payload = Payload::new(verify_data);
|
||||
|
||||
let m = Message {
|
||||
@@ -705,10 +706,13 @@ async fn emit_finished_tls13(
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true).await;
|
||||
common.send_msg(m, true).await
|
||||
}
|
||||
|
||||
async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &mut CommonState) {
|
||||
async fn emit_end_of_early_data_tls13(
|
||||
transcript: &mut HandshakeHash,
|
||||
common: &mut CommonState,
|
||||
) -> Result<(), Error> {
|
||||
let m = Message {
|
||||
version: ProtocolVersion::TLSv1_3,
|
||||
payload: MessagePayload::Handshake(HandshakeMessagePayload {
|
||||
@@ -718,7 +722,7 @@ async fn emit_end_of_early_data_tls13(transcript: &mut HandshakeHash, common: &m
|
||||
};
|
||||
|
||||
transcript.add_message(&m);
|
||||
common.send_msg(m, true).await;
|
||||
common.send_msg(m, true).await
|
||||
}
|
||||
|
||||
struct ExpectFinished {
|
||||
@@ -732,7 +736,7 @@ struct ExpectFinished {
|
||||
sig_verified: verify::HandshakeSignatureValid,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectFinished {
|
||||
async fn handle(
|
||||
self: Box<Self>,
|
||||
@@ -746,7 +750,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
let handshake_hash = st.transcript.get_current_hash();
|
||||
let expect_verify_data = cx
|
||||
.common
|
||||
.handshaker
|
||||
.crypto
|
||||
.server_finished(handshake_hash.as_ref())
|
||||
.await?;
|
||||
|
||||
@@ -758,7 +762,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
Err(_) => {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::DecryptError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::DecryptError);
|
||||
}
|
||||
};
|
||||
@@ -768,11 +772,10 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
/* The EndOfEarlyData message to server is still encrypted with early data keys,
|
||||
* but appears in the transcript after the server Finished. */
|
||||
if cx.common.early_traffic {
|
||||
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await;
|
||||
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await?;
|
||||
cx.common.early_traffic = false;
|
||||
cx.data.early_data.finished();
|
||||
let enc = cx.common.handshaker.message_encrypter().await?;
|
||||
cx.common.record_layer.set_message_encrypter(enc);
|
||||
cx.common.crypto.set_encrypt(EncryptMode::Handshake)?;
|
||||
}
|
||||
|
||||
/* Send our authentication/finished messages. These are still encrypted
|
||||
@@ -782,7 +785,8 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
ClientAuthDetails::Empty {
|
||||
auth_context_tls13: auth_context,
|
||||
} => {
|
||||
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common).await;
|
||||
emit_certificate_tls13(&mut st.transcript, None, auth_context, cx.common)
|
||||
.await?;
|
||||
}
|
||||
ClientAuthDetails::Verify {
|
||||
certkey,
|
||||
@@ -795,7 +799,7 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
auth_context,
|
||||
cx.common,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
emit_certverify_tls13(&mut st.transcript, signer.as_ref(), cx.common).await?;
|
||||
}
|
||||
}
|
||||
@@ -804,21 +808,18 @@ impl State<ClientConnectionData> for ExpectFinished {
|
||||
let handshake_hash = st.transcript.get_current_hash();
|
||||
let client_finished = cx
|
||||
.common
|
||||
.handshaker
|
||||
.crypto
|
||||
.client_finished(handshake_hash.as_ref())
|
||||
.await?;
|
||||
emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await;
|
||||
emit_finished_tls13(&client_finished, &mut st.transcript, cx.common).await?;
|
||||
|
||||
/* Now move to our application traffic keys. */
|
||||
cx.common.check_aligned_handshake().await?;
|
||||
|
||||
let dec = cx.common.handshaker.message_decrypter().await?;
|
||||
cx.common.record_layer.set_message_decrypter(dec);
|
||||
cx.common.crypto.set_encrypt(EncryptMode::Application)?;
|
||||
cx.common.crypto.set_decrypt(DecryptMode::Application)?;
|
||||
|
||||
let enc = cx.common.handshaker.message_encrypter().await?;
|
||||
cx.common.record_layer.set_message_encrypter(enc);
|
||||
|
||||
cx.common.start_traffic().await;
|
||||
cx.common.start_traffic().await?;
|
||||
|
||||
let st = ExpectTraffic {
|
||||
session_storage: Arc::clone(&st.config.session_storage),
|
||||
@@ -859,7 +860,7 @@ impl ExpectTraffic {
|
||||
if nst.has_duplicate_extension() {
|
||||
cx.common
|
||||
.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"peer sent duplicate NewSessionTicket extensions".into(),
|
||||
));
|
||||
@@ -915,7 +916,7 @@ impl ExpectTraffic {
|
||||
// Client does not support key updates
|
||||
common
|
||||
.send_fatal_alert(AlertDescription::InternalError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::General(
|
||||
"received unsupported key update request from peer".to_string(),
|
||||
));
|
||||
@@ -943,7 +944,7 @@ impl ExpectTraffic {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
impl State<ClientConnectionData> for ExpectTraffic {
|
||||
async fn handle(
|
||||
mut self: Box<Self>,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::client::ClientConnectionData;
|
||||
use crate::crypto::{Crypto, InvalidCrypto};
|
||||
use crate::error::Error;
|
||||
use crate::handshaker::{Handshake, InvalidHandShaker};
|
||||
#[cfg(feature = "logging")]
|
||||
use crate::log::{debug, error, trace, warn};
|
||||
use crate::record_layer;
|
||||
@@ -339,7 +339,7 @@ impl ConnectionCommon {
|
||||
if self.handshake_joiner.take_message(msg).is_none() {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ impl ConnectionCommon {
|
||||
// handshake with an "unexpected_message" alert."
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::UnexpectedMessage)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::PeerMisbehavedError(
|
||||
"illegal middlebox CCS received".into(),
|
||||
));
|
||||
@@ -406,7 +406,7 @@ impl ConnectionCommon {
|
||||
None => {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError)
|
||||
.await;
|
||||
.await?;
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
}
|
||||
@@ -488,15 +488,15 @@ impl ConnectionCommon {
|
||||
}
|
||||
|
||||
/// Write buffer into connection
|
||||
pub async fn write_plaintext(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
pub async fn write_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
if let Ok(st) = &mut self.state {
|
||||
st.perhaps_write_key_update(&mut self.common_state).await;
|
||||
}
|
||||
Ok(self.common_state.send_some_plaintext(buf).await)
|
||||
self.common_state.send_some_plaintext(buf).await
|
||||
}
|
||||
|
||||
/// Write entire buffer into connection
|
||||
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
pub async fn write_all_plaintext(&mut self, buf: &[u8]) -> Result<usize, Error> {
|
||||
let mut pos = 0;
|
||||
while pos < buf.len() {
|
||||
pos += self.write_plaintext(&buf[pos..]).await?;
|
||||
@@ -572,7 +572,7 @@ pub struct CommonState {
|
||||
pub(crate) negotiated_version: Option<ProtocolVersion>,
|
||||
pub(crate) side: Side,
|
||||
pub(crate) record_layer: record_layer::RecordLayer,
|
||||
pub(crate) handshaker: Box<dyn Handshake>,
|
||||
pub(crate) crypto: Box<dyn Crypto>,
|
||||
pub(crate) suite: Option<SupportedCipherSuite>,
|
||||
pub(crate) alpn_protocol: Option<Vec<u8>>,
|
||||
aligned_handshake: bool,
|
||||
@@ -600,7 +600,7 @@ impl CommonState {
|
||||
negotiated_version: None,
|
||||
side,
|
||||
record_layer: record_layer::RecordLayer::new(),
|
||||
handshaker: Box::new(InvalidHandShaker {}),
|
||||
crypto: Box::new(InvalidCrypto {}),
|
||||
suite: None,
|
||||
alpn_protocol: None,
|
||||
aligned_handshake: true,
|
||||
@@ -697,7 +697,7 @@ impl CommonState {
|
||||
};
|
||||
if msg.is_handshake_type(reject_ty) {
|
||||
self.send_warning_alert(AlertDescription::NoRenegotiation)
|
||||
.await;
|
||||
.await?;
|
||||
return Ok(state);
|
||||
}
|
||||
}
|
||||
@@ -711,7 +711,7 @@ impl CommonState {
|
||||
Err(e @ Error::InappropriateMessage { .. })
|
||||
| Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
|
||||
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
|
||||
.await;
|
||||
.await?;
|
||||
Err(e)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@@ -723,7 +723,7 @@ impl CommonState {
|
||||
///
|
||||
/// If internal buffers are too small, this function will not accept
|
||||
/// all the data.
|
||||
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
|
||||
pub(crate) async fn send_some_plaintext(&mut self, data: &[u8]) -> Result<usize, Error> {
|
||||
self.send_plain(data, Limit::Yes).await
|
||||
}
|
||||
|
||||
@@ -734,7 +734,7 @@ impl CommonState {
|
||||
pub(crate) async fn check_aligned_handshake(&mut self) -> Result<(), Error> {
|
||||
if !self.aligned_handshake {
|
||||
self.send_fatal_alert(AlertDescription::UnexpectedMessage)
|
||||
.await;
|
||||
.await?;
|
||||
Err(Error::PeerMisbehavedError(
|
||||
"key epoch or handshake flight with pending fragment".to_string(),
|
||||
))
|
||||
@@ -743,10 +743,10 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn illegal_param(&mut self, why: &str) -> Error {
|
||||
pub(crate) async fn illegal_param(&mut self, why: &str) -> Result<Error, Error> {
|
||||
self.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
Error::PeerMisbehavedError(why.to_string())
|
||||
.await?;
|
||||
Ok(Error::PeerMisbehavedError(why.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) async fn decrypt_incoming(
|
||||
@@ -754,16 +754,19 @@ impl CommonState {
|
||||
encr: OpaqueMessage,
|
||||
) -> Result<Option<PlainMessage>, Error> {
|
||||
if self.record_layer.wants_close_before_decrypt() {
|
||||
self.send_close_notify().await;
|
||||
self.send_close_notify().await?;
|
||||
}
|
||||
|
||||
let encrypted_len = encr.payload.0.len();
|
||||
let plain = self.record_layer.decrypt_incoming(encr).await;
|
||||
let plain = self
|
||||
.record_layer
|
||||
.decrypt_incoming(self.crypto.as_mut(), encr)
|
||||
.await;
|
||||
|
||||
match plain {
|
||||
Err(Error::PeerSentOversizedRecord) => {
|
||||
self.send_fatal_alert(AlertDescription::RecordOverflow)
|
||||
.await;
|
||||
.await?;
|
||||
Err(Error::PeerSentOversizedRecord)
|
||||
}
|
||||
Err(Error::DecryptError) if self.record_layer.doing_trial_decryption(encrypted_len) => {
|
||||
@@ -771,7 +774,8 @@ impl CommonState {
|
||||
Ok(None)
|
||||
}
|
||||
Err(Error::DecryptError) => {
|
||||
self.send_fatal_alert(AlertDescription::BadRecordMac).await;
|
||||
self.send_fatal_alert(AlertDescription::BadRecordMac)
|
||||
.await?;
|
||||
Err(Error::DecryptError)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@@ -781,18 +785,26 @@ impl CommonState {
|
||||
|
||||
/// Fragment `m`, encrypt the fragments, and then queue
|
||||
/// the encrypted fragments for sending.
|
||||
#[async_recursion]
|
||||
pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) {
|
||||
pub(crate) async fn send_msg_encrypt(&mut self, m: PlainMessage) -> Result<(), Error> {
|
||||
let mut plain_messages = VecDeque::new();
|
||||
self.message_fragmenter.fragment(m, &mut plain_messages);
|
||||
|
||||
for m in plain_messages {
|
||||
self.send_single_fragment(m).await;
|
||||
// Close connection once we start to run out of
|
||||
// sequence space.
|
||||
if self.record_layer.wants_close_before_encrypt() {
|
||||
debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
|
||||
let m = Message::build_alert(AlertLevel::Warning, AlertDescription::CloseNotify);
|
||||
self.send_single_fragment(m.into()).await?;
|
||||
}
|
||||
|
||||
for m in plain_messages {
|
||||
self.send_single_fragment(m).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Like send_msg_encrypt, but operate on an appdata directly.
|
||||
async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
|
||||
async fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> Result<usize, Error> {
|
||||
// Here, the limit on sendable_tls applies to encrypted data,
|
||||
// but we're respecting it for plaintext data -- so we'll
|
||||
// be out by whatever the cipher+record overhead is. That's a
|
||||
@@ -813,27 +825,25 @@ impl CommonState {
|
||||
);
|
||||
|
||||
for m in plain_messages {
|
||||
self.send_single_fragment(m).await;
|
||||
self.send_single_fragment(m).await?;
|
||||
}
|
||||
|
||||
len
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
async fn send_single_fragment(&mut self, m: PlainMessage) {
|
||||
// Close connection once we start to run out of
|
||||
// sequence space.
|
||||
if self.record_layer.wants_close_before_encrypt() {
|
||||
self.send_close_notify().await;
|
||||
}
|
||||
|
||||
async fn send_single_fragment(&mut self, m: PlainMessage) -> Result<(), Error> {
|
||||
// Refuse to wrap counter at all costs. This
|
||||
// is basically untestable unfortunately.
|
||||
if self.record_layer.encrypt_exhausted() {
|
||||
return;
|
||||
return Err(Error::EncryptError);
|
||||
}
|
||||
|
||||
let em = self.record_layer.encrypt_outgoing(m).await;
|
||||
let em = self
|
||||
.record_layer
|
||||
.encrypt_outgoing(self.crypto.as_mut(), m)
|
||||
.await?;
|
||||
self.queue_tls_message(em);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes TLS messages to `wr`.
|
||||
@@ -852,7 +862,7 @@ impl CommonState {
|
||||
///
|
||||
/// Returns the number of bytes written from `data`: this might
|
||||
/// be less than `data.len()` if buffer limits were exceeded.
|
||||
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
|
||||
async fn send_plain(&mut self, data: &[u8], limit: Limit) -> Result<usize, Error> {
|
||||
if !self.may_send_application_data {
|
||||
// If we haven't completed handshaking, buffer
|
||||
// plaintext to send once we do.
|
||||
@@ -860,27 +870,27 @@ impl CommonState {
|
||||
Limit::Yes => self.sendable_plaintext.append_limited_copy(data),
|
||||
Limit::No => self.sendable_plaintext.append(data.to_vec()),
|
||||
};
|
||||
return len;
|
||||
return Ok(len);
|
||||
}
|
||||
|
||||
debug_assert!(self.record_layer.is_encrypting());
|
||||
|
||||
if data.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return 0;
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, limit).await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_outgoing_traffic(&mut self) {
|
||||
pub(crate) async fn start_outgoing_traffic(&mut self) -> Result<(), Error> {
|
||||
self.may_send_application_data = true;
|
||||
self.flush_plaintext().await;
|
||||
self.flush_plaintext().await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_traffic(&mut self) {
|
||||
pub(crate) async fn start_traffic(&mut self) -> Result<(), Error> {
|
||||
self.may_receive_application_data = true;
|
||||
self.start_outgoing_traffic().await;
|
||||
self.start_outgoing_traffic().await
|
||||
}
|
||||
|
||||
/// Sets a limit on the internal buffers used to buffer
|
||||
@@ -929,14 +939,16 @@ impl CommonState {
|
||||
|
||||
/// Send any buffered plaintext. Plaintext is buffered if
|
||||
/// written during handshake.
|
||||
async fn flush_plaintext(&mut self) {
|
||||
async fn flush_plaintext(&mut self) -> Result<(), Error> {
|
||||
if !self.may_send_application_data {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while let Some(buf) = self.sendable_plaintext.pop() {
|
||||
self.send_plain(&buf, Limit::No).await;
|
||||
self.send_plain(&buf, Limit::No).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Put m into sendable_tls for writing.
|
||||
@@ -945,15 +957,16 @@ impl CommonState {
|
||||
}
|
||||
|
||||
/// Send a raw TLS message, fragmenting it if needed.
|
||||
pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) {
|
||||
pub(crate) async fn send_msg(&mut self, m: Message, must_encrypt: bool) -> Result<(), Error> {
|
||||
if !must_encrypt {
|
||||
let mut to_send = VecDeque::new();
|
||||
self.message_fragmenter.fragment(m.into(), &mut to_send);
|
||||
for mm in to_send {
|
||||
self.queue_tls_message(mm.into_unencrypted_opaque());
|
||||
}
|
||||
return Ok(());
|
||||
} else {
|
||||
self.send_msg_encrypt(m.into()).await;
|
||||
self.send_msg_encrypt(m.into()).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,16 +974,16 @@ impl CommonState {
|
||||
self.received_plaintext.append(bytes.0);
|
||||
}
|
||||
|
||||
async fn send_warning_alert(&mut self, desc: AlertDescription) {
|
||||
async fn send_warning_alert(&mut self, desc: AlertDescription) -> Result<(), Error> {
|
||||
warn!("Sending warning alert {:?}", desc);
|
||||
self.send_warning_alert_no_log(desc).await;
|
||||
self.send_warning_alert_no_log(desc).await
|
||||
}
|
||||
|
||||
async fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
|
||||
// Reject unknown AlertLevels.
|
||||
if let AlertLevel::Unknown(_) = alert.level {
|
||||
self.send_fatal_alert(AlertDescription::IllegalParameter)
|
||||
.await;
|
||||
.await?;
|
||||
}
|
||||
|
||||
// If we get a CloseNotify, make a note to declare EOF to our
|
||||
@@ -984,7 +997,7 @@ impl CommonState {
|
||||
// (except, for no good reason, user_cancelled).
|
||||
if alert.level == AlertLevel::Warning {
|
||||
if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
|
||||
self.send_fatal_alert(AlertDescription::DecodeError).await;
|
||||
self.send_fatal_alert(AlertDescription::DecodeError).await?;
|
||||
} else {
|
||||
warn!("TLS alert warning received: {:#?}", alert);
|
||||
return Ok(());
|
||||
@@ -995,26 +1008,27 @@ impl CommonState {
|
||||
Err(Error::AlertReceived(alert.description))
|
||||
}
|
||||
|
||||
pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) {
|
||||
pub(crate) async fn send_fatal_alert(&mut self, desc: AlertDescription) -> Result<(), Error> {
|
||||
warn!("Sending fatal alert {:?}", desc);
|
||||
debug_assert!(!self.sent_fatal_alert);
|
||||
let m = Message::build_alert(AlertLevel::Fatal, desc);
|
||||
self.send_msg(m, self.record_layer.is_encrypting()).await;
|
||||
self.send_msg(m, self.record_layer.is_encrypting()).await?;
|
||||
self.sent_fatal_alert = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Queues a close_notify warning alert to be sent in the next
|
||||
/// [`CommonState::write_tls`] call. This informs the peer that the
|
||||
/// connection is being closed.
|
||||
pub async fn send_close_notify(&mut self) {
|
||||
pub async fn send_close_notify(&mut self) -> Result<(), Error> {
|
||||
debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
|
||||
self.send_warning_alert_no_log(AlertDescription::CloseNotify)
|
||||
.await;
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
|
||||
async fn send_warning_alert_no_log(&mut self, desc: AlertDescription) -> Result<(), Error> {
|
||||
let m = Message::build_alert(AlertLevel::Warning, desc);
|
||||
self.send_msg(m, self.record_layer.is_encrypting()).await;
|
||||
self.send_msg(m, self.record_layer.is_encrypting()).await
|
||||
}
|
||||
|
||||
pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
|
||||
@@ -1054,7 +1068,7 @@ impl CommonState {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(?Send)]
|
||||
pub(crate) trait State<ClientConnectionData>: Send + Sync {
|
||||
async fn start(
|
||||
self: Box<Self>,
|
||||
|
||||
@@ -1,16 +1,35 @@
|
||||
use crate::Error;
|
||||
use tls_core::msgs::enums::ProtocolVersion;
|
||||
use tls_core::msgs::handshake::Random;
|
||||
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
|
||||
use tls_core::{key::PublicKey, suites::SupportedCipherSuite};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::cipher::{MessageDecrypter, MessageEncrypter};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EncryptMode {
|
||||
/// Encrypt payload with PSK
|
||||
EarlyData,
|
||||
/// Encrypt payload with Handshake keys
|
||||
Handshake,
|
||||
/// Encrypt payload with Application traffic keys
|
||||
Application,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DecryptMode {
|
||||
/// Decrypt payload with Handshake keys
|
||||
Handshake,
|
||||
/// Decrypt payload with Application traffic keys
|
||||
Application,
|
||||
}
|
||||
|
||||
/// Core trait which manages crypto operations for the TLS connection such as key exchange, encryption
|
||||
/// and decryption.
|
||||
#[async_trait]
|
||||
pub trait Handshake: Send {
|
||||
pub trait Crypto: Send {
|
||||
/// Signals selected protocol version to implementor.
|
||||
/// Throws error if version is not supported.
|
||||
fn select_protocol_version(&mut self, version: ProtocolVersion) -> Result<(), Error>;
|
||||
@@ -19,6 +38,10 @@ pub trait Handshake: Send {
|
||||
fn select_cipher_suite(&mut self, suite: SupportedCipherSuite) -> Result<(), Error>;
|
||||
/// Returns configured cipher suite.
|
||||
fn suite(&self) -> Result<SupportedCipherSuite, Error>;
|
||||
/// Set encryption mode
|
||||
fn set_encrypt(&mut self, mode: EncryptMode) -> Result<(), Error>;
|
||||
/// Start decryption
|
||||
fn set_decrypt(&mut self, mode: DecryptMode) -> Result<(), Error>;
|
||||
/// Returns client_random value.
|
||||
async fn client_random(&mut self) -> Result<Random, Error>;
|
||||
/// Returns public client keyshare.
|
||||
@@ -33,16 +56,16 @@ pub trait Handshake: Send {
|
||||
async fn server_finished(&mut self, hash: &[u8]) -> Result<Vec<u8>, Error>;
|
||||
/// Returns ClientFinished verify_data.
|
||||
async fn client_finished(&mut self, hash: &[u8]) -> Result<Vec<u8>, Error>;
|
||||
/// Returns initialized MessageEncrypter.
|
||||
async fn message_encrypter(&mut self) -> Result<Box<dyn MessageEncrypter>, Error>;
|
||||
/// Returns initialized MessageDecrypter.
|
||||
async fn message_decrypter(&mut self) -> Result<Box<dyn MessageDecrypter>, Error>;
|
||||
/// Perform the encryption over the concerned TLS message.
|
||||
async fn encrypt(&self, m: PlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
|
||||
/// Perform the decryption over the concerned TLS message.
|
||||
async fn decrypt(&self, m: OpaqueMessage, seq: u64) -> Result<PlainMessage, Error>;
|
||||
}
|
||||
|
||||
pub struct InvalidHandShaker {}
|
||||
pub struct InvalidCrypto {}
|
||||
|
||||
#[async_trait]
|
||||
impl Handshake for InvalidHandShaker {
|
||||
impl Crypto for InvalidCrypto {
|
||||
fn select_protocol_version(&mut self, _version: ProtocolVersion) -> Result<(), Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
@@ -52,6 +75,14 @@ impl Handshake for InvalidHandShaker {
|
||||
fn suite(&self) -> Result<SupportedCipherSuite, Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
/// Start encryption
|
||||
fn set_encrypt(&mut self, _mode: EncryptMode) -> Result<(), Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
/// Start decryption
|
||||
fn set_decrypt(&mut self, _mode: DecryptMode) -> Result<(), Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
async fn client_random(&mut self) -> Result<Random, Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
@@ -73,10 +104,10 @@ impl Handshake for InvalidHandShaker {
|
||||
async fn client_finished(&mut self, _hash: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
async fn message_encrypter(&mut self) -> Result<Box<dyn MessageEncrypter>, Error> {
|
||||
async fn encrypt(&self, _m: PlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
async fn message_decrypter(&mut self) -> Result<Box<dyn MessageDecrypter>, Error> {
|
||||
async fn decrypt(&self, _m: OpaqueMessage, _seq: u64) -> Result<PlainMessage, Error> {
|
||||
Err(Error::General("handshaker not yet available".to_string()))
|
||||
}
|
||||
}
|
||||
@@ -315,8 +315,8 @@ pub extern crate tls_core;
|
||||
mod anchors;
|
||||
mod cipher;
|
||||
mod conn;
|
||||
mod crypto;
|
||||
mod error;
|
||||
mod handshaker;
|
||||
mod hash_hs;
|
||||
mod limited_cache;
|
||||
mod msgs;
|
||||
@@ -357,7 +357,7 @@ pub use crate::key_log::{KeyLog, NoKeyLog};
|
||||
pub use crate::key_log_file::KeyLogFile;
|
||||
pub use crate::kx::{SupportedKxGroup, ALL_KX_GROUPS};
|
||||
pub use cipher::{MessageDecrypter, MessageEncrypter};
|
||||
pub use handshaker::Handshake;
|
||||
pub use crypto::Crypto;
|
||||
pub use tls_core::key::{Certificate, PrivateKey};
|
||||
pub use tls_core::msgs::enums::CipherSuite;
|
||||
pub use tls_core::msgs::enums::ProtocolVersion;
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::{
|
||||
InvalidMessageDecrypter, InvalidMessageEncrypter, MessageDecrypter, MessageEncrypter,
|
||||
},
|
||||
error::Error,
|
||||
Crypto,
|
||||
};
|
||||
use tls_core::msgs::message::{OpaqueMessage, PlainMessage};
|
||||
|
||||
@@ -22,8 +23,6 @@ enum DirectionState {
|
||||
}
|
||||
|
||||
pub(crate) struct RecordLayer {
|
||||
message_encrypter: Box<dyn MessageEncrypter>,
|
||||
message_decrypter: Box<dyn MessageDecrypter>,
|
||||
write_seq: u64,
|
||||
read_seq: u64,
|
||||
encrypt_state: DirectionState,
|
||||
@@ -38,8 +37,6 @@ pub(crate) struct RecordLayer {
|
||||
impl RecordLayer {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
message_encrypter: Box::new(InvalidMessageEncrypter {}),
|
||||
message_decrypter: Box::new(InvalidMessageDecrypter {}),
|
||||
write_seq: 0,
|
||||
read_seq: 0,
|
||||
encrypt_state: DirectionState::Invalid,
|
||||
@@ -71,16 +68,14 @@ impl RecordLayer {
|
||||
|
||||
/// Prepare to use the given `MessageEncrypter` for future message encryption.
|
||||
/// It is not used until you call `start_encrypting`.
|
||||
pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
|
||||
self.message_encrypter = cipher;
|
||||
pub(crate) fn prepare_message_encrypter(&mut self) {
|
||||
self.write_seq = 0;
|
||||
self.encrypt_state = DirectionState::Prepared;
|
||||
}
|
||||
|
||||
/// Prepare to use the given `MessageDecrypter` for future message decryption.
|
||||
/// It is not used until you call `start_decrypting`.
|
||||
pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
|
||||
self.message_decrypter = cipher;
|
||||
pub(crate) fn prepare_message_decrypter(&mut self) {
|
||||
self.read_seq = 0;
|
||||
self.decrypt_state = DirectionState::Prepared;
|
||||
}
|
||||
@@ -101,15 +96,15 @@ impl RecordLayer {
|
||||
|
||||
/// Set and start using the given `MessageEncrypter` for future outgoing
|
||||
/// message encryption.
|
||||
pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
|
||||
self.prepare_message_encrypter(cipher);
|
||||
pub(crate) fn set_message_encrypter(&mut self) {
|
||||
self.prepare_message_encrypter();
|
||||
self.start_encrypting();
|
||||
}
|
||||
|
||||
/// Set and start using the given `MessageDecrypter` for future incoming
|
||||
/// message decryption.
|
||||
pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
|
||||
self.prepare_message_decrypter(cipher);
|
||||
pub(crate) fn set_message_decrypter(&mut self) {
|
||||
self.prepare_message_decrypter();
|
||||
self.start_decrypting();
|
||||
self.trial_decryption_len = None;
|
||||
}
|
||||
@@ -117,12 +112,8 @@ impl RecordLayer {
|
||||
/// Set and start using the given `MessageDecrypter` for future incoming
|
||||
/// message decryption, and enable "trial decryption" mode for when TLS1.3
|
||||
/// 0-RTT is attempted but rejected by the server.
|
||||
pub(crate) fn set_message_decrypter_with_trial_decryption(
|
||||
&mut self,
|
||||
cipher: Box<dyn MessageDecrypter>,
|
||||
max_length: usize,
|
||||
) {
|
||||
self.prepare_message_decrypter(cipher);
|
||||
pub(crate) fn set_message_decrypter_with_trial_decryption(&mut self, max_length: usize) {
|
||||
self.prepare_message_decrypter();
|
||||
self.start_decrypting();
|
||||
self.trial_decryption_len = Some(max_length);
|
||||
}
|
||||
@@ -162,11 +153,12 @@ impl RecordLayer {
|
||||
/// an error is returned.
|
||||
pub(crate) async fn decrypt_incoming(
|
||||
&mut self,
|
||||
cipher: &mut dyn Crypto,
|
||||
encr: OpaqueMessage,
|
||||
) -> Result<PlainMessage, Error> {
|
||||
debug_assert!(self.is_decrypting());
|
||||
let seq = self.read_seq;
|
||||
let msg = self.message_decrypter.decrypt(encr, seq).await?;
|
||||
let msg = cipher.decrypt(encr, seq).await?;
|
||||
self.read_seq += 1;
|
||||
Ok(msg)
|
||||
}
|
||||
@@ -175,11 +167,15 @@ impl RecordLayer {
|
||||
///
|
||||
/// `plain` is a TLS message we'd like to send. This function
|
||||
/// panics if the requisite keying material hasn't been established yet.
|
||||
pub(crate) async fn encrypt_outgoing(&mut self, plain: PlainMessage) -> OpaqueMessage {
|
||||
pub(crate) async fn encrypt_outgoing(
|
||||
&mut self,
|
||||
cipher: &mut dyn Crypto,
|
||||
plain: PlainMessage,
|
||||
) -> Result<OpaqueMessage, Error> {
|
||||
debug_assert!(self.encrypt_state == DirectionState::Active);
|
||||
assert!(!self.encrypt_exhausted());
|
||||
let seq = self.write_seq;
|
||||
self.write_seq += 1;
|
||||
self.message_encrypter.encrypt(plain, seq).await.unwrap()
|
||||
cipher.encrypt(plain, seq).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,7 +460,7 @@ async fn client_close_notify() {
|
||||
// check that alerts don't overtake appdata
|
||||
assert_eq!(12, server.writer().write(b"from-server!").unwrap());
|
||||
assert_eq!(12, client.write_plaintext(b"from-client!").await.unwrap());
|
||||
client.send_close_notify().await;
|
||||
client.send_close_notify().await.unwrap();
|
||||
|
||||
send(&mut client, &mut server);
|
||||
let io_state = server.process_new_packets().unwrap();
|
||||
@@ -973,7 +973,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn new_fails(sess: &'a mut C) -> ServerSession<'a, C, S> {
|
||||
fn _new_fails(sess: &'a mut C) -> ServerSession<'a, C, S> {
|
||||
let mut os = ServerSession::new(sess);
|
||||
os.fail_ok = true;
|
||||
os
|
||||
@@ -1067,7 +1067,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn new_fails(sess: &'a mut C) -> ClientSession<'a, C> {
|
||||
fn _new_fails(sess: &'a mut C) -> ClientSession<'a, C> {
|
||||
let mut os = ClientSession::new(sess);
|
||||
os.fail_ok = true;
|
||||
os
|
||||
@@ -1088,7 +1088,7 @@ impl<'a, C> io::Write for ClientSession<'a, C>
|
||||
where
|
||||
C: DerefMut + Deref<Target = tls_client::ConnectionCommon>,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user