diff --git a/Cargo.lock b/Cargo.lock index 985b3d35e..8090c2aa5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5931,6 +5931,7 @@ dependencies = [ "rangeset", "rstest", "rustls-pki-types", + "rustls-webpki", "semver 1.0.26", "serde", "serio", @@ -5951,7 +5952,7 @@ dependencies = [ "tracing-subscriber", "uid-mux", "web-spawn", - "webpki-roots 0.26.11", + "webpki-roots", ] [[package]] @@ -5975,7 +5976,6 @@ dependencies = [ "tlsn-core", "tlsn-data-fixtures", "tlsn-tls-core", - "webpki-roots 0.26.11", ] [[package]] @@ -6013,6 +6013,8 @@ dependencies = [ "rangeset", "rs_merkle", "rstest", + "rustls-pki-types", + "rustls-webpki", "serde", "sha2", "thiserror 1.0.69", @@ -6021,7 +6023,9 @@ dependencies = [ "tlsn-tls-core", "tlsn-utils", "web-time 0.2.4", - "webpki-roots 0.26.11", + "webpki-root-certs", + "webpki-roots", + "zeroize", ] [[package]] @@ -6068,11 +6072,9 @@ dependencies = [ "spansy", "tls-server-fixture", "tlsn", - "tlsn-core", "tlsn-formats", "tlsn-server-fixture", "tlsn-server-fixture-certs", - "tlsn-tls-core", "tokio", "tokio-util", "tracing", @@ -6119,10 +6121,8 @@ dependencies = [ "serde_json", "serio", "tlsn", - "tlsn-core", "tlsn-harness-core", "tlsn-server-fixture-certs", - "tlsn-tls-core", "tlsn-wasm", "tokio", "tokio-util", @@ -6250,6 +6250,8 @@ dependencies = [ "rand 0.9.1", "rand_chacha 0.9.0", "rstest", + "rustls-pki-types", + "rustls-webpki", "serde", "serio", "thiserror 1.0.69", @@ -6324,14 +6326,15 @@ dependencies = [ "ring 0.17.14", "rustls 0.20.9", "rustls-pemfile", + "rustls-pki-types", + "rustls-webpki", "sct", "sha2", "tlsn-tls-backend", "tlsn-tls-core", "tokio", "web-time 0.2.4", - "webpki", - "webpki-roots 0.26.11", + "webpki-roots", ] [[package]] @@ -6344,6 +6347,8 @@ dependencies = [ "hyper", "hyper-util", "rstest", + "rustls-pki-types", + "rustls-webpki", "thiserror 1.0.69", "tls-server-fixture", "tlsn-tls-client", @@ -6361,13 +6366,14 @@ dependencies = [ "rand 0.9.1", "ring 0.17.14", "rustls-pemfile", + "rustls-pki-types", + "rustls-webpki", "sct", "serde", "sha2", "thiserror 1.0.69", "tracing", "web-time 0.2.4", - "webpki", ] [[package]] @@ -7064,19 +7070,19 @@ dependencies = [ ] [[package]] -name = "webpki-roots" -version = "0.26.11" +name = "webpki-root-certs" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +checksum = "4e4ffd8df1c57e87c325000a3d6ef93db75279dc3a231125aac571650f22b12a" dependencies = [ - "webpki-roots 1.0.1", + "rustls-pki-types", ] [[package]] name = "webpki-roots" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8782dd5a41a24eed3a4f40b606249b3e236ca61adf1f25ea4d45c73de122b502" +checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2" dependencies = [ "rustls-pki-types", ] diff --git a/Cargo.toml b/Cargo.toml index 835ff0625..44b6b0d82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -137,6 +137,8 @@ rs_merkle = { git = "https://github.com/tlsnotary/rs-merkle.git", rev = "85f3e82 rstest = { version = "0.17" } rustls = { version = "0.21" } rustls-pemfile = { version = "1.0" } +rustls-webpki = { version = "0.103" } +rustls-pki-types = { version = "1.12" } sct = { version = "0.7" } semver = { version = "1.0" } serde = { version = "1.0" } @@ -157,7 +159,8 @@ wasm-bindgen = { version = "0.2" } wasm-bindgen-futures = { version = "0.4" } web-spawn = { version = "0.2" } web-time = { version = "0.2" } -webpki = { version = "0.22" } -webpki-roots = { version = "0.26" } +webpki-roots = { version = "1.0" } +webpki-root-certs = { version = "1.0" } # Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958 ws_stream_wasm = { git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51" } +zeroize = { version = "1.8" } diff --git a/crates/attestation/Cargo.toml b/crates/attestation/Cargo.toml index 5aa936afa..18c2f424e 100644 --- a/crates/attestation/Cargo.toml +++ b/crates/attestation/Cargo.toml @@ -21,7 +21,6 @@ rand = { workspace = true } serde = { workspace = true, features = ["derive"] } thiserror = { workspace = true } tiny-keccak = { workspace = true, features = ["keccak"] } -webpki-roots = { workspace = true } [dev-dependencies] alloy-primitives = { version = "0.8.22", default-features = false } diff --git a/crates/attestation/src/builder.rs b/crates/attestation/src/builder.rs index 1f827c701..326d06f86 100644 --- a/crates/attestation/src/builder.rs +++ b/crates/attestation/src/builder.rs @@ -242,7 +242,7 @@ impl std::fmt::Display for AttestationBuilderError { mod test { use rstest::{fixture, rstest}; use tlsn_core::{ - connection::{HandshakeData, HandshakeDataV1_2}, + connection::{CertBinding, CertBindingV1_2}, fixtures::{ConnectionFixture, encoding_provider}, hash::Blake3, transcript::Transcript, @@ -399,10 +399,10 @@ mod test { server_cert_data, .. } = connection; - let HandshakeData::V1_2(HandshakeDataV1_2 { + let CertBinding::V1_2(CertBindingV1_2 { server_ephemeral_key, .. - }) = server_cert_data.handshake + }) = server_cert_data.binding else { panic!("expected v1.2 handshake data"); }; @@ -470,10 +470,10 @@ mod test { .. } = connection; - let HandshakeData::V1_2(HandshakeDataV1_2 { + let CertBinding::V1_2(CertBindingV1_2 { server_ephemeral_key, .. - }) = server_cert_data.handshake + }) = server_cert_data.binding else { panic!("expected v1.2 handshake data"); }; diff --git a/crates/attestation/src/connection.rs b/crates/attestation/src/connection.rs index 9327acec7..49080445f 100644 --- a/crates/attestation/src/connection.rs +++ b/crates/attestation/src/connection.rs @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; use tlsn_core::{ - connection::{CertificateVerificationError, ServerCertData, ServerEphemKey, ServerName}, + connection::{HandshakeData, HandshakeVerificationError, ServerEphemKey, ServerName}, hash::{Blinded, HashAlgorithm, HashProviderError, TypedHash}, }; @@ -30,14 +30,14 @@ use crate::{CryptoProvider, hash::HashAlgorithmExt, serialize::impl_domain_separ /// Opens a [`ServerCertCommitment`]. #[derive(Clone, Serialize, Deserialize)] -pub struct ServerCertOpening(Blinded); +pub struct ServerCertOpening(Blinded); impl_domain_separator!(ServerCertOpening); opaque_debug::implement!(ServerCertOpening); impl ServerCertOpening { - pub(crate) fn new(data: ServerCertData) -> Self { + pub(crate) fn new(data: HandshakeData) -> Self { Self(Blinded::new(data)) } @@ -49,7 +49,7 @@ impl ServerCertOpening { } /// Returns the server identity data. - pub fn data(&self) -> &ServerCertData { + pub fn data(&self) -> &HandshakeData { self.0.data() } } @@ -122,8 +122,8 @@ impl From for ServerIdentityProofError { } } -impl From for ServerIdentityProofError { - fn from(err: CertificateVerificationError) -> Self { +impl From for ServerIdentityProofError { + fn from(err: HandshakeVerificationError) -> Self { Self { kind: ErrorKind::Certificate, message: err.to_string(), diff --git a/crates/attestation/src/fixtures.rs b/crates/attestation/src/fixtures.rs index f586603ea..0e2d1b1ab 100644 --- a/crates/attestation/src/fixtures.rs +++ b/crates/attestation/src/fixtures.rs @@ -1,7 +1,7 @@ //! Attestation fixtures. use tlsn_core::{ - connection::{HandshakeData, HandshakeDataV1_2}, + connection::{CertBinding, CertBindingV1_2}, fixtures::ConnectionFixture, hash::HashAlgorithm, transcript::{ @@ -67,7 +67,7 @@ pub fn request_fixture( let mut request_builder = Request::builder(&request_config); request_builder .server_name(server_name) - .server_cert_data(server_cert_data) + .handshake_data(server_cert_data) .transcript(transcript); let (request, _) = request_builder.build(&provider).unwrap(); @@ -91,12 +91,12 @@ pub fn attestation_fixture( .. } = connection; - let HandshakeData::V1_2(HandshakeDataV1_2 { + let CertBinding::V1_2(CertBindingV1_2 { server_ephemeral_key, .. - }) = server_cert_data.handshake + }) = server_cert_data.binding else { - panic!("expected v1.2 handshake data"); + panic!("expected v1.2 binding data"); }; let mut provider = CryptoProvider::default(); diff --git a/crates/attestation/src/provider.rs b/crates/attestation/src/provider.rs index c02e6263a..3358f5808 100644 --- a/crates/attestation/src/provider.rs +++ b/crates/attestation/src/provider.rs @@ -1,8 +1,4 @@ -use tls_core::{ - anchors::{OwnedTrustAnchor, RootCertStore}, - verify::WebPkiVerifier, -}; -use tlsn_core::hash::HashProvider; +use tlsn_core::{hash::HashProvider, webpki::ServerCertVerifier}; use crate::signing::{SignatureVerifierProvider, SignerProvider}; @@ -28,7 +24,7 @@ pub struct CryptoProvider { /// This is used to verify the server's certificate chain. /// /// The default verifier uses the Mozilla root certificates. - pub cert: WebPkiVerifier, + pub cert: ServerCertVerifier, /// Signer provider. /// /// This is used for signing attestations. @@ -45,21 +41,9 @@ impl Default for CryptoProvider { fn default() -> Self { Self { hash: Default::default(), - cert: default_cert_verifier(), + cert: ServerCertVerifier::mozilla(), signer: Default::default(), signature: Default::default(), } } } - -pub(crate) fn default_cert_verifier() -> WebPkiVerifier { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); - WebPkiVerifier::new(root_store, None) -} diff --git a/crates/attestation/src/request/builder.rs b/crates/attestation/src/request/builder.rs index ac6e191a0..5c04195f2 100644 --- a/crates/attestation/src/request/builder.rs +++ b/crates/attestation/src/request/builder.rs @@ -1,5 +1,5 @@ use tlsn_core::{ - connection::{ServerCertData, ServerName}, + connection::{HandshakeData, ServerName}, transcript::{Transcript, TranscriptCommitment, TranscriptSecret}, }; @@ -13,7 +13,7 @@ use crate::{ pub struct RequestBuilder<'a> { config: &'a RequestConfig, server_name: Option, - server_cert_data: Option, + handshake_data: Option, transcript: Option, transcript_commitments: Vec, transcript_commitment_secrets: Vec, @@ -25,7 +25,7 @@ impl<'a> RequestBuilder<'a> { Self { config, server_name: None, - server_cert_data: None, + handshake_data: None, transcript: None, transcript_commitments: Vec::new(), transcript_commitment_secrets: Vec::new(), @@ -38,9 +38,9 @@ impl<'a> RequestBuilder<'a> { self } - /// Sets the server identity data. - pub fn server_cert_data(&mut self, data: ServerCertData) -> &mut Self { - self.server_cert_data = Some(data); + /// Sets the handshake data. + pub fn handshake_data(&mut self, data: HandshakeData) -> &mut Self { + self.handshake_data = Some(data); self } @@ -69,7 +69,7 @@ impl<'a> RequestBuilder<'a> { let Self { config, server_name, - server_cert_data, + handshake_data: server_cert_data, transcript, transcript_commitments, transcript_commitment_secrets, diff --git a/crates/attestation/src/serialize.rs b/crates/attestation/src/serialize.rs index 828e3d7a1..54ce33e30 100644 --- a/crates/attestation/src/serialize.rs +++ b/crates/attestation/src/serialize.rs @@ -46,7 +46,7 @@ pub(crate) use impl_domain_separator; impl_domain_separator!(tlsn_core::connection::ServerEphemKey); impl_domain_separator!(tlsn_core::connection::ConnectionInfo); -impl_domain_separator!(tlsn_core::connection::HandshakeData); +impl_domain_separator!(tlsn_core::connection::CertBinding); impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment); impl_domain_separator!(tlsn_core::transcript::TranscriptSecret); impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment); diff --git a/crates/attestation/tests/api.rs b/crates/attestation/tests/api.rs index 167a9b101..edb41e210 100644 --- a/crates/attestation/tests/api.rs +++ b/crates/attestation/tests/api.rs @@ -5,7 +5,7 @@ use tlsn_attestation::{ signing::SignatureAlgId, }; use tlsn_core::{ - connection::{HandshakeData, HandshakeDataV1_2}, + connection::{CertBinding, CertBindingV1_2}, fixtures::{self, ConnectionFixture, encoder_secret}, hash::Blake3, transcript::{ @@ -36,10 +36,10 @@ fn test_api() { server_cert_data, } = ConnectionFixture::tlsnotary(transcript.length()); - let HandshakeData::V1_2(HandshakeDataV1_2 { + let CertBinding::V1_2(CertBindingV1_2 { server_ephemeral_key, .. - }) = server_cert_data.handshake.clone() + }) = server_cert_data.binding.clone() else { unreachable!() }; @@ -72,7 +72,7 @@ fn test_api() { request_builder .server_name(server_name.clone()) - .server_cert_data(server_cert_data) + .handshake_data(server_cert_data) .transcript(transcript) .transcript_commitments( vec![TranscriptSecret::Encoding(encoding_tree)], diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 36436825e..d63304eeb 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -36,10 +36,14 @@ thiserror = { workspace = true } tiny-keccak = { workspace = true, features = ["keccak"] } web-time = { workspace = true } webpki-roots = { workspace = true } +rustls-webpki = { workspace = true, features = ["ring"] } +rustls-pki-types = { workspace = true } itybity = { workspace = true } +zeroize = { workspace = true } [dev-dependencies] bincode = { workspace = true } hex = { workspace = true } rstest = { workspace = true } tlsn-data-fixtures = { workspace = true } +webpki-root-certs = { workspace = true } diff --git a/crates/core/src/connection.rs b/crates/core/src/connection.rs index 85f6a1ef0..1d9a9cd2f 100644 --- a/crates/core/src/connection.rs +++ b/crates/core/src/connection.rs @@ -2,16 +2,11 @@ use std::fmt; +use rustls_pki_types as webpki_types; use serde::{Deserialize, Serialize}; -use tls_core::{ - msgs::{ - codec::Codec, - enums::NamedGroup, - handshake::{DigitallySignedStruct, ServerECDHParams}, - }, - verify::{ServerCertVerifier as _, WebPkiVerifier}, -}; -use web_time::{Duration, UNIX_EPOCH}; +use tls_core::msgs::{codec::Codec, enums::NamedGroup, handshake::ServerECDHParams}; + +use crate::webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError}; /// TLS version. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -35,40 +30,82 @@ impl TryFrom for TlsVersion { } } -/// Server's name, a.k.a. the DNS name. +/// Server's name. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct ServerName(String); +pub enum ServerName { + /// DNS name. + Dns(DnsName), +} impl ServerName { - /// Creates a new server name. - pub fn new(name: String) -> Self { - Self(name) - } - - /// Returns the name as a string. - pub fn as_str(&self) -> &str { - &self.0 - } -} - -impl From<&str> for ServerName { - fn from(name: &str) -> Self { - Self(name.to_string()) - } -} - -impl AsRef for ServerName { - fn as_ref(&self) -> &str { - &self.0 + pub(crate) fn to_webpki(&self) -> webpki_types::ServerName<'static> { + match self { + ServerName::Dns(name) => webpki_types::ServerName::DnsName( + webpki_types::DnsName::try_from(name.0.as_str()) + .expect("name was validated") + .to_owned(), + ), + } } } impl fmt::Display for ServerName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ServerName::Dns(name) => write!(f, "{name}"), + } + } +} + +/// DNS name. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(try_from = "String")] +pub struct DnsName(String); + +impl DnsName { + /// Returns the DNS name as a string. + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + +impl fmt::Display for DnsName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } +impl AsRef for DnsName { + fn as_ref(&self) -> &str { + &self.0 + } +} + +/// Error returned when a DNS name is invalid. +#[derive(Debug, thiserror::Error)] +#[error("invalid DNS name")] +pub struct InvalidDnsNameError {} + +impl TryFrom<&str> for DnsName { + type Error = InvalidDnsNameError; + + fn try_from(value: &str) -> Result { + // Borrow validation from rustls + match webpki_types::DnsName::try_from_str(value) { + Ok(_) => Ok(DnsName(value.to_string())), + Err(_) => Err(InvalidDnsNameError {}), + } + } +} + +impl TryFrom for DnsName { + type Error = InvalidDnsNameError; + + fn try_from(value: String) -> Result { + Self::try_from(value.as_str()) + } +} + /// Type of a public key. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -98,6 +135,25 @@ pub enum SignatureScheme { ED25519 = 0x0807, } +impl fmt::Display for SignatureScheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SignatureScheme::RSA_PKCS1_SHA1 => write!(f, "RSA_PKCS1_SHA1"), + SignatureScheme::ECDSA_SHA1_Legacy => write!(f, "ECDSA_SHA1_Legacy"), + SignatureScheme::RSA_PKCS1_SHA256 => write!(f, "RSA_PKCS1_SHA256"), + SignatureScheme::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"), + SignatureScheme::RSA_PKCS1_SHA384 => write!(f, "RSA_PKCS1_SHA384"), + SignatureScheme::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"), + SignatureScheme::RSA_PKCS1_SHA512 => write!(f, "RSA_PKCS1_SHA512"), + SignatureScheme::ECDSA_NISTP521_SHA512 => write!(f, "ECDSA_NISTP521_SHA512"), + SignatureScheme::RSA_PSS_SHA256 => write!(f, "RSA_PSS_SHA256"), + SignatureScheme::RSA_PSS_SHA384 => write!(f, "RSA_PSS_SHA384"), + SignatureScheme::RSA_PSS_SHA512 => write!(f, "RSA_PSS_SHA512"), + SignatureScheme::ED25519 => write!(f, "ED25519"), + } + } +} + impl TryFrom for SignatureScheme { type Error = &'static str; @@ -142,16 +198,6 @@ impl From for tls_core::msgs::enums::SignatureScheme { } } -/// X.509 certificate, DER encoded. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Certificate(pub Vec); - -impl From for Certificate { - fn from(cert: tls_core::key::Certificate) -> Self { - Self(cert.0) - } -} - /// Server's signature of the key exchange parameters. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerSignature { @@ -220,9 +266,9 @@ pub struct TranscriptLength { pub received: u32, } -/// TLS 1.2 handshake data. +/// TLS 1.2 certificate binding. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HandshakeDataV1_2 { +pub struct CertBindingV1_2 { /// Client random. pub client_random: [u8; 32], /// Server random. @@ -231,13 +277,18 @@ pub struct HandshakeDataV1_2 { pub server_ephemeral_key: ServerEphemKey, } -/// TLS handshake data. +/// TLS certificate binding. +/// +/// This is the data that the server signs using its public key in the +/// certificate it presents during the TLS handshake. This provides a binding +/// between the server's identity and the ephemeral keys used to authenticate +/// the TLS session. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[non_exhaustive] -pub enum HandshakeData { - /// TLS 1.2 handshake data. - V1_2(HandshakeDataV1_2), +pub enum CertBinding { + /// TLS 1.2 certificate binding. + V1_2(CertBindingV1_2), } /// Verify data from the TLS handshake finished messages. @@ -249,19 +300,19 @@ pub struct VerifyData { pub server_finished: Vec, } -/// Server certificate and handshake data. +/// TLS handshake data. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerCertData { - /// Certificate chain. - pub certs: Vec, - /// Server signature of the key exchange parameters. +pub struct HandshakeData { + /// Server certificate chain. + pub certs: Vec, + /// Server certificate signature over the binding message. pub sig: ServerSignature, - /// TLS handshake data. - pub handshake: HandshakeData, + /// Certificate binding. + pub binding: CertBinding, } -impl ServerCertData { - /// Verifies the server certificate data. +impl HandshakeData { + /// Verifies the handshake data. /// /// # Arguments /// @@ -271,53 +322,35 @@ impl ServerCertData { /// * `server_name` - The server name. pub fn verify( &self, - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, time: u64, server_ephemeral_key: &ServerEphemKey, server_name: &ServerName, - ) -> Result<(), CertificateVerificationError> { + ) -> Result<(), HandshakeVerificationError> { #[allow(irrefutable_let_patterns)] - let HandshakeData::V1_2(HandshakeDataV1_2 { + let CertBinding::V1_2(CertBindingV1_2 { client_random, server_random, server_ephemeral_key: expected_server_ephemeral_key, - }) = &self.handshake + }) = &self.binding else { unreachable!("only TLS 1.2 is implemented") }; if server_ephemeral_key != expected_server_ephemeral_key { - return Err(CertificateVerificationError::InvalidServerEphemeralKey); + return Err(HandshakeVerificationError::InvalidServerEphemeralKey); } - // Verify server name. - let server_name = tls_core::dns::ServerName::try_from(server_name.as_ref()) - .map_err(|_| CertificateVerificationError::InvalidIdentity(server_name.clone()))?; - - // Verify server certificate. - let cert_chain = self + let (end_entity, intermediates) = self .certs - .clone() - .into_iter() - .map(|cert| tls_core::key::Certificate(cert.0)) - .collect::>(); - - let (end_entity, intermediates) = cert_chain .split_first() - .ok_or(CertificateVerificationError::MissingCerts)?; + .ok_or(HandshakeVerificationError::MissingCerts)?; // Verify the end entity cert is valid for the provided server name // and that it chains to at least one of the roots we trust. verifier - .verify_server_cert( - end_entity, - intermediates, - &server_name, - &mut [].into_iter(), - &[], - UNIX_EPOCH + Duration::from_secs(time), - ) - .map_err(|_| CertificateVerificationError::InvalidCert)?; + .verify_server_cert(end_entity, intermediates, server_name, time) + .map_err(HandshakeVerificationError::ServerCert)?; // Verify the signature matches the certificate and key exchange parameters. let mut message = Vec::new(); @@ -325,11 +358,31 @@ impl ServerCertData { message.extend_from_slice(server_random); message.extend_from_slice(&server_ephemeral_key.kx_params()); - let dss = DigitallySignedStruct::new(self.sig.scheme.into(), self.sig.sig.clone()); + use webpki::ring as alg; + let sig_alg = match self.sig.scheme { + SignatureScheme::RSA_PKCS1_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256, + SignatureScheme::RSA_PKCS1_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384, + SignatureScheme::RSA_PKCS1_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512, + SignatureScheme::RSA_PSS_SHA256 => alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY, + SignatureScheme::RSA_PSS_SHA384 => alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY, + SignatureScheme::RSA_PSS_SHA512 => alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY, + SignatureScheme::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256, + SignatureScheme::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384, + SignatureScheme::ED25519 => alg::ED25519, + scheme => { + return Err(HandshakeVerificationError::UnsupportedSignatureScheme( + scheme, + )) + } + }; - verifier - .verify_tls12_signature(&message, end_entity, &dss) - .map_err(|_| CertificateVerificationError::InvalidServerSignature)?; + let end_entity = webpki_types::CertificateDer::from(end_entity.0.as_slice()); + let end_entity = webpki::EndEntityCert::try_from(&end_entity) + .map_err(|_| HandshakeVerificationError::InvalidEndEntityCertificate)?; + + end_entity + .verify_signature(sig_alg, &message, &self.sig.sig) + .map_err(|_| HandshakeVerificationError::InvalidServerSignature)?; Ok(()) } @@ -338,58 +391,51 @@ impl ServerCertData { /// Errors that can occur when verifying a certificate chain or signature. #[derive(Debug, thiserror::Error)] #[allow(missing_docs)] -pub enum CertificateVerificationError { - #[error("invalid server identity: {0}")] - InvalidIdentity(ServerName), +pub enum HandshakeVerificationError { + #[error("invalid end entity certificate")] + InvalidEndEntityCertificate, #[error("missing server certificates")] MissingCerts, - #[error("invalid server certificate")] - InvalidCert, #[error("invalid server signature")] InvalidServerSignature, #[error("invalid server ephemeral key")] InvalidServerEphemeralKey, + #[error("server certificate verification failed: {0}")] + ServerCert(ServerCertVerifierError), + #[error("unsupported signature scheme: {0}")] + UnsupportedSignatureScheme(SignatureScheme), } #[cfg(test)] mod tests { use super::*; - use crate::{fixtures::ConnectionFixture, transcript::Transcript}; + use crate::{fixtures::ConnectionFixture, transcript::Transcript, webpki::RootCertStore}; use hex::FromHex; use rstest::*; - use tls_core::{ - anchors::{OwnedTrustAnchor, RootCertStore}, - verify::WebPkiVerifier, - }; use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON}; #[fixture] #[once] - fn verifier() -> WebPkiVerifier { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); + fn verifier() -> ServerCertVerifier { + let mut root_store = RootCertStore { + roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS + .iter() + .map(|c| CertificateDer(c.to_vec())) + .collect(), + }; // Add a cert which is no longer included in the Mozilla root store. - let cert = tls_core::key::Certificate( + root_store.roots.push( appliedzkp() .server_cert_data .certs .last() .expect("chain is valid") - .0 .clone(), ); - root_store.add(&cert).unwrap(); - - WebPkiVerifier::new(root_store, None) + ServerCertVerifier::new(&root_store).unwrap() } fn tlsnotary() -> ConnectionFixture { @@ -405,7 +451,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_sucess_ca_implicit( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { // Remove the CA cert @@ -417,7 +463,7 @@ mod tests { verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ) .is_ok()); } @@ -428,7 +474,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_success_ca_explicit( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] data: ConnectionFixture, ) { assert!(data @@ -437,7 +483,7 @@ mod tests { verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ) .is_ok()); } @@ -447,7 +493,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_fail_bad_time( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] data: ConnectionFixture, ) { // unix time when the cert chain was NOT valid @@ -457,12 +503,12 @@ mod tests { verifier, bad_time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidCert + HandshakeVerificationError::ServerCert(_) )); } @@ -471,7 +517,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_fail_no_interm_cert( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { // Remove the CA cert @@ -483,12 +529,12 @@ mod tests { verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidCert + HandshakeVerificationError::ServerCert(_) )); } @@ -498,7 +544,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_fail_no_interm_cert_with_ca_cert( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { // Remove the intermediate cert @@ -508,12 +554,12 @@ mod tests { verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidCert + HandshakeVerificationError::ServerCert(_) )); } @@ -522,24 +568,24 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_fail_bad_ee_cert( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { let ee: &[u8] = include_bytes!("./fixtures/data/unknown/ee.der"); // Change the end entity cert - data.server_cert_data.certs[0] = Certificate(ee.to_vec()); + data.server_cert_data.certs[0] = CertificateDer(ee.to_vec()); let err = data.server_cert_data.verify( verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidCert + HandshakeVerificationError::ServerCert(_) )); } @@ -548,23 +594,23 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_sig_ke_params_fail_bad_client_random( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { - let HandshakeData::V1_2(HandshakeDataV1_2 { client_random, .. }) = - &mut data.server_cert_data.handshake; + let CertBinding::V1_2(CertBindingV1_2 { client_random, .. }) = + &mut data.server_cert_data.binding; client_random[31] = client_random[31].wrapping_add(1); let err = data.server_cert_data.verify( verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidServerSignature + HandshakeVerificationError::InvalidServerSignature )); } @@ -573,7 +619,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_sig_ke_params_fail_bad_sig( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { data.server_cert_data.sig.sig[31] = data.server_cert_data.sig.sig[31].wrapping_add(1); @@ -582,12 +628,12 @@ mod tests { verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidServerSignature + HandshakeVerificationError::InvalidServerSignature )); } @@ -596,10 +642,10 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_check_dns_name_present_in_cert_fail_bad_host( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] data: ConnectionFixture, ) { - let bad_name = ServerName::from("badhost.com"); + let bad_name = ServerName::Dns(DnsName::try_from("badhost.com").unwrap()); let err = data.server_cert_data.verify( verifier, @@ -610,7 +656,7 @@ mod tests { assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidCert + HandshakeVerificationError::ServerCert(_) )); } @@ -618,7 +664,7 @@ mod tests { #[rstest] #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] - fn test_invalid_ephemeral_key(verifier: &WebPkiVerifier, #[case] data: ConnectionFixture) { + fn test_invalid_ephemeral_key(verifier: &ServerCertVerifier, #[case] data: ConnectionFixture) { let wrong_ephemeral_key = ServerEphemKey { typ: KeyType::SECP256R1, key: Vec::::from_hex(include_bytes!("./fixtures/data/unknown/pubkey")).unwrap(), @@ -628,12 +674,12 @@ mod tests { verifier, data.connection_info.time, &wrong_ephemeral_key, - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::InvalidServerEphemeralKey + HandshakeVerificationError::InvalidServerEphemeralKey )); } @@ -642,7 +688,7 @@ mod tests { #[case::tlsnotary(tlsnotary())] #[case::appliedzkp(appliedzkp())] fn test_verify_cert_chain_fail_no_cert( - verifier: &WebPkiVerifier, + verifier: &ServerCertVerifier, #[case] mut data: ConnectionFixture, ) { // Empty certs @@ -652,12 +698,12 @@ mod tests { verifier, data.connection_info.time, data.server_ephemeral_key(), - &ServerName::from(data.server_name.as_ref()), + &data.server_name, ); assert!(matches!( err.unwrap_err(), - CertificateVerificationError::MissingCerts + HandshakeVerificationError::MissingCerts )); } } diff --git a/crates/core/src/fixtures.rs b/crates/core/src/fixtures.rs index f1b332cb3..ce54cd03d 100644 --- a/crates/core/src/fixtures.rs +++ b/crates/core/src/fixtures.rs @@ -8,13 +8,14 @@ use hex::FromHex; use crate::{ connection::{ - Certificate, ConnectionInfo, HandshakeData, HandshakeDataV1_2, KeyType, ServerCertData, + CertBinding, CertBindingV1_2, ConnectionInfo, DnsName, HandshakeData, KeyType, ServerEphemKey, ServerName, ServerSignature, SignatureScheme, TlsVersion, TranscriptLength, }, transcript::{ encoding::{EncoderSecret, EncodingProvider}, Transcript, }, + webpki::CertificateDer, }; /// A fixture containing various TLS connection data. @@ -23,24 +24,26 @@ use crate::{ pub struct ConnectionFixture { pub server_name: ServerName, pub connection_info: ConnectionInfo, - pub server_cert_data: ServerCertData, + pub server_cert_data: HandshakeData, } impl ConnectionFixture { /// Returns a connection fixture for tlsnotary.org. pub fn tlsnotary(transcript_length: TranscriptLength) -> Self { ConnectionFixture { - server_name: ServerName::new("tlsnotary.org".to_string()), + server_name: ServerName::Dns(DnsName::try_from("tlsnotary.org").unwrap()), connection_info: ConnectionInfo { time: 1671637529, version: TlsVersion::V1_2, transcript_length, }, - server_cert_data: ServerCertData { + server_cert_data: HandshakeData { certs: vec![ - Certificate(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()), - Certificate(include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec()), - Certificate(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()), + CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()), + CertificateDer( + include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec(), + ), + CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()), ], sig: ServerSignature { scheme: SignatureScheme::RSA_PKCS1_SHA256, @@ -49,7 +52,7 @@ impl ConnectionFixture { )) .unwrap(), }, - handshake: HandshakeData::V1_2(HandshakeDataV1_2 { + binding: CertBinding::V1_2(CertBindingV1_2 { client_random: <[u8; 32]>::from_hex(include_bytes!( "fixtures/data/tlsnotary.org/client_random" )) @@ -73,17 +76,19 @@ impl ConnectionFixture { /// Returns a connection fixture for appliedzkp.org. pub fn appliedzkp(transcript_length: TranscriptLength) -> Self { ConnectionFixture { - server_name: ServerName::new("appliedzkp.org".to_string()), + server_name: ServerName::Dns(DnsName::try_from("appliedzkp.org").unwrap()), connection_info: ConnectionInfo { time: 1671637529, version: TlsVersion::V1_2, transcript_length, }, - server_cert_data: ServerCertData { + server_cert_data: HandshakeData { certs: vec![ - Certificate(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()), - Certificate(include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec()), - Certificate(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()), + CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()), + CertificateDer( + include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec(), + ), + CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()), ], sig: ServerSignature { scheme: SignatureScheme::ECDSA_NISTP256_SHA256, @@ -92,7 +97,7 @@ impl ConnectionFixture { )) .unwrap(), }, - handshake: HandshakeData::V1_2(HandshakeDataV1_2 { + binding: CertBinding::V1_2(CertBindingV1_2 { client_random: <[u8; 32]>::from_hex(include_bytes!( "fixtures/data/appliedzkp.org/client_random" )) @@ -115,10 +120,10 @@ impl ConnectionFixture { /// Returns the server_ephemeral_key fixture. pub fn server_ephemeral_key(&self) -> &ServerEphemKey { - let HandshakeData::V1_2(HandshakeDataV1_2 { + let CertBinding::V1_2(CertBindingV1_2 { server_ephemeral_key, .. - }) = &self.server_cert_data.handshake; + }) = &self.server_cert_data.binding; server_ephemeral_key } } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 48a0bdf3e..d43cec39e 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -10,12 +10,13 @@ pub mod fixtures; pub mod hash; pub mod merkle; pub mod transcript; +pub mod webpki; use rangeset::ToRangeSet; use serde::{Deserialize, Serialize}; use crate::{ - connection::{ServerCertData, ServerName}, + connection::{HandshakeData, ServerName}, transcript::{ Direction, Idx, PartialTranscript, Transcript, TranscriptCommitConfig, TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret, @@ -200,8 +201,8 @@ enum VerifyConfigBuilderErrorRepr {} #[doc(hidden)] #[derive(Debug, Serialize, Deserialize)] pub struct ProvePayload { - /// Server identity data. - pub server_identity: Option<(ServerName, ServerCertData)>, + /// Handshake data. + pub handshake: Option<(ServerName, HandshakeData)>, /// Transcript data. pub transcript: Option, /// Transcript commitment configuration. diff --git a/crates/core/src/transcript/tls.rs b/crates/core/src/transcript/tls.rs index 8fdc8e6d9..0c9028e2b 100644 --- a/crates/core/src/transcript/tls.rs +++ b/crates/core/src/transcript/tls.rs @@ -2,10 +2,10 @@ use crate::{ connection::{ - Certificate, HandshakeData, HandshakeDataV1_2, ServerEphemKey, ServerSignature, TlsVersion, - VerifyData, + CertBinding, CertBindingV1_2, ServerEphemKey, ServerSignature, TlsVersion, VerifyData, }, transcript::{Direction, Transcript}, + webpki::CertificateDer, }; use tls_core::msgs::{ alert::AlertMessagePayload, @@ -19,9 +19,9 @@ use tls_core::msgs::{ pub struct TlsTranscript { time: u64, version: TlsVersion, - server_cert_chain: Option>, + server_cert_chain: Option>, server_signature: Option, - handshake_data: HandshakeData, + certificate_binding: CertBinding, sent: Vec, recv: Vec, } @@ -32,9 +32,9 @@ impl TlsTranscript { pub fn new( time: u64, version: TlsVersion, - server_cert_chain: Option>, + server_cert_chain: Option>, server_signature: Option, - handshake_data: HandshakeData, + certificate_binding: CertBinding, verify_data: VerifyData, sent: Vec, recv: Vec, @@ -198,7 +198,7 @@ impl TlsTranscript { version, server_cert_chain, server_signature, - handshake_data, + certificate_binding, sent, recv, }) @@ -215,7 +215,7 @@ impl TlsTranscript { } /// Returns the server certificate chain. - pub fn server_cert_chain(&self) -> Option<&[Certificate]> { + pub fn server_cert_chain(&self) -> Option<&[CertificateDer]> { self.server_cert_chain.as_deref() } @@ -226,17 +226,17 @@ impl TlsTranscript { /// Returns the server ephemeral key used in the TLS handshake. pub fn server_ephemeral_key(&self) -> &ServerEphemKey { - match &self.handshake_data { - HandshakeData::V1_2(HandshakeDataV1_2 { + match &self.certificate_binding { + CertBinding::V1_2(CertBindingV1_2 { server_ephemeral_key, .. }) => server_ephemeral_key, } } - /// Returns the handshake data. - pub fn handshake_data(&self) -> &HandshakeData { - &self.handshake_data + /// Returns the certificate binding data. + pub fn certificate_binding(&self) -> &CertBinding { + &self.certificate_binding } /// Returns the sent records. diff --git a/crates/core/src/webpki.rs b/crates/core/src/webpki.rs new file mode 100644 index 000000000..83ef0cb95 --- /dev/null +++ b/crates/core/src/webpki.rs @@ -0,0 +1,168 @@ +//! Web PKI types. + +use std::time::Duration; + +use rustls_pki_types::{self as webpki_types, pem::PemObject}; +use serde::{Deserialize, Serialize}; + +use crate::connection::ServerName; + +/// X.509 certificate, DER encoded. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertificateDer(pub Vec); + +impl CertificateDer { + /// Creates a DER-encoded certificate from a PEM-encoded certificate. + pub fn from_pem_slice(pem: &[u8]) -> Result { + let der = webpki_types::CertificateDer::from_pem_slice(pem).map_err(|_| PemError {})?; + + Ok(Self(der.to_vec())) + } +} + +/// Private key, DER encoded. +#[derive(Debug, Clone, zeroize::ZeroizeOnDrop, Serialize, Deserialize)] +pub struct PrivateKeyDer(pub Vec); + +impl PrivateKeyDer { + /// Creates a DER-encoded private key from a PEM-encoded private key. + pub fn from_pem_slice(pem: &[u8]) -> Result { + let der = webpki_types::PrivateKeyDer::from_pem_slice(pem).map_err(|_| PemError {})?; + + Ok(Self(der.secret_der().to_vec())) + } +} + +/// PEM parsing error. +#[derive(Debug, thiserror::Error)] +#[error("failed to parse PEM object")] +pub struct PemError {} + +/// Root certificate store. +/// +/// This stores root certificates which are used to verify end-entity +/// certificates presented by a TLS server. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RootCertStore { + /// Unvalidated DER-encoded X.509 root certificates. + pub roots: Vec, +} + +impl RootCertStore { + /// Creates an empty root certificate store. + pub fn empty() -> Self { + Self { roots: Vec::new() } + } +} + +/// Server certificate verifier. +#[derive(Debug)] +pub struct ServerCertVerifier { + roots: Vec>, +} + +impl ServerCertVerifier { + /// Creates a new server certificate verifier. + pub fn new(roots: &RootCertStore) -> Result { + let roots = roots + .roots + .iter() + .map(|cert| { + webpki::anchor_from_trusted_cert(&webpki_types::CertificateDer::from( + cert.0.as_slice(), + )) + .map(|anchor| anchor.to_owned()) + .map_err(|err| ServerCertVerifierError::InvalidRootCertificate { + cert: cert.clone(), + reason: err.to_string(), + }) + }) + .collect::, _>>()?; + + Ok(Self { roots }) + } + + /// Creates a new server certificate verifier with Mozilla root + /// certificates. + pub fn mozilla() -> Self { + Self { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + } + } + + /// Verifies the server certificate was valid at the given time of + /// presentation. + /// + /// # Arguments + /// + /// * `end_entity` - End-entity certificate to verify. + /// * `intermediates` - Intermediate certificates to a trust anchor. + /// * `server_name` - Server DNS name. + /// * `time` - Unix time the certificate was presented. + pub fn verify_server_cert( + &self, + end_entity: &CertificateDer, + intermediates: &[CertificateDer], + server_name: &ServerName, + time: u64, + ) -> Result<(), ServerCertVerifierError> { + let cert = webpki_types::CertificateDer::from(end_entity.0.as_slice()); + let cert = webpki::EndEntityCert::try_from(&cert).map_err(|e| { + ServerCertVerifierError::InvalidEndEntityCertificate { + cert: end_entity.clone(), + reason: e.to_string(), + } + })?; + let intermediates = intermediates + .iter() + .map(|c| webpki_types::CertificateDer::from(c.0.as_slice())) + .collect::>(); + let server_name = server_name.to_webpki(); + let time = webpki_types::UnixTime::since_unix_epoch(Duration::from_secs(time)); + + cert.verify_for_usage( + webpki::ALL_VERIFICATION_ALGS, + &self.roots, + &intermediates, + time, + webpki::KeyUsage::server_auth(), + None, + None, + ) + .map(|_| ()) + .map_err(|_| ServerCertVerifierError::InvalidPath)?; + + cert.verify_is_valid_for_subject_name(&server_name) + .map_err(|_| ServerCertVerifierError::InvalidServerName)?; + + Ok(()) + } +} + +/// Error for [`ServerCertVerifier`]. +#[derive(Debug, thiserror::Error)] +#[error("server certificate verification failed: {0}")] +pub enum ServerCertVerifierError { + /// Root certificate store contains invalid certificate. + #[error("root certificate store contains invalid certificate: {reason}")] + InvalidRootCertificate { + /// Invalid certificate. + cert: CertificateDer, + /// Reason for invalidity. + reason: String, + }, + /// End-entity certificate is invalid. + #[error("end-entity certificate is invalid: {reason}")] + InvalidEndEntityCertificate { + /// Invalid certificate. + cert: CertificateDer, + /// Reason for invalidity. + reason: String, + }, + /// Failed to verify certificate path to provided trust anchors. + #[error("failed to verify certificate path to provided trust anchors")] + InvalidPath, + /// Failed to verify certificate is valid for provided server name. + #[error("failed to verify certificate is valid for provided server name")] + InvalidServerName, +} diff --git a/crates/examples/Cargo.toml b/crates/examples/Cargo.toml index 210777268..7c11bc415 100644 --- a/crates/examples/Cargo.toml +++ b/crates/examples/Cargo.toml @@ -8,10 +8,8 @@ version = "0.0.0" workspace = true [dependencies] -tlsn-core = { workspace = true } tlsn = { workspace = true } tlsn-formats = { workspace = true } -tlsn-tls-core = { workspace = true } tls-server-fixture = { workspace = true } tlsn-server-fixture = { workspace = true } tlsn-server-fixture-certs = { workspace = true } diff --git a/crates/examples/interactive/interactive.rs b/crates/examples/interactive/interactive.rs index 931ae6b18..574f5ccfb 100644 --- a/crates/examples/interactive/interactive.rs +++ b/crates/examples/interactive/interactive.rs @@ -12,7 +12,8 @@ use tracing::instrument; use tls_server_fixture::CA_CERT_DER; use tlsn::{ - config::{ProtocolConfig, ProtocolConfigValidator}, + config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore}, + connection::ServerName, prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, transcript::PartialTranscript, verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig}, @@ -72,18 +73,16 @@ async fn prover( // Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the // server-fixture. - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); let mut tls_config_builder = TlsConfig::builder(); - tls_config_builder.root_store(root_store); + tls_config_builder.root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }); let tls_config = tls_config_builder.build().unwrap(); // Set up protocol configuration for prover. let mut prover_config_builder = ProverConfig::builder(); prover_config_builder - .server_name(server_domain) + .server_name(ServerName::Dns(server_domain.try_into().unwrap())) .tls_config(tls_config) .protocol_config( ProtocolConfig::builder() @@ -194,13 +193,10 @@ async fn verifier( // Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the // server-fixture. - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - let verifier_config = VerifierConfig::builder() - .root_store(root_store) + .root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }) .protocol_config_validator(config_validator) .build() .unwrap(); @@ -234,6 +230,7 @@ async fn verifier( .unwrap_or_else(|| panic!("Expected valid data from {SERVER_DOMAIN}")); // Check Session info: server name. + let ServerName::Dns(server_name) = server_name; assert_eq!(server_name.as_str(), SERVER_DOMAIN); transcript diff --git a/crates/harness/executor/Cargo.toml b/crates/harness/executor/Cargo.toml index d315a1915..915c4e784 100644 --- a/crates/harness/executor/Cargo.toml +++ b/crates/harness/executor/Cargo.toml @@ -14,8 +14,6 @@ wasm-opt = ["-O3"] [dependencies] tlsn-harness-core = { workspace = true } tlsn = { workspace = true } -tlsn-core = { workspace = true } -tlsn-tls-core = { workspace = true } tlsn-server-fixture-certs = { workspace = true } inventory = { workspace = true } diff --git a/crates/harness/executor/src/bench/prover.rs b/crates/harness/executor/src/bench/prover.rs index f359ba2ec..2d19e7a7b 100644 --- a/crates/harness/executor/src/bench/prover.rs +++ b/crates/harness/executor/src/bench/prover.rs @@ -5,7 +5,8 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt}; use harness_core::bench::{Bench, ProverMetrics}; use tlsn::{ - config::ProtocolConfig, + config::{CertificateDer, ProtocolConfig, RootCertStore}, + connection::ServerName, prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, }; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; @@ -32,20 +33,17 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result Result<()> let protocol_config = builder.build()?; - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - let verifier = Verifier::new( VerifierConfig::builder() - .root_store(root_store) + .root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }) .protocol_config_validator(protocol_config) .build()?, ); diff --git a/crates/harness/executor/test_plugins/basic.rs b/crates/harness/executor/test_plugins/basic.rs index 35334c56e..98eb454bc 100644 --- a/crates/harness/executor/test_plugins/basic.rs +++ b/crates/harness/executor/test_plugins/basic.rs @@ -1,6 +1,6 @@ -use tls_core::anchors::RootCertStore; use tlsn::{ - config::{ProtocolConfig, ProtocolConfigValidator}, + config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore}, + connection::ServerName, hash::HashAlgId, prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind}, @@ -21,19 +21,17 @@ const MAX_RECV_DATA: usize = 1 << 11; crate::test!("basic", prover, verifier); async fn prover(provider: &IoProvider) { - let mut root_store = RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - let mut tls_config_builder = TlsConfig::builder(); - tls_config_builder.root_store(root_store); + tls_config_builder.root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }); let tls_config = tls_config_builder.build().unwrap(); + let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()); let prover = Prover::new( ProverConfig::builder() - .server_name(SERVER_DOMAIN) + .server_name(server_name) .tls_config(tls_config) .protocol_config( ProtocolConfig::builder() @@ -114,11 +112,6 @@ async fn prover(provider: &IoProvider) { } async fn verifier(provider: &IoProvider) { - let mut root_store = RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - let config = VerifierConfig::builder() .protocol_config_validator( ProtocolConfigValidator::builder() @@ -127,7 +120,9 @@ async fn verifier(provider: &IoProvider) { .build() .unwrap(), ) - .root_store(root_store) + .root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }) .build() .unwrap(); @@ -145,7 +140,9 @@ async fn verifier(provider: &IoProvider) { .await .unwrap(); - assert_eq!(server_name.unwrap().as_str(), SERVER_DOMAIN); + let ServerName::Dns(server_name) = server_name.unwrap(); + + assert_eq!(server_name.as_str(), SERVER_DOMAIN); assert!( transcript_commitments .iter() diff --git a/crates/mpc-tls/Cargo.toml b/crates/mpc-tls/Cargo.toml index 105dce21b..a85fe8b99 100644 --- a/crates/mpc-tls/Cargo.toml +++ b/crates/mpc-tls/Cargo.toml @@ -72,3 +72,5 @@ tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } tokio-util = { workspace = true, features = ["compat"] } tracing-subscriber = { workspace = true } uid-mux = { workspace = true, features = ["serio", "test-utils"] } +rustls-pki-types = { workspace = true } +rustls-webpki = { workspace = true } diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index d4f37d94b..241820013 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -22,7 +22,7 @@ use serio::stream::IoStreamExt; use std::mem; use tls_core::msgs::enums::NamedGroup; use tlsn_core::{ - connection::{HandshakeData, HandshakeDataV1_2, TlsVersion, VerifyData}, + connection::{CertBinding, CertBindingV1_2, TlsVersion, VerifyData}, transcript::TlsTranscript, }; use tracing::{debug, instrument}; @@ -405,7 +405,7 @@ impl MpcTlsFollower { let cf_vd = cf_vd.ok_or(MpcTlsError::hs("client finished VD not computed"))?; let sf_vd = sf_vd.ok_or(MpcTlsError::hs("server finished VD not computed"))?; - let handshake_data = HandshakeData::V1_2(HandshakeDataV1_2 { + let handshake_data = CertBinding::V1_2(CertBindingV1_2 { client_random, server_random, server_ephemeral_key: server_key diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 3d522f19b..926fda993 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -43,10 +43,9 @@ use tls_core::{ suites::SupportedCipherSuite, }; use tlsn_core::{ - connection::{ - Certificate, HandshakeData, HandshakeDataV1_2, ServerSignature, TlsVersion, VerifyData, - }, + connection::{CertBinding, CertBindingV1_2, ServerSignature, TlsVersion, VerifyData}, transcript::TlsTranscript, + webpki::CertificateDer, }; use tracing::{debug, instrument, trace, warn}; @@ -325,7 +324,7 @@ impl MpcTlsLeader { let server_cert_chain = server_cert_details .cert_chain() .iter() - .map(|cert| Certificate(cert.0.clone())) + .map(|cert| CertificateDer(cert.0.clone())) .collect(); let server_signature = ServerSignature { @@ -337,7 +336,7 @@ impl MpcTlsLeader { sig: server_kx_details.kx_sig().sig.0.clone(), }; - let handshake_data = HandshakeData::V1_2(HandshakeDataV1_2 { + let handshake_data = CertBinding::V1_2(CertBindingV1_2 { client_random: client_random.0, server_random: server_random.0, server_ephemeral_key: server_key diff --git a/crates/mpc-tls/tests/test.rs b/crates/mpc-tls/tests/test.rs index 523cff768..40d1186c4 100644 --- a/crates/mpc-tls/tests/test.rs +++ b/crates/mpc-tls/tests/test.rs @@ -12,11 +12,15 @@ use mpz_ot::{ rcot::shared::{SharedRCOTReceiver, SharedRCOTSender}, }; use rand::{rngs::StdRng, Rng, SeedableRng}; -use tls_client::Certificate; +use rustls_pki_types::CertificateDer; +use tls_client::RootCertStore; use tls_client_async::bind_client; use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; use tokio::sync::Mutex; use tokio_util::compat::TokioAsyncReadCompatExt; +use webpki::anchor_from_trusted_cert; + +const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER); #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[ignore = "expensive"] @@ -48,11 +52,11 @@ async fn leader_task(mut leader: MpcTlsLeader) { let (leader_ctrl, leader_fut) = leader.run(); tokio::spawn(async { leader_fut.await.unwrap() }); - let mut root_store = tls_client::RootCertStore::empty(); - root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); let config = tls_client::ClientConfig::builder() .with_safe_defaults() - .with_root_certificates(root_store) + .with_root_certificates(RootCertStore { + roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()], + }) .with_no_client_auth(); let server_name = SERVER_DOMAIN.try_into().unwrap(); diff --git a/crates/tls/client-async/Cargo.toml b/crates/tls/client-async/Cargo.toml index 6a205ac95..d4b4968eb 100644 --- a/crates/tls/client-async/Cargo.toml +++ b/crates/tls/client-async/Cargo.toml @@ -35,3 +35,5 @@ hyper = { workspace = true, features = ["client", "http1"] } hyper-util = { workspace = true, features = ["full"] } rstest = { workspace = true } tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } +rustls-webpki = { workspace = true } +rustls-pki-types = { workspace = true } diff --git a/crates/tls/client-async/tests/test.rs b/crates/tls/client-async/tests/test.rs index 9c2c40f25..f588665db 100644 --- a/crates/tls/client-async/tests/test.rs +++ b/crates/tls/client-async/tests/test.rs @@ -6,7 +6,8 @@ use http_body_util::{BodyExt as _, Full}; use hyper::{body::Bytes, Request, StatusCode}; use hyper_util::rt::TokioIo; use rstest::{fixture, rstest}; -use tls_client::{Certificate, ClientConfig, ClientConnection, RustCryptoBackend, ServerName}; +use rustls_pki_types::CertificateDer; +use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName}; use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection}; use tls_server_fixture::{ bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY, @@ -14,6 +15,9 @@ use tls_server_fixture::{ }; use tokio::task::JoinHandle; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; +use webpki::anchor_from_trusted_cert; + +const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER); // An established client TLS connection struct TlsFixture { @@ -30,7 +34,9 @@ async fn set_up_tls() -> TlsFixture { let _server_task = tokio::spawn(bind_test_server(server_socket.compat())); let mut root_store = tls_client::RootCertStore::empty(); - root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); + root_store + .roots + .push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()); let config = ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) @@ -75,7 +81,9 @@ async fn test_hyper_ok() { let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); let mut root_store = tls_client::RootCertStore::empty(); - root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); + root_store + .roots + .push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()); let config = ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) diff --git a/crates/tls/client/Cargo.toml b/crates/tls/client/Cargo.toml index 25a2ad0eb..4e3f32a9f 100644 --- a/crates/tls/client/Cargo.toml +++ b/crates/tls/client/Cargo.toml @@ -23,7 +23,8 @@ async-trait = { workspace = true } log = { workspace = true, optional = true } ring = { workspace = true } sct = { workspace = true } -webpki = { workspace = true, features = ["alloc", "std"] } +rustls-pki-types = { workspace = true } +rustls-webpki = { workspace = true } aes-gcm = { workspace = true } p256 = { workspace = true, features = ["ecdh"] } rand = { workspace = true } diff --git a/crates/tls/client/src/client/hs.rs b/crates/tls/client/src/client/hs.rs index eff6857b8..2032d99b0 100644 --- a/crates/tls/client/src/client/hs.rs +++ b/crates/tls/client/src/client/hs.rs @@ -7,7 +7,6 @@ use crate::{ conn::{CommonState, ConnectionRandoms, State}, error::Error, hash_hs::HandshakeHashBuffer, - msgs::persist, ticketer::TimeBase, }; use tls_core::{ @@ -42,35 +41,6 @@ pub(super) type NextState = Box>; pub(super) type NextStateOrError = Result; pub(super) type ClientContext<'a> = crate::conn::Context<'a>; -fn find_session( - server_name: &ServerName, - config: &ClientConfig, -) -> Option> { - let key = persist::ClientSessionKey::session_for_server_name(server_name); - let key_buf = key.get_encoding(); - - let value = config.session_storage.get(&key_buf).or_else(|| { - debug!("No cached session for {:?}", server_name); - None - })?; - - #[allow(unused_mut)] - let mut reader = Reader::init(&value[2..]); - #[allow(clippy::bind_instead_of_map)] // https://github.com/rust-lang/rust-clippy/issues/8082 - CipherSuite::read_bytes(&value[..2]) - .and_then(|suite| { - persist::ClientSessionValue::read(&mut reader, suite, &config.cipher_suites) - }) - .and_then(|resuming| { - let retrieved = persist::Retrieved::new(resuming, TimeBase::now().ok()?); - match retrieved.has_expired() { - false => Some(retrieved), - true => None, - } - }) - .and_then(Some) -} - pub(super) async fn start_handshake( server_name: ServerName, extra_exts: Vec, @@ -123,7 +93,6 @@ pub(super) async fn start_handshake( emit_client_hello_for_retry( config, cx, - None, random, false, transcript_buffer, @@ -142,7 +111,6 @@ pub(super) async fn start_handshake( struct ExpectServerHello { config: Arc, - resuming_session: Option>, server_name: ServerName, random: Random, using_ems: bool, @@ -162,7 +130,6 @@ struct ExpectServerHelloOrHelloRetryRequest { async fn emit_client_hello_for_retry( config: Arc, cx: &mut ClientContext<'_>, - resuming_session: Option>, random: Random, using_ems: bool, mut transcript_buffer: HandshakeHashBuffer, @@ -176,25 +143,6 @@ async fn emit_client_hello_for_retry( may_send_sct_list: bool, suite: Option, ) -> Result { - // For now we do not support session resumption - // - // Do we have a SessionID or ticket cached for this host? - // let (ticket, resume_version) = if let Some(resuming) = &resuming_session { - // match &resuming.value { - // persist::ClientSessionValue::Tls13(inner) => { - // (inner.ticket().to_vec(), ProtocolVersion::TLSv1_3) - // } - // #[cfg(feature = "tls12")] - // persist::ClientSessionValue::Tls12(inner) => { - // (inner.ticket().to_vec(), ProtocolVersion::TLSv1_2) - // } - // } - // } else { - // (Vec::new(), ProtocolVersion::Unknown(0)) - // }; - - // let (ticket, resume_version) = (Vec::new(), ProtocolVersion::Unknown(0)); - let support_tls12 = config.supports_version(ProtocolVersion::TLSv1_2); let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3); @@ -256,48 +204,6 @@ async fn emit_client_hello_for_retry( // Extra extensions must be placed before the PSK extension exts.extend(extra_exts.iter().cloned()); - // let fill_in_binder = if support_tls13 - // && config.enable_tickets - // && resume_version == ProtocolVersion::TLSv1_3 - // && !ticket.is_empty() - // { - // let resuming = - // resuming_session - // .as_ref() - // .and_then(|resuming| match (suite, resuming.tls13()) { - // (Some(suite), Some(resuming)) => { - // suite.tls13()?.can_resume_from(resuming.suite())?; - // Some(resuming) - // } - // (None, Some(resuming)) => Some(resuming), - // _ => None, - // }); - // if let Some(ref resuming) = resuming { - // tls13::prepare_resumption( - // &config, - // cx, - // ticket, - // &resuming, - // &mut exts, - // retryreq.is_some(), - // ) - // .await; - // } - // resuming - // } else if config.enable_tickets { - // // If we have a ticket, include it. Otherwise, request one. - // if ticket.is_empty() { - // exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Request)); - // } else { - // exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Offer( - // Payload::new(ticket), - // ))); - // } - // None - // } else { - // None - // }; - // Note what extensions we sent. hello.sent_extensions = exts.iter().map(ClientExtension::get_type).collect(); @@ -319,8 +225,8 @@ async fn emit_client_hello_for_retry( }; // let early_key_schedule = if let Some(resuming) = fill_in_binder { - // let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer, &mut chp); - // Some((resuming.suite(), schedule)) + // let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer, + // &mut chp); Some((resuming.suite(), schedule)) // } else { // None // }; @@ -350,7 +256,6 @@ async fn emit_client_hello_for_retry( let next = ExpectServerHello { config, - resuming_session, server_name, random, using_ems, @@ -551,19 +456,10 @@ impl State for ExpectServerHello { // handshake_traffic_secret. match suite { SupportedCipherSuite::Tls13(suite) => { - let resuming_session = - self.resuming_session - .and_then(|resuming| match resuming.value { - persist::ClientSessionValue::Tls13(inner) => Some(inner), - #[cfg(feature = "tls12")] - persist::ClientSessionValue::Tls12(_) => None, - }); - tls13::handle_server_hello( self.config, cx, server_hello, - resuming_session, self.server_name, randoms, suite, @@ -577,16 +473,8 @@ impl State for ExpectServerHello { } #[cfg(feature = "tls12")] SupportedCipherSuite::Tls12(suite) => { - let resuming_session = - self.resuming_session - .and_then(|resuming| match resuming.value { - persist::ClientSessionValue::Tls12(inner) => Some(inner), - persist::ClientSessionValue::Tls13(_) => None, - }); - tls12::CompleteServerHelloHandling { config: self.config, - resuming_session, server_name: self.server_name, randoms, using_ems: self.using_ems, @@ -723,7 +611,6 @@ impl ExpectServerHelloOrHelloRetryRequest { emit_client_hello_for_retry( self.next.config, cx, - self.next.resuming_session, self.next.random, self.next.using_ems, transcript_buffer, diff --git a/crates/tls/client/src/client/tls12.rs b/crates/tls/client/src/client/tls12.rs index 95ad0fcf8..7adeaff3d 100644 --- a/crates/tls/client/src/client/tls12.rs +++ b/crates/tls/client/src/client/tls12.rs @@ -10,7 +10,6 @@ use crate::{ conn::{CommonState, ConnectionRandoms, State}, error::Error, hash_hs::HandshakeHash, - msgs::persist, sign::Signer, ticketer::TimeBase, verify, @@ -49,7 +48,6 @@ mod server_hello { pub(in crate::client) struct CompleteServerHelloHandling { pub(in crate::client) config: Arc, - pub(in crate::client) resuming_session: Option, pub(in crate::client) server_name: ServerName, pub(in crate::client) randoms: ConnectionRandoms, pub(in crate::client) using_ems: bool, @@ -113,76 +111,8 @@ mod server_hello { None }; - // See if we're successfully resuming. - if let Some(ref _resuming) = self.resuming_session { - return Err(Error::General( - "client does not support resumption".to_string(), - )); - // if resuming.session_id == server_hello.session_id { - // debug!("Server agreed to resume"); - - // // Is the server telling lies about the ciphersuite? - // if resuming.suite() != suite { - // let error_msg = - // "abbreviated handshake offered, but with varied cs".to_string(); - // return Err(Error::PeerMisbehavedError(error_msg)); - // } - - // // And about EMS support? - // if resuming.extended_ms() != self.using_ems { - // let error_msg = "server varied ems support over resume".to_string(); - // return Err(Error::PeerMisbehavedError(error_msg)); - // } - - // let secrets = - // ConnectionSecrets::new_resume(self.randoms, suite, resuming.secret()); - // self.config.key_log.log( - // "CLIENT_RANDOM", - // &secrets.randoms.client, - // &secrets.master_secret, - // ); - // cx.common.start_encryption_tls12(&secrets, Side::Client); - - // // Since we're resuming, we verified the certificate and - // // proof of possession in the prior session. - // cx.common.peer_certificates = Some(resuming.server_cert_chain().to_vec()); - // let cert_verified = verify::ServerCertVerified::assertion(); - // let sig_verified = verify::HandshakeSignatureValid::assertion(); - - // return if must_issue_new_ticket { - // Ok(Box::new(ExpectNewTicket { - // config: self.config, - // secrets, - // resuming_session: self.resuming_session, - // session_id: server_hello.session_id, - // server_name: self.server_name, - // using_ems: self.using_ems, - // transcript: self.transcript, - // resuming: true, - // cert_verified, - // sig_verified, - // })) - // } else { - // Ok(Box::new(ExpectCcs { - // config: self.config, - // secrets, - // resuming_session: self.resuming_session, - // session_id: server_hello.session_id, - // server_name: self.server_name, - // using_ems: self.using_ems, - // transcript: self.transcript, - // ticket: None, - // resuming: true, - // cert_verified, - // sig_verified, - // })) - // }; - // } - } - Ok(Box::new(ExpectCertificate { config: self.config, - resuming_session: self.resuming_session, session_id: server_hello.session_id, server_name: self.server_name, randoms: self.randoms, @@ -199,7 +129,6 @@ mod server_hello { struct ExpectCertificate { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -228,7 +157,6 @@ impl State for ExpectCertificate { if self.may_send_cert_status { Ok(Box::new(ExpectCertificateStatusOrServerKx { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -250,7 +178,6 @@ impl State for ExpectCertificate { Ok(Box::new(ExpectServerKx { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -266,7 +193,6 @@ impl State for ExpectCertificate { struct ExpectCertificateStatusOrServerKx { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -303,7 +229,6 @@ impl State for ExpectCertificateStatusOrServerKx { Box::new(ExpectServerKx { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -322,7 +247,6 @@ impl State for ExpectCertificateStatusOrServerKx { }) => { Box::new(ExpectCertificateStatus { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -350,7 +274,6 @@ impl State for ExpectCertificateStatusOrServerKx { struct ExpectCertificateStatus { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -395,7 +318,6 @@ impl State for ExpectCertificateStatus { Ok(Box::new(ExpectServerKx { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -410,7 +332,6 @@ impl State for ExpectCertificateStatus { struct ExpectServerKx { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -458,7 +379,6 @@ impl State for ExpectServerKx { Ok(Box::new(ExpectServerDoneOrCertReq { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -570,7 +490,6 @@ async fn emit_finished( // client auth. Otherwise we go straight to ServerHelloDone. struct ExpectServerDoneOrCertReq { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -598,7 +517,6 @@ impl State for ExpectServerDoneOrCertReq { ) { Box::new(ExpectCertificateRequest { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -616,7 +534,6 @@ impl State for ExpectServerDoneOrCertReq { Box::new(ExpectServerDone { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -636,7 +553,6 @@ impl State for ExpectServerDoneOrCertReq { struct ExpectCertificateRequest { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -679,7 +595,6 @@ impl State for ExpectCertificateRequest { Ok(Box::new(ExpectServerDone { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, randoms: self.randoms, @@ -696,7 +611,6 @@ impl State for ExpectCertificateRequest { struct ExpectServerDone { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, randoms: ConnectionRandoms, @@ -745,6 +659,7 @@ impl State for ExpectServerDone { // 3. Verify that the top certificate signed their kx. // 4. If doing client auth, send our Certificate. // 5. Complete the key exchange: + // // a) generate our kx pair // b) emit a ClientKeyExchange containing it // c) if doing client auth, emit a CertificateVerify @@ -891,7 +806,6 @@ impl State for ExpectServerDone { if st.must_issue_new_ticket { Ok(Box::new(ExpectNewTicket { config: st.config, - resuming_session: st.resuming_session, session_id: st.session_id, server_name: st.server_name, using_ems: st.using_ems, @@ -903,7 +817,6 @@ impl State for ExpectServerDone { } else { Ok(Box::new(ExpectCcs { config: st.config, - resuming_session: st.resuming_session, session_id: st.session_id, server_name: st.server_name, using_ems: st.using_ems, @@ -919,7 +832,6 @@ impl State for ExpectServerDone { struct ExpectNewTicket { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, using_ems: bool, @@ -946,7 +858,6 @@ impl State for ExpectNewTicket { Ok(Box::new(ExpectCcs { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, using_ems: self.using_ems, @@ -962,7 +873,6 @@ impl State for ExpectNewTicket { // -- Waiting for their CCS -- struct ExpectCcs { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, using_ems: bool, @@ -998,7 +908,6 @@ impl State for ExpectCcs { Ok(Box::new(ExpectFinished { config: self.config, - resuming_session: self.resuming_session, session_id: self.session_id, server_name: self.server_name, using_ems: self.using_ems, @@ -1013,7 +922,6 @@ impl State for ExpectCcs { struct ExpectFinished { config: Arc, - resuming_session: Option, session_id: SessionID, server_name: ServerName, using_ems: bool, @@ -1024,60 +932,6 @@ struct ExpectFinished { sig_verified: verify::HandshakeSignatureValid, } -// impl ExpectFinished { -// // -- Waiting for their finished -- -// fn save_session(&mut self, cx: &mut ClientContext<'_>) { -// // Save a ticket. If we got a new ticket, save that. Otherwise, save the -// // original ticket again. -// let (mut ticket, lifetime) = match self.ticket.take() { -// Some(nst) => (nst.ticket.0, nst.lifetime_hint), -// None => (Vec::new(), 0), -// }; - -// if ticket.is_empty() { -// if let Some(resuming_session) = &mut self.resuming_session { -// ticket = resuming_session.take_ticket(); -// } -// } - -// if self.session_id.is_empty() && ticket.is_empty() { -// debug!("Session not saved: server didn't allocate id or ticket"); -// return; -// } - -// let time_now = match TimeBase::now() { -// Ok(time_now) => time_now, -// Err(e) => { -// debug!("Session not saved: {}", e); -// return; -// } -// }; - -// let key = persist::ClientSessionKey::session_for_server_name(&self.server_name); -// let value = persist::Tls12ClientSessionValue::new( -// self.secrets.suite(), -// self.session_id, -// ticket, -// self.secrets.get_master_secret(), -// cx.common.peer_certificates.clone().unwrap_or_default(), -// time_now, -// lifetime, -// self.using_ems, -// ); - -// let worked = self -// .config -// .session_storage -// .put(key.get_encoding(), value.get_encoding()); - -// if worked { -// debug!("Session saved"); -// } else { -// debug!("Session not saved"); -// } -// } -// } - #[async_trait] impl State for ExpectFinished { async fn handle( diff --git a/crates/tls/client/src/client/tls13.rs b/crates/tls/client/src/client/tls13.rs index 955e77654..c9263d03a 100644 --- a/crates/tls/client/src/client/tls13.rs +++ b/crates/tls/client/src/client/tls13.rs @@ -11,7 +11,6 @@ use crate::{ conn::{CommonState, ConnectionRandoms, State}, error::Error, hash_hs::{HandshakeHash, HandshakeHashBuffer}, - msgs::persist, sign, verify, KeyLog, }; #[allow(deprecated)] @@ -60,7 +59,6 @@ pub(super) async fn handle_server_hello( config: Arc, cx: &mut ClientContext<'_>, server_hello: &ServerHelloPayload, - resuming_session: Option, server_name: ServerName, randoms: ConnectionRandoms, suite: &'static Tls13CipherSuite, @@ -102,8 +100,8 @@ pub(super) async fn handle_server_hello( // }; // if server_hello.get_psk_index() != Some(0) { - // return Err(cx.common.illegal_param("server selected invalid psk").await); - // } + // return Err(cx.common.illegal_param("server selected invalid + // psk").await); } // debug!("Resuming using PSK"); // // The key schedule has been initialized and set in fill_in_psk_binder() @@ -143,7 +141,6 @@ pub(super) async fn handle_server_hello( Ok(Box::new(ExpectEncryptedExtensions { config, - resuming_session, server_name, randoms, suite, @@ -170,69 +167,6 @@ async fn validate_server_hello( Ok(()) } -// fn save_kx_hint(config: &ClientConfig, server_name: &ServerName, group: NamedGroup) { -// let key = persist::ClientSessionKey::hint_for_server_name(server_name); - -// config -// .session_storage -// .put(key.get_encoding(), group.get_encoding()); -// } - -// /// This implements the horrifying TLS1.3 hack where PSK binders have a -// /// data dependency on the message they are contained within. -// pub(super) fn fill_in_psk_binder( -// resuming: &persist::Tls13ClientSessionValue, -// transcript: &HandshakeHashBuffer, -// hmp: &mut HandshakeMessagePayload, -// ) -> KeyScheduleEarly { -// // We need to know the hash function of the suite we're trying to resume into. -// let hkdf_alg = &resuming.suite().hkdf_algorithm; -// let suite_hash = resuming.suite().hash_algorithm(); - -// // The binder is calculated over the clienthello, but doesn't include itself or its -// // length, or the length of its container. -// let binder_plaintext = hmp.get_encoding_for_binder_signing(); -// let handshake_hash = transcript.get_hash_given(suite_hash, &binder_plaintext); - -// // Run a fake key_schedule to simulate what the server will do if it chooses -// // to resume. -// let key_schedule = KeyScheduleEarly::new(hkdf_alg, resuming.secret()); -// let real_binder = key_schedule.resumption_psk_binder_key_and_sign_verify_data(&handshake_hash); - -// if let HandshakePayload::ClientHello(ref mut ch) = hmp.payload { -// ch.set_psk_binder(real_binder.as_ref()); -// }; - -// key_schedule -// } - -// pub(super) async fn prepare_resumption( -// config: &ClientConfig, -// cx: &mut ClientContext<'_>, -// ticket: Vec, -// resuming_session: &persist::Retrieved<&persist::Tls13ClientSessionValue>, -// exts: &mut Vec, -// doing_retry: bool, -// ) { -// let resuming_suite = resuming_session.suite(); -// cx.common.suite = Some(resuming_suite.into()); -// cx.data.resumption_ciphersuite = Some(resuming_suite.into()); - -// // Finally, and only for TLS1.3 with a ticket resumption, include a binder -// // for our ticket. This must go last. -// // -// // Include an empty binder. It gets filled in below because it depends on -// // the message it's contained in (!!!). -// let obfuscated_ticket_age = resuming_session.obfuscated_ticket_age(); - -// let binder_len = resuming_suite.hash_algorithm().output_len(); -// let binder = vec![0u8; binder_len]; - -// let psk_identity = PresharedKeyIdentity::new(ticket, obfuscated_ticket_age); -// let psk_ext = PresharedKeyOffer::new(psk_identity, binder); -// exts.push(ClientExtension::PresharedKey(psk_ext)); -// } - pub(super) async fn emit_fake_ccs( sent_tls13_fake_ccs: &mut bool, common: &mut CommonState, @@ -287,7 +221,6 @@ async fn validate_encrypted_extensions( struct ExpectEncryptedExtensions { config: Arc, - resuming_session: Option, server_name: ServerName, randoms: ConnectionRandoms, suite: &'static Tls13CipherSuite, @@ -313,52 +246,19 @@ impl State for ExpectEncryptedExtensions { validate_encrypted_extensions(cx.common, &self.hello, exts).await?; hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol()).await?; - if let Some(resuming_session) = self.resuming_session { - let was_early_traffic = cx.common.early_traffic; - if was_early_traffic { - if exts.early_data_extension_offered() { - cx.data.early_data.accepted(); - } else { - cx.data.early_data.rejected(); - cx.common.early_traffic = false; - } - } - - if was_early_traffic && !cx.common.early_traffic { - // If no early traffic, set the encryption key for handshakes - cx.common.record_layer.set_message_encrypter(); - } - - cx.common.peer_certificates = Some(resuming_session.server_cert_chain().to_vec()); - - // We *don't* reverify the certificate chain here: resumption is a - // continuation of the previous session in terms of security policy. - let cert_verified = verify::ServerCertVerified::assertion(); - let sig_verified = verify::HandshakeSignatureValid::assertion(); - Ok(Box::new(ExpectFinished { - config: self.config, - server_name: self.server_name, - randoms: self.randoms, - suite: self.suite, - transcript: self.transcript, - client_auth: None, - cert_verified, - sig_verified, - })) - } else { - if exts.early_data_extension_offered() { - let msg = "server sent early data extension without resumption".to_string(); - return Err(Error::PeerMisbehavedError(msg)); - } - Ok(Box::new(ExpectCertificateOrCertReq { - config: self.config, - server_name: self.server_name, - randoms: self.randoms, - suite: self.suite, - transcript: self.transcript, - may_send_sct_list: self.hello.server_may_send_sct_list(), - })) + if exts.early_data_extension_offered() { + let msg = "server sent early data extension without resumption".to_string(); + return Err(Error::PeerMisbehavedError(msg)); } + + Ok(Box::new(ExpectCertificateOrCertReq { + config: self.config, + server_name: self.server_name, + randoms: self.randoms, + suite: self.suite, + transcript: self.transcript, + may_send_sct_list: self.hello.server_may_send_sct_list(), + })) } } @@ -422,9 +322,9 @@ impl State for ExpectCertificateOrCertReq { } } -// TLS1.3 version of CertificateRequest handling. We then move to expecting the server -// Certificate. Unfortunately the CertificateRequest type changed in an annoying way -// in TLS1.3. +// TLS1.3 version of CertificateRequest handling. We then move to expecting the +// server Certificate. Unfortunately the CertificateRequest type changed in an +// annoying way in TLS1.3. struct ExpectCertificateRequest { config: Arc, server_name: ServerName, @@ -787,8 +687,8 @@ impl State for ExpectFinished { st.transcript.add_message(&m); - /* The EndOfEarlyData message to server is still encrypted with early data keys, - * but appears in the transcript after the server Finished. */ + /* The EndOfEarlyData message to server is still encrypted with early data + * keys, but appears in the transcript after the server Finished. */ if cx.common.early_traffic { emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await?; cx.common.early_traffic = false; @@ -893,42 +793,6 @@ impl ExpectTraffic { )); } - // let handshake_hash = self.transcript.get_current_hash(); - // let secret = self - // .key_schedule - // .resumption_master_secret_and_derive_ticket_psk(&handshake_hash, &nst.nonce.0); - - // let time_now = match TimeBase::now() { - // Ok(t) => t, - // #[allow(unused_variables)] - // Err(e) => { - // debug!("Session not saved: {}", e); - // return Ok(()); - // } - // }; - - // let value = persist::Tls13ClientSessionValue::new( - // self.suite, - // nst.ticket.0.clone(), - // secret, - // cx.common.peer_certificates.clone().unwrap_or_default(), - // time_now, - // nst.lifetime, - // nst.age_add, - // nst.get_max_early_data_size().unwrap_or_default(), - // ); - - // let key = persist::ClientSessionKey::session_for_server_name(&self.server_name); - // #[allow(unused_mut)] - // let mut ticket = value.get_encoding(); - - // let worked = self.session_storage.put(key.get_encoding(), ticket); - - // if worked { - // debug!("Ticket saved"); - // } else { - // debug!("Ticket not saved"); - // } Ok(()) } @@ -948,27 +812,6 @@ impl ExpectTraffic { Err(Error::General( "received unsupported key update request from peer".to_string(), )) - - // match kur { - // KeyUpdateRequest::UpdateNotRequested => {} - // KeyUpdateRequest::UpdateRequested => { - // self.want_write_key_update = true; - // } - // _ => { - // common - // .send_fatal_alert(AlertDescription::IllegalParameter) - // .await; - // return Err(Error::CorruptMessagePayload(ContentType::Handshake)); - // } - // } - - // // Update our read-side keys. - // let new_read_key = self.key_schedule.next_server_application_traffic_secret(); - // common - // .record_layer - // .set_message_decrypter(self.suite.derive_decrypter(&new_read_key)); - - // Ok(()) } } @@ -1022,10 +865,11 @@ impl State for ExpectTraffic { // .send_msg_encrypt(Message::build_key_update_notify().into()) // .await; - // let write_key = self.key_schedule.next_client_application_traffic_secret(); + // let write_key = + // self.key_schedule.next_client_application_traffic_secret(); // common // .record_layer - // .set_message_encrypter(self.suite.derive_encrypter(&write_key)); - // } + // .set_message_encrypter(self.suite.derive_encrypter(& + // write_key)); } } } diff --git a/crates/tls/client/src/lib.rs b/crates/tls/client/src/lib.rs index 9555770dd..d6c1c25f0 100644 --- a/crates/tls/client/src/lib.rs +++ b/crates/tls/client/src/lib.rs @@ -301,15 +301,11 @@ mod conn; mod error; mod hash_hs; mod limited_cache; -mod msgs; mod rand; mod record_layer; //mod stream; mod vecbuf; -pub(crate) use tls_core::verify; -#[cfg(test)] -mod verifybench; -pub(crate) use tls_core::x509; +pub(crate) use tls_core::{verify, x509}; #[macro_use] mod check; mod bs_debug; @@ -330,7 +326,7 @@ pub mod internal { // The public interface is: pub use crate::{ - anchors::{OwnedTrustAnchor, RootCertStore}, + anchors::RootCertStore, builder::{ConfigBuilder, WantsCipherSuites, WantsKxGroups, WantsVerifier, WantsVersions}, conn::{CommonState, ConnectionCommon, IoState, Reader, SideData}, error::Error, diff --git a/crates/tls/client/src/msgs/mod.rs b/crates/tls/client/src/msgs/mod.rs deleted file mode 100644 index 652058cf3..000000000 --- a/crates/tls/client/src/msgs/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub(crate) mod persist; - -#[cfg(test)] -mod persist_test; diff --git a/crates/tls/client/src/msgs/persist.rs b/crates/tls/client/src/msgs/persist.rs deleted file mode 100644 index e70aed5e4..000000000 --- a/crates/tls/client/src/msgs/persist.rs +++ /dev/null @@ -1,526 +0,0 @@ -use crate::{client::ServerName, ticketer::TimeBase}; -use std::cmp; -#[cfg(feature = "tls12")] -use std::mem; -#[cfg(feature = "tls12")] -use tls_core::suites::Tls12CipherSuite; -use tls_core::{ - msgs::{ - base::{PayloadU16, PayloadU8}, - codec::{Codec, Reader}, - enums::{CipherSuite, ProtocolVersion}, - handshake::{CertificatePayload, SessionID}, - }, - suites::{SupportedCipherSuite, Tls13CipherSuite}, -}; - -// These are the keys and values we store in session storage. - -// --- Client types --- -/// Keys for session resumption and tickets. -/// Matching value is a `ClientSessionValue`. -#[derive(Debug)] -pub struct ClientSessionKey { - kind: &'static [u8], - name: Vec, -} - -impl Codec for ClientSessionKey { - fn encode(&self, bytes: &mut Vec) { - bytes.extend_from_slice(self.kind); - bytes.extend_from_slice(&self.name); - } - - // Don't need to read these. - fn read(_r: &mut Reader) -> Option { - None - } -} - -impl ClientSessionKey { - pub fn session_for_server_name(server_name: &ServerName) -> Self { - Self { - kind: b"session", - name: server_name.encode(), - } - } - - pub fn hint_for_server_name(server_name: &ServerName) -> Self { - Self { - kind: b"kx-hint", - name: server_name.encode(), - } - } -} - -#[derive(Debug)] -pub enum ClientSessionValue { - Tls13(Tls13ClientSessionValue), - #[cfg(feature = "tls12")] - Tls12(Tls12ClientSessionValue), -} - -impl ClientSessionValue { - pub fn read( - reader: &mut Reader<'_>, - suite: CipherSuite, - supported: &[SupportedCipherSuite], - ) -> Option { - match supported.iter().find(|s| s.suite() == suite)? { - SupportedCipherSuite::Tls13(inner) => { - Tls13ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls13) - } - #[cfg(feature = "tls12")] - SupportedCipherSuite::Tls12(inner) => { - Tls12ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls12) - } - } - } - - fn common(&self) -> &ClientSessionCommon { - match self { - Self::Tls13(inner) => &inner.common, - #[cfg(feature = "tls12")] - Self::Tls12(inner) => &inner.common, - } - } -} - -impl From for ClientSessionValue { - fn from(v: Tls13ClientSessionValue) -> Self { - Self::Tls13(v) - } -} - -#[cfg(feature = "tls12")] -impl From for ClientSessionValue { - fn from(v: Tls12ClientSessionValue) -> Self { - Self::Tls12(v) - } -} - -pub struct Retrieved { - pub value: T, - retrieved_at: TimeBase, -} - -impl Retrieved { - pub fn new(value: T, retrieved_at: TimeBase) -> Self { - Self { - value, - retrieved_at, - } - } -} - -impl Retrieved<&Tls13ClientSessionValue> { - pub fn obfuscated_ticket_age(&self) -> u32 { - let age_secs = self - .retrieved_at - .as_secs() - .saturating_sub(self.value.common.epoch); - let age_millis = age_secs as u32 * 1000; - age_millis.wrapping_add(self.value.age_add) - } -} - -impl Retrieved { - pub fn tls13(&self) -> Option> { - match &self.value { - ClientSessionValue::Tls13(value) => Some(Retrieved::new(value, self.retrieved_at)), - #[cfg(feature = "tls12")] - ClientSessionValue::Tls12(_) => None, - } - } - - pub fn has_expired(&self) -> bool { - let common = self.value.common(); - common.lifetime_secs != 0 - && common.epoch + u64::from(common.lifetime_secs) < self.retrieved_at.as_secs() - } -} - -impl std::ops::Deref for Retrieved { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.value - } -} - -#[derive(Debug)] -pub struct Tls13ClientSessionValue { - suite: &'static Tls13CipherSuite, - age_add: u32, - max_early_data_size: u32, - pub common: ClientSessionCommon, -} - -impl Tls13ClientSessionValue { - pub fn new( - suite: &'static Tls13CipherSuite, - ticket: Vec, - secret: Vec, - server_cert_chain: Vec, - time_now: TimeBase, - lifetime_secs: u32, - age_add: u32, - max_early_data_size: u32, - ) -> Self { - Self { - suite, - age_add, - max_early_data_size, - common: ClientSessionCommon::new( - ticket, - secret, - time_now, - lifetime_secs, - server_cert_chain, - ), - } - } - - /// [`Codec::read()`] with an extra `suite` argument. - /// - /// We decode the `suite` argument separately because it allows us to - /// decide whether we're decoding an 1.2 or 1.3 session value. - pub fn read(suite: &'static Tls13CipherSuite, r: &mut Reader) -> Option { - Some(Self { - suite, - age_add: u32::read(r)?, - max_early_data_size: u32::read(r)?, - common: ClientSessionCommon::read(r)?, - }) - } - - /// Inherent implementation of the [`Codec::get_encoding()`] method. - /// - /// (See `read()` for why this is inherent here.) - pub fn get_encoding(&self) -> Vec { - let mut bytes = Vec::with_capacity(16); - self.suite.common.suite.encode(&mut bytes); - self.age_add.encode(&mut bytes); - self.max_early_data_size.encode(&mut bytes); - self.common.encode(&mut bytes); - bytes - } - - pub fn max_early_data_size(&self) -> u32 { - self.max_early_data_size - } - - pub fn suite(&self) -> &'static Tls13CipherSuite { - self.suite - } -} - -impl std::ops::Deref for Tls13ClientSessionValue { - type Target = ClientSessionCommon; - - fn deref(&self) -> &Self::Target { - &self.common - } -} - -#[cfg(feature = "tls12")] -#[derive(Debug)] -pub struct Tls12ClientSessionValue { - suite: &'static Tls12CipherSuite, - pub session_id: SessionID, - extended_ms: bool, - pub common: ClientSessionCommon, -} - -#[cfg(feature = "tls12")] -impl Tls12ClientSessionValue { - pub fn new( - suite: &'static Tls12CipherSuite, - session_id: SessionID, - ticket: Vec, - master_secret: Vec, - server_cert_chain: Vec, - time_now: TimeBase, - lifetime_secs: u32, - extended_ms: bool, - ) -> Self { - Self { - suite, - session_id, - extended_ms, - common: ClientSessionCommon::new( - ticket, - master_secret, - time_now, - lifetime_secs, - server_cert_chain, - ), - } - } - - /// [`Codec::read()`] with an extra `suite` argument. - /// - /// We decode the `suite` argument separately because it allows us to - /// decide whether we're decoding an 1.2 or 1.3 session value. - fn read(suite: &'static Tls12CipherSuite, r: &mut Reader) -> Option { - Some(Self { - suite, - session_id: SessionID::read(r)?, - extended_ms: u8::read(r)? == 1, - common: ClientSessionCommon::read(r)?, - }) - } - - /// Inherent implementation of the [`Codec::get_encoding()`] method. - /// - /// (See `read()` for why this is inherent here.) - pub fn get_encoding(&self) -> Vec { - let mut bytes = Vec::with_capacity(16); - self.suite.common.suite.encode(&mut bytes); - self.session_id.encode(&mut bytes); - (if self.extended_ms { 1u8 } else { 0u8 }).encode(&mut bytes); - self.common.encode(&mut bytes); - bytes - } - - pub fn take_ticket(&mut self) -> Vec { - mem::take(&mut self.common.ticket.0) - } - - pub fn extended_ms(&self) -> bool { - self.extended_ms - } - - pub fn suite(&self) -> &'static Tls12CipherSuite { - self.suite - } -} - -#[cfg(feature = "tls12")] -impl std::ops::Deref for Tls12ClientSessionValue { - type Target = ClientSessionCommon; - - fn deref(&self) -> &Self::Target { - &self.common - } -} - -#[derive(Debug)] -pub struct ClientSessionCommon { - ticket: PayloadU16, - secret: PayloadU8, - epoch: u64, - lifetime_secs: u32, - server_cert_chain: CertificatePayload, -} - -impl ClientSessionCommon { - fn new( - ticket: Vec, - secret: Vec, - time_now: TimeBase, - lifetime_secs: u32, - server_cert_chain: Vec, - ) -> Self { - Self { - ticket: PayloadU16(ticket), - secret: PayloadU8(secret), - epoch: time_now.as_secs(), - lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME), - server_cert_chain, - } - } - - /// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`] - /// implementation through [`Deref`] implementations on - /// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`]. - fn read(r: &mut Reader) -> Option { - Some(Self { - ticket: PayloadU16::read(r)?, - secret: PayloadU8::read(r)?, - epoch: u64::read(r)?, - lifetime_secs: u32::read(r)?, - server_cert_chain: CertificatePayload::read(r)?, - }) - } - - /// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`] - /// implementation through [`Deref`] implementations on - /// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`]. - fn encode(&self, bytes: &mut Vec) { - self.ticket.encode(bytes); - self.secret.encode(bytes); - self.epoch.encode(bytes); - self.lifetime_secs.encode(bytes); - self.server_cert_chain.encode(bytes); - } - - pub fn server_cert_chain(&self) -> &[tls_core::key::Certificate] { - self.server_cert_chain.as_ref() - } - - pub fn secret(&self) -> &[u8] { - self.secret.0.as_ref() - } - - pub fn ticket(&self) -> &[u8] { - self.ticket.0.as_ref() - } - - /// Test only: wind back epoch by delta seconds. - pub fn rewind_epoch(&mut self, delta: u32) { - self.epoch -= delta as u64; - } -} - -static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60; - -/// This is the maximum allowed skew between server and client clocks, over -/// the maximum ticket lifetime period. This encompasses TCP retransmission -/// times in case packet loss occurs when the client sends the ClientHello -/// or receives the NewSessionTicket, _and_ actual clock skew over this period. -static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000; - -// --- Server types --- -pub type ServerSessionKey = SessionID; - -#[derive(Debug)] -pub struct ServerSessionValue { - pub sni: Option, - pub version: ProtocolVersion, - pub cipher_suite: CipherSuite, - pub master_secret: PayloadU8, - pub extended_ms: bool, - pub client_cert_chain: Option, - pub alpn: Option, - pub application_data: PayloadU16, - pub creation_time_sec: u64, - pub age_obfuscation_offset: u32, - freshness: Option, -} - -impl Codec for ServerSessionValue { - fn encode(&self, bytes: &mut Vec) { - if let Some(ref sni) = self.sni { - 1u8.encode(bytes); - let sni_bytes: &str = sni.as_ref().into(); - PayloadU8::new(Vec::from(sni_bytes)).encode(bytes); - } else { - 0u8.encode(bytes); - } - self.version.encode(bytes); - self.cipher_suite.encode(bytes); - self.master_secret.encode(bytes); - (if self.extended_ms { 1u8 } else { 0u8 }).encode(bytes); - if let Some(ref chain) = self.client_cert_chain { - 1u8.encode(bytes); - chain.encode(bytes); - } else { - 0u8.encode(bytes); - } - if let Some(ref alpn) = self.alpn { - 1u8.encode(bytes); - alpn.encode(bytes); - } else { - 0u8.encode(bytes); - } - self.application_data.encode(bytes); - self.creation_time_sec.encode(bytes); - self.age_obfuscation_offset.encode(bytes); - } - - fn read(r: &mut Reader) -> Option { - let has_sni = u8::read(r)?; - let sni = if has_sni == 1 { - let dns_name = PayloadU8::read(r)?; - let dns_name = webpki::DnsNameRef::try_from_ascii(&dns_name.0).ok()?; - Some(dns_name.into()) - } else { - None - }; - let v = ProtocolVersion::read(r)?; - let cs = CipherSuite::read(r)?; - let ms = PayloadU8::read(r)?; - let ems = u8::read(r)?; - let has_ccert = u8::read(r)? == 1; - let ccert = if has_ccert { - Some(CertificatePayload::read(r)?) - } else { - None - }; - let has_alpn = u8::read(r)? == 1; - let alpn = if has_alpn { - Some(PayloadU8::read(r)?) - } else { - None - }; - let application_data = PayloadU16::read(r)?; - let creation_time_sec = u64::read(r)?; - let age_obfuscation_offset = u32::read(r)?; - - Some(Self { - sni, - version: v, - cipher_suite: cs, - master_secret: ms, - extended_ms: ems == 1u8, - client_cert_chain: ccert, - alpn, - application_data, - creation_time_sec, - age_obfuscation_offset, - freshness: None, - }) - } -} - -impl ServerSessionValue { - pub fn new( - sni: Option<&webpki::DnsName>, - v: ProtocolVersion, - cs: CipherSuite, - ms: Vec, - client_cert_chain: Option, - alpn: Option>, - application_data: Vec, - creation_time: TimeBase, - age_obfuscation_offset: u32, - ) -> Self { - Self { - sni: sni.cloned(), - version: v, - cipher_suite: cs, - master_secret: PayloadU8::new(ms), - extended_ms: false, - client_cert_chain, - alpn: alpn.map(PayloadU8::new), - application_data: PayloadU16::new(application_data), - creation_time_sec: creation_time.as_secs(), - age_obfuscation_offset, - freshness: None, - } - } - - pub fn set_extended_ms_used(&mut self) { - self.extended_ms = true; - } - - pub fn set_freshness(mut self, obfuscated_client_age_ms: u32, time_now: TimeBase) -> Self { - let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset); - let server_age_ms = - (time_now.as_secs().saturating_sub(self.creation_time_sec) as u32).saturating_mul(1000); - - let age_difference = if client_age_ms < server_age_ms { - server_age_ms - client_age_ms - } else { - client_age_ms - server_age_ms - }; - - self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS); - self - } - - pub fn is_fresh(&self) -> bool { - self.freshness.unwrap_or_default() - } -} diff --git a/crates/tls/client/src/msgs/persist_test.rs b/crates/tls/client/src/msgs/persist_test.rs deleted file mode 100644 index a7f99abf8..000000000 --- a/crates/tls/client/src/msgs/persist_test.rs +++ /dev/null @@ -1,78 +0,0 @@ -use super::persist::*; -use crate::ticketer::TimeBase; -use std::convert::TryInto; -use tls_core::{ - key::Certificate, - msgs::{ - codec::{Codec, Reader}, - enums::*, - }, - suites::TLS13_AES_128_GCM_SHA256, -}; - -#[test] -fn clientsessionkey_is_debug() { - let name = "hello".try_into().unwrap(); - let csk = ClientSessionKey::session_for_server_name(&name); - println!("{:?}", csk); -} - -#[test] -fn clientsessionkey_cannot_be_read() { - let bytes = [0; 1]; - let mut rd = Reader::init(&bytes); - assert!(ClientSessionKey::read(&mut rd).is_none()); -} - -#[test] -fn clientsessionvalue_is_debug() { - let csv = ClientSessionValue::from(Tls13ClientSessionValue::new( - TLS13_AES_128_GCM_SHA256.tls13().unwrap(), - vec![], - vec![1, 2, 3], - vec![Certificate(b"abc".to_vec()), Certificate(b"def".to_vec())], - TimeBase::now().unwrap(), - 15, - 10, - 128, - )); - println!("{:?}", csv); -} - -#[test] -fn serversessionvalue_is_debug() { - let ssv = ServerSessionValue::new( - None, - ProtocolVersion::TLSv1_3, - CipherSuite::TLS13_AES_128_GCM_SHA256, - vec![1, 2, 3], - None, - None, - vec![4, 5, 6], - TimeBase::now().unwrap(), - 0x12345678, - ); - println!("{:?}", ssv); -} - -#[test] -fn serversessionvalue_no_sni() { - let bytes = [ - 0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, - 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d, - ]; - let mut rd = Reader::init(&bytes); - let ssv = ServerSessionValue::read(&mut rd).unwrap(); - assert_eq!(ssv.get_encoding(), bytes); -} - -#[test] -fn serversessionvalue_with_cert() { - let bytes = [ - 0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, - 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d, - ]; - let mut rd = Reader::init(&bytes); - let ssv = ServerSessionValue::read(&mut rd).unwrap(); - assert_eq!(ssv.get_encoding(), bytes); -} diff --git a/crates/tls/client/src/sign.rs b/crates/tls/client/src/sign.rs index b9ac2f285..312ec4ee0 100644 --- a/crates/tls/client/src/sign.rs +++ b/crates/tls/client/src/sign.rs @@ -6,6 +6,7 @@ use ring::{ io::der, signature::{self, EcdsaKeyPair, Ed25519KeyPair, RsaKeyPair}, }; +use rustls_pki_types as pki_types; use std::{convert::TryFrom, error::Error as StdError, fmt, sync::Arc}; use tls_core::{ key, @@ -71,51 +72,6 @@ impl CertifiedKey { pub fn end_entity_cert(&self) -> Result<&key::Certificate, SignError> { self.cert.first().ok_or(SignError(())) } - - /// Check the certificate chain for validity: - /// - it should be non-empty list - /// - the first certificate should be parsable as a x509v3, - /// - the first certificate should quote the given server name - /// (if provided) - /// - /// These checks are not security-sensitive. They are the - /// *server* attempting to detect accidental misconfiguration. - pub(crate) fn cross_check_end_entity_cert( - &self, - name: Option, - ) -> Result<(), Error> { - // Always reject an empty certificate chain. - let end_entity_cert = self.end_entity_cert().map_err(|SignError(())| { - Error::General("No end-entity certificate in certificate chain".to_string()) - })?; - - // Reject syntactically-invalid end-entity certificates. - let end_entity_cert = - webpki::EndEntityCert::try_from(end_entity_cert.as_ref()).map_err(|_| { - Error::General( - "End-entity certificate in certificate \ - chain is syntactically invalid" - .to_string(), - ) - })?; - - if let Some(name) = name { - // If SNI was offered then the certificate must be valid for - // that hostname. Note that this doesn't fully validate that the - // certificate is valid; it only validates that the name is one - // that the certificate is valid for, if the certificate is - // valid. - if end_entity_cert.verify_is_valid_for_dns_name(name).is_err() { - return Err(Error::General( - "The server certificate is not \ - valid for the given name" - .to_string(), - )); - } - } - - Ok(()) - } } /// Parse `der` as any supported key encoding/type, returning diff --git a/crates/tls/client/src/verifybench.rs b/crates/tls/client/src/verifybench.rs deleted file mode 100644 index 6d5be2f1d..000000000 --- a/crates/tls/client/src/verifybench.rs +++ /dev/null @@ -1,223 +0,0 @@ -// This program does benchmarking of the functions in verify.rs, -// that do certificate chain validation and signature verification. -// -// Note: we don't use any of the standard 'cargo bench', 'test::Bencher', -// etc. because it's unstable at the time of writing. - -use crate::{anchors, verify, verify::ServerCertVerifier, OwnedTrustAnchor}; -use std::convert::TryInto; -use web_time::{Duration, Instant, SystemTime}; - -use webpki_roots; - -fn duration_nanos(d: Duration) -> u64 { - ((d.as_secs() as f64) * 1e9 + (d.subsec_nanos() as f64)) as u64 -} - -#[test] -fn test_reddit_cert() { - Context::new( - "reddit", - "reddit.com", - &[ - include_bytes!("testdata/cert-reddit.0.der"), - include_bytes!("testdata/cert-reddit.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_github_cert() { - Context::new( - "github", - "github.com", - &[ - include_bytes!("testdata/cert-github.0.der"), - include_bytes!("testdata/cert-github.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_arstechnica_cert() { - Context::new( - "arstechnica", - "arstechnica.com", - &[ - include_bytes!("testdata/cert-arstechnica.0.der"), - include_bytes!("testdata/cert-arstechnica.1.der"), - include_bytes!("testdata/cert-arstechnica.2.der"), - include_bytes!("testdata/cert-arstechnica.3.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_twitter_cert() { - Context::new( - "twitter", - "twitter.com", - &[ - include_bytes!("testdata/cert-twitter.0.der"), - include_bytes!("testdata/cert-twitter.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_wikipedia_cert() { - Context::new( - "wikipedia", - "wikipedia.org", - &[ - include_bytes!("testdata/cert-wikipedia.0.der"), - include_bytes!("testdata/cert-wikipedia.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_google_cert() { - Context::new( - "google", - "www.google.com", - &[ - include_bytes!("testdata/cert-google.0.der"), - include_bytes!("testdata/cert-google.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_hn_cert() { - Context::new( - "hn", - "news.ycombinator.com", - &[ - include_bytes!("testdata/cert-hn.0.der"), - include_bytes!("testdata/cert-hn.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_stackoverflow_cert() { - Context::new( - "stackoverflow", - "stackoverflow.com", - &[ - include_bytes!("testdata/cert-stackoverflow.0.der"), - include_bytes!("testdata/cert-stackoverflow.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_duckduckgo_cert() { - Context::new( - "duckduckgo", - "duckduckgo.com", - &[ - include_bytes!("testdata/cert-duckduckgo.0.der"), - include_bytes!("testdata/cert-duckduckgo.1.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_rustlang_cert() { - Context::new( - "rustlang", - "www.rust-lang.org", - &[ - include_bytes!("testdata/cert-rustlang.0.der"), - include_bytes!("testdata/cert-rustlang.1.der"), - include_bytes!("testdata/cert-rustlang.2.der"), - ], - ) - .bench(100) -} - -#[test] -fn test_wapo_cert() { - Context::new( - "wapo", - "www.washingtonpost.com", - &[ - include_bytes!("testdata/cert-wapo.0.der"), - include_bytes!("testdata/cert-wapo.1.der"), - ], - ) - .bench(100) -} - -struct Context { - name: &'static str, - domain: &'static str, - roots: anchors::RootCertStore, - chain: Vec, - now: SystemTime, -} - -impl Context { - fn new(name: &'static str, domain: &'static str, certs: &[&'static [u8]]) -> Self { - let mut roots = anchors::RootCertStore::empty(); - roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); - Self { - name, - domain, - roots, - chain: certs - .iter() - .copied() - .map(|bytes| tls_core::key::Certificate(bytes.to_vec())) - .collect(), - now: SystemTime::UNIX_EPOCH + Duration::from_secs(1640870720), - } - } - - fn bench(&self, count: usize) { - let verifier = verify::WebPkiVerifier::new(self.roots.clone(), None); - const SCTS: &[&[u8]] = &[]; - const OCSP_RESPONSE: &[u8] = &[]; - let mut times = Vec::new(); - - let (end_entity, intermediates) = self.chain.split_first().unwrap(); - for _ in 0..count { - let start = Instant::now(); - let server_name = self.domain.try_into().unwrap(); - verifier - .verify_server_cert( - end_entity, - intermediates, - &server_name, - &mut SCTS.iter().copied(), - OCSP_RESPONSE, - self.now, - ) - .unwrap(); - times.push(duration_nanos(Instant::now().duration_since(start))); - } - - println!( - "verify_server_cert({}): min {:?}us", - self.name, - times.iter().min().unwrap() / 1000 - ); - } -} diff --git a/crates/tls/client/tests/api.rs b/crates/tls/client/tests/api.rs index 0bfcb7339..c562aab9e 100644 --- a/crates/tls/client/tests/api.rs +++ b/crates/tls/client/tests/api.rs @@ -780,14 +780,12 @@ async fn client_checks_server_certificate_with_given_name() { let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap(); let err = do_handshake_until_error(&mut client, &mut server).await; - assert_eq!( + assert!(matches!( err, Err(ErrorFromPeer::Client(Error::CoreError( - tls_core::Error::InvalidCertificateData( - "invalid peer certificate: CertNotValidForName".into(), - ) + tls_core::Error::InvalidCertificateData(_) ))) - ); + )); } } } diff --git a/crates/tls/client/tests/common/mod.rs b/crates/tls/client/tests/common/mod.rs index 4cfec734c..6e7a5a785 100644 --- a/crates/tls/client/tests/common/mod.rs +++ b/crates/tls/client/tests/common/mod.rs @@ -2,6 +2,7 @@ use futures::{AsyncRead, AsyncWrite}; use rustls::{server::AllowAnyAuthenticatedClient, ServerConfig, ServerConnection}; +use rustls_pki_types::CertificateDer; use std::{ convert::{TryFrom, TryInto}, io, @@ -15,6 +16,7 @@ use tls_client::{ Certificate, ClientConfig, ClientConnection, Error, PrivateKey, RootCertStore, RustCryptoBackend, }; +use webpki::anchor_from_trusted_cert; macro_rules! embed_files { ( @@ -409,9 +411,17 @@ pub fn finish_client_config( kt: KeyType, config: tls_client::ConfigBuilder, ) -> ClientConfig { - let mut root_store = RootCertStore::empty(); let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert")); - root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap()); + let roots = rustls_pemfile::certs(&mut rootbuf) + .unwrap() + .into_iter() + .map(|cert| { + let der = CertificateDer::from_slice(&cert); + anchor_from_trusted_cert(&der).unwrap().to_owned() + }) + .collect(); + + let root_store = RootCertStore { roots }; config .with_root_certificates(root_store) @@ -422,9 +432,17 @@ pub fn finish_client_config_with_creds( kt: KeyType, config: tls_client::ConfigBuilder, ) -> ClientConfig { - let mut root_store = RootCertStore::empty(); let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert")); - root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap()); + let roots = rustls_pemfile::certs(&mut rootbuf) + .unwrap() + .into_iter() + .map(|cert| { + let der = CertificateDer::from_slice(&cert); + anchor_from_trusted_cert(&der).unwrap().to_owned() + }) + .collect(); + + let root_store = RootCertStore { roots }; config .with_root_certificates(root_store) diff --git a/crates/tls/core/Cargo.toml b/crates/tls/core/Cargo.toml index 49391afb3..f40763f7e 100644 --- a/crates/tls/core/Cargo.toml +++ b/crates/tls/core/Cargo.toml @@ -35,4 +35,5 @@ sha2 = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true, optional = true } web-time = { workspace = true } -webpki = { workspace = true, features = ["alloc", "std"] } +rustls-webpki = { workspace = true, features = ["ring"] } +rustls-pki-types = { workspace = true } diff --git a/crates/tls/core/src/anchors.rs b/crates/tls/core/src/anchors.rs index 1fb30981a..0122e2aad 100644 --- a/crates/tls/core/src/anchors.rs +++ b/crates/tls/core/src/anchors.rs @@ -1,65 +1,16 @@ +use rustls_pki_types::TrustAnchor; + use crate::{ msgs::handshake::{DistinguishedName, DistinguishedNames}, x509, }; -/// A trust anchor, commonly known as a "Root Certificate." -#[derive(Debug, Clone)] -pub struct OwnedTrustAnchor { - subject: Vec, - spki: Vec, - name_constraints: Option>, -} - -impl OwnedTrustAnchor { - /// Get a `webpki::TrustAnchor` by borrowing the owned elements. - pub(crate) fn to_trust_anchor(&self) -> webpki::TrustAnchor<'_> { - webpki::TrustAnchor { - subject: &self.subject, - spki: &self.spki, - name_constraints: self.name_constraints.as_deref(), - } - } - - /// Constructs an `OwnedTrustAnchor` from its components. - /// - /// `subject` is the subject field of the trust anchor. - /// - /// `spki` is the `subjectPublicKeyInfo` field of the trust anchor. - /// - /// `name_constraints` is the value of a DER-encoded name constraints to - /// apply for this trust anchor, if any. - pub fn from_subject_spki_name_constraints( - subject: impl Into>, - spki: impl Into>, - name_constraints: Option>>, - ) -> Self { - Self { - subject: subject.into(), - spki: spki.into(), - name_constraints: name_constraints.map(|x| x.into()), - } - } -} - -/// Errors that can occur during operations with RootCertStore -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum RootCertStoreError { - #[error(transparent)] - WebpkiError(#[from] webpki::Error), - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error("Unexpected PEM certificate count. Expected 1 certificate, got {0}")] - PemCertUnexpectedCount(usize), -} - /// A container for root certificates able to provide a root-of-trust /// for connection authentication. #[derive(Debug, Clone)] pub struct RootCertStore { /// The list of roots. - pub roots: Vec, + pub roots: Vec>, } impl RootCertStore { @@ -91,100 +42,4 @@ impl RootCertStore { r } - - /// Add a single DER-encoded certificate to the store. - pub fn add(&mut self, der: &crate::key::Certificate) -> Result<(), RootCertStoreError> { - let ta = webpki::TrustAnchor::try_from_cert_der(&der.0)?; - let ota = OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ); - self.roots.push(ota); - Ok(()) - } - - /// Adds a single PEM-encoded certificate to the store. - pub fn add_pem(&mut self, pem: &str) -> Result<(), RootCertStoreError> { - let mut certificates = rustls_pemfile::certs(&mut pem.as_bytes())?; - - if certificates.len() != 1 { - return Err(RootCertStoreError::PemCertUnexpectedCount( - certificates.len(), - )); - } - - self.add(&crate::key::Certificate(certificates.remove(0)))?; - - Ok(()) - } - - /// Adds all the given TrustAnchors `anchors`. This does not - /// fail. - pub fn add_server_trust_anchors( - &mut self, - trust_anchors: impl Iterator, - ) { - self.roots.extend(trust_anchors) - } - - /// Parse the given DER-encoded certificates and add all that can be parsed - /// in a best-effort fashion. - /// - /// This is because large collections of root certificates often - /// include ancient or syntactically invalid certificates. - /// - /// Returns the number of certificates added, and the number that were ignored. - pub fn add_parsable_certificates(&mut self, der_certs: &[Vec]) -> (usize, usize) { - let mut valid_count = 0; - let mut invalid_count = 0; - - for der_cert in der_certs { - match self.add(&crate::key::Certificate(der_cert.clone())) { - Ok(_) => valid_count += 1, - Err(_err) => invalid_count += 1, - } - } - - (valid_count, invalid_count) - } -} - -#[cfg(test)] -mod tests { - use super::*; - const CA_PEM_CERT: &[u8] = include_bytes!("../testdata/cert-digicert.pem"); - - #[test] - fn test_add_pem_ok() { - let pem = std::str::from_utf8(CA_PEM_CERT).unwrap(); - assert!(RootCertStore::empty().add_pem(pem).is_ok()); - } - - #[test] - fn test_add_pem_err_bad_cert() { - assert_eq!( - RootCertStore::empty() - .add_pem("bad pem") - .err() - .unwrap() - .to_string(), - "Unexpected PEM certificate count. Expected 1 certificate, got 0" - ); - } - - #[test] - fn test_add_pem_err_more_than_one_cert() { - let pem1 = std::str::from_utf8(CA_PEM_CERT).unwrap(); - let pem2 = pem1; - - assert_eq!( - RootCertStore::empty() - .add_pem((pem1.to_owned() + pem2).as_str()) - .err() - .unwrap() - .to_string(), - "Unexpected PEM certificate count. Expected 1 certificate, got 2" - ); - } } diff --git a/crates/tls/core/src/dns.rs b/crates/tls/core/src/dns.rs index aab903554..81f438ef6 100644 --- a/crates/tls/core/src/dns.rs +++ b/crates/tls/core/src/dns.rs @@ -1,6 +1,6 @@ use std::{error::Error as StdError, fmt}; -use crate::verify; +use rustls_pki_types as pki_types; /// Encodes ways a client can know the expected name of the server. /// @@ -26,20 +26,16 @@ use crate::verify; /// ``` #[non_exhaustive] #[derive(Debug, PartialEq, Clone)] -pub enum ServerName { - /// The server is identified by a DNS name. The name - /// is sent in the TLS Server Name Indication (SNI) - /// extension. - DnsName(verify::DnsName), -} +pub struct ServerName(pub(crate) pki_types::ServerName<'static>); impl ServerName { /// Return the name that should go in the SNI extension. /// If [`None`] is returned, the SNI extension is not included /// in the handshake. - pub fn for_sni(&self) -> Option> { - match self { - Self::DnsName(dns_name) => Some(dns_name.0.as_ref()), + pub fn for_sni(&self) -> Option> { + match &self.0 { + pki_types::ServerName::DnsName(dns_name) => Some(dns_name.clone()), + _ => None, } } @@ -49,13 +45,17 @@ impl ServerName { DnsName = 0x01, } - let Self::DnsName(dns_name) = self; - let bytes = dns_name.0.as_ref(); + let bytes: &[u8] = match &self.0 { + pki_types::ServerName::DnsName(dns_name) => dns_name.as_ref().as_ref(), + pki_types::ServerName::IpAddress(pki_types::IpAddr::V4(ip)) => ip.as_ref(), + pki_types::ServerName::IpAddress(pki_types::IpAddr::V6(ip)) => ip.as_ref(), + _ => unreachable!(), + }; - let mut r = Vec::with_capacity(2 + bytes.as_ref().len()); + let mut r = Vec::with_capacity(2 + bytes.len()); r.push(UniqueTypeCode::DnsName as u8); - r.push(bytes.as_ref().len() as u8); - r.extend_from_slice(bytes.as_ref()); + r.push(bytes.len() as u8); + r.extend_from_slice(bytes); r } @@ -66,9 +66,9 @@ impl ServerName { impl TryFrom<&str> for ServerName { type Error = InvalidDnsNameError; fn try_from(s: &str) -> Result { - match webpki::DnsNameRef::try_from_ascii_str(s) { - Ok(dns) => Ok(Self::DnsName(verify::DnsName(dns.into()))), - Err(webpki::InvalidDnsNameError) => Err(InvalidDnsNameError), + match pki_types::DnsName::try_from(s) { + Ok(dns) => Ok(Self(pki_types::ServerName::DnsName(dns.to_owned()))), + Err(_) => Err(InvalidDnsNameError), } } } diff --git a/crates/tls/core/src/error.rs b/crates/tls/core/src/error.rs index 994f19219..c20b99296 100644 --- a/crates/tls/core/src/error.rs +++ b/crates/tls/core/src/error.rs @@ -1,5 +1,6 @@ use crate::msgs::enums::{AlertDescription, ContentType, HandshakeType}; -use std::{error::Error as StdError, fmt, time::SystemTimeError}; +use std::{error::Error as StdError, fmt}; +use web_time::SystemTimeError; /// rustls reports protocol errors using this type. #[derive(Debug, PartialEq, Clone)] @@ -41,8 +42,9 @@ pub enum Error { /// We couldn't decrypt a message. This is invariably fatal. DecryptError, - /// We couldn't encrypt a message because it was larger than the allowed message size. - /// This should never happen if the application is using valid record sizes. + /// We couldn't encrypt a message because it was larger than the allowed + /// message size. This should never happen if the application is using + /// valid record sizes. EncryptError, /// The peer doesn't support a protocol version/feature we require. diff --git a/crates/tls/core/src/msgs/handshake.rs b/crates/tls/core/src/msgs/handshake.rs index 52c17ee2f..2423da4a5 100644 --- a/crates/tls/core/src/msgs/handshake.rs +++ b/crates/tls/core/src/msgs/handshake.rs @@ -1,3 +1,5 @@ +use rustls_pki_types as pki_types; + use crate::{ key, msgs::{ @@ -249,14 +251,14 @@ impl DecomposedSignatureScheme for SignatureScheme { #[derive(Clone, Debug)] pub enum ServerNamePayload { // Stored twice, bytes so we can round-trip, and DnsName for use - HostName((PayloadU16, webpki::DnsName)), + HostName((PayloadU16, pki_types::DnsName<'static>)), Unknown(Payload), } impl ServerNamePayload { - pub fn new_hostname(hostname: webpki::DnsName) -> Self { + pub fn new_hostname(hostname: pki_types::DnsName<'static>) -> Self { let raw = { - let s: &str = hostname.as_ref().into(); + let s: &str = hostname.as_ref(); PayloadU16::new(s.as_bytes().into()) }; Self::HostName((raw, hostname)) @@ -266,8 +268,8 @@ impl ServerNamePayload { let raw = PayloadU16::read(r)?; let dns_name = { - match webpki::DnsNameRef::try_from_ascii(&raw.0) { - Ok(dns_name) => dns_name.into(), + match pki_types::DnsName::try_from(raw.0.as_slice()) { + Ok(dns_name) => dns_name.to_owned(), Err(_) => { warn!("Illegal SNI hostname received {:?}", raw.0); return None; @@ -313,11 +315,12 @@ declare_u16_vec!(ServerNameRequest, ServerName); pub trait ConvertServerNameList { fn has_duplicate_names_for_type(&self) -> bool; - fn get_single_hostname(&self) -> Option>; + fn get_single_hostname(&self) -> Option>; } impl ConvertServerNameList for ServerNameRequest { - /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type." + /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the + /// same name_type." fn has_duplicate_names_for_type(&self) -> bool { let mut seen = collections::HashSet::new(); @@ -330,10 +333,10 @@ impl ConvertServerNameList for ServerNameRequest { false } - fn get_single_hostname(&self) -> Option> { - fn only_dns_hostnames(name: &ServerName) -> Option> { + fn get_single_hostname(&self) -> Option> { + fn only_dns_hostnames(name: &ServerName) -> Option> { if let ServerNamePayload::HostName((_, ref dns)) = name.payload { - Some(dns.as_ref()) + Some(dns.clone()) } else { None } @@ -694,16 +697,16 @@ impl Codec for ClientExtension { } } -fn trim_hostname_trailing_dot_for_sni(dns_name: webpki::DnsNameRef) -> webpki::DnsName { - let dns_name_str: &str = dns_name.into(); +fn trim_hostname_trailing_dot_for_sni( + dns_name: pki_types::DnsName<'static>, +) -> pki_types::DnsName<'static> { + let dns_name_str: &str = dns_name.as_ref(); // RFC6066: "The hostname is represented as a byte string using // ASCII encoding without a trailing dot" if dns_name_str.ends_with('.') { let trimmed = &dns_name_str[0..dns_name_str.len() - 1]; - webpki::DnsNameRef::try_from_ascii_str(trimmed) - .unwrap() - .to_owned() + pki_types::DnsName::try_from(trimmed).unwrap().to_owned() } else { dns_name.to_owned() } @@ -711,7 +714,7 @@ fn trim_hostname_trailing_dot_for_sni(dns_name: webpki::DnsNameRef) -> webpki::D impl ClientExtension { /// Make a basic SNI ServerNameRequest quoting `hostname`. - pub fn make_sni(dns_name: webpki::DnsNameRef) -> Self { + pub fn make_sni(dns_name: pki_types::DnsName<'static>) -> Self { let name = ServerName { typ: ServerNameType::HostName, payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)), diff --git a/crates/tls/core/src/msgs/handshake_test.rs b/crates/tls/core/src/msgs/handshake_test.rs index ed99e579c..81ca310f5 100644 --- a/crates/tls/core/src/msgs/handshake_test.rs +++ b/crates/tls/core/src/msgs/handshake_test.rs @@ -1,3 +1,5 @@ +use rustls_pki_types as pki_types; + use super::{ base::{Payload, PayloadU16, PayloadU24, PayloadU8}, codec::{put_u16, Codec, Reader}, @@ -5,7 +7,6 @@ use super::{ handshake::*, }; use crate::key::Certificate; -use webpki::DnsNameRef; #[test] fn rejects_short_random() { @@ -186,7 +187,8 @@ fn can_roundtrip_multiname_sni() { assert!(req.has_duplicate_names_for_type()); - let dns_name_str: &str = req.get_single_hostname().unwrap().into(); + let dns_name = req.get_single_hostname().unwrap(); + let dns_name_str: &str = dns_name.as_ref(); assert_eq!(dns_name_str, "hi"); assert_eq!(req[0].typ, ServerNameType::HostName); @@ -363,7 +365,7 @@ fn get_sample_clienthellopayload() -> ClientHelloPayload { ClientExtension::ECPointFormats(ECPointFormatList::supported()), ClientExtension::NamedGroups(vec![NamedGroup::X25519]), ClientExtension::SignatureAlgorithms(vec![SignatureScheme::ECDSA_NISTP256_SHA256]), - ClientExtension::make_sni(DnsNameRef::try_from_ascii_str("hello").unwrap()), + ClientExtension::make_sni(pki_types::DnsName::try_from("hello").unwrap().to_owned()), ClientExtension::SessionTicket(ClientSessionTicket::Request), ClientExtension::SessionTicket(ClientSessionTicket::Offer(Payload(vec![]))), ClientExtension::Protocols(vec![PayloadU8(vec![0])]), diff --git a/crates/tls/core/src/verify.rs b/crates/tls/core/src/verify.rs index 10ac68821..ed26e7a3a 100644 --- a/crates/tls/core/src/verify.rs +++ b/crates/tls/core/src/verify.rs @@ -1,5 +1,5 @@ use crate::{ - anchors::{OwnedTrustAnchor, RootCertStore}, + anchors::RootCertStore, dns::ServerName, error::Error, key::Certificate, @@ -9,27 +9,10 @@ use crate::{ }, }; use ring::digest::Digest; -use std::convert::TryFrom; -use web_time::SystemTime; +use rustls_pki_types as pki_types; +use web_time::{SystemTime, UNIX_EPOCH}; -type SignatureAlgorithms = &'static [&'static webpki::SignatureAlgorithm]; - -/// Which signature verification mechanisms we support. No particular -/// order. -static SUPPORTED_SIG_ALGS: SignatureAlgorithms = &[ - &webpki::ECDSA_P256_SHA256, - &webpki::ECDSA_P256_SHA384, - &webpki::ECDSA_P384_SHA256, - &webpki::ECDSA_P384_SHA384, - &webpki::ED25519, - &webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY, - &webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY, - &webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY, - &webpki::RSA_PKCS1_2048_8192_SHA256, - &webpki::RSA_PKCS1_2048_8192_SHA384, - &webpki::RSA_PKCS1_2048_8192_SHA512, - &webpki::RSA_PKCS1_3072_8192_SHA384, -]; +type SignatureAlgorithms = &'static [&'static dyn pki_types::SignatureVerificationAlgorithm]; // Marker types. These are used to bind the fact some verification // (certificate chain or handshake signature) has taken place into @@ -170,7 +153,7 @@ pub trait ServerCertVerifier: Send + Sync { /// A type which encapsuates a string that is a syntactically valid DNS name. #[derive(Clone, Debug, PartialEq)] -pub struct DnsName(pub(crate) webpki::DnsName); +pub struct DnsName(pub(crate) pki_types::DnsName<'static>); impl AsRef for DnsName { fn as_ref(&self) -> &str { @@ -289,35 +272,33 @@ impl ServerCertVerifier for WebPkiVerifier { _ocsp_response: &[u8], now: SystemTime, ) -> Result { - let (cert, chain, trustroots) = prepare(end_entity, intermediates, &self.roots)?; - // `webpki::Time::try_from` does not work with `web_time::SystemTime`. - // To workaround this we convert `SystemTime` to seconds and use - // `webpki::Time::from_seconds_since_unix_epoch` instead. - let duration_since_epoch = now - .duration_since(web_time::UNIX_EPOCH) - .map_err(|_| Error::FailedToGetCurrentTime)?; - let seconds_since_unix_epoch = duration_since_epoch.as_secs(); - let webpki_now = webpki::Time::from_seconds_since_unix_epoch(seconds_since_unix_epoch); + let cert = pki_types::CertificateDer::from(end_entity.0.as_slice()); + let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?; + let intermediates = intermediates + .iter() + .map(|c| pki_types::CertificateDer::from(c.0.as_slice())) + .collect::>(); + let time = pki_types::UnixTime::since_unix_epoch(now.duration_since(UNIX_EPOCH)?); - let ServerName::DnsName(dns_name) = server_name; - - let cert = cert - .verify_is_valid_tls_server_cert( - SUPPORTED_SIG_ALGS, - &webpki::TlsServerTrustAnchors(&trustroots), - &chain, - webpki_now, - ) - .map_err(pki_error) - .map(|_| cert)?; + cert.verify_for_usage( + webpki::ALL_VERIFICATION_ALGS, + &self.roots.roots, + &intermediates, + time, + webpki::KeyUsage::server_auth(), + None, + None, + ) + .map(|_| ()) + .map_err(pki_error)?; if let Some(policy) = &self.ct_policy { policy.verify(end_entity, now, scts)?; } - cert.verify_is_valid_for_dns_name(dns_name.0.as_ref()) - .map_err(pki_error) + cert.verify_is_valid_for_subject_name(&server_name.0) .map(|_| ServerCertVerified::assertion()) + .map_err(pki_error) } } @@ -429,31 +410,6 @@ impl CertificateTransparencyPolicy { } } -type CertChainAndRoots<'a, 'b> = ( - webpki::EndEntityCert<'a>, - Vec<&'a [u8]>, - Vec>, -); - -fn prepare<'a, 'b>( - end_entity: &'a Certificate, - intermediates: &'a [Certificate], - roots: &'b RootCertStore, -) -> Result, Error> { - // EE cert must appear first. - let cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?; - - let intermediates: Vec<&'a [u8]> = intermediates.iter().map(|cert| cert.0.as_ref()).collect(); - - let trustroots: Vec = roots - .roots - .iter() - .map(OwnedTrustAnchor::to_trust_anchor) - .collect(); - - Ok((cert, intermediates, trustroots)) -} - pub(crate) fn pki_error(error: webpki::Error) -> Error { use webpki::Error::*; match error { @@ -466,20 +422,24 @@ pub(crate) fn pki_error(error: webpki::Error) -> Error { } } -static ECDSA_SHA256: SignatureAlgorithms = - &[&webpki::ECDSA_P256_SHA256, &webpki::ECDSA_P384_SHA256]; +static ECDSA_SHA256: SignatureAlgorithms = &[ + webpki::ring::ECDSA_P256_SHA256, + webpki::ring::ECDSA_P384_SHA256, +]; -static ECDSA_SHA384: SignatureAlgorithms = - &[&webpki::ECDSA_P256_SHA384, &webpki::ECDSA_P384_SHA384]; +static ECDSA_SHA384: SignatureAlgorithms = &[ + webpki::ring::ECDSA_P256_SHA384, + webpki::ring::ECDSA_P384_SHA384, +]; -static ED25519: SignatureAlgorithms = &[&webpki::ED25519]; +static ED25519: SignatureAlgorithms = &[webpki::ring::ED25519]; -static RSA_SHA256: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA256]; -static RSA_SHA384: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA384]; -static RSA_SHA512: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA512]; -static RSA_PSS_SHA256: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY]; -static RSA_PSS_SHA384: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY]; -static RSA_PSS_SHA512: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY]; +static RSA_SHA256: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA256]; +static RSA_SHA384: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA384]; +static RSA_SHA512: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA512]; +static RSA_PSS_SHA256: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY]; +static RSA_PSS_SHA384: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY]; +static RSA_PSS_SHA512: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY]; fn convert_scheme(scheme: SignatureScheme) -> Result { match scheme { @@ -514,7 +474,7 @@ fn verify_sig_using_any_alg( // webpki::SignatureAlgorithm. Therefore, convert_algs maps to several and // we try them all. for alg in algs { - match cert.verify_signature(alg, message, sig) { + match cert.verify_signature(*alg, message, sig) { Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey) => continue, res => return res, } @@ -529,7 +489,9 @@ fn verify_signed_struct( dss: &DigitallySignedStruct, ) -> Result { let possible_algs = convert_scheme(dss.scheme)?; - let cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?; + + let cert = pki_types::CertificateDer::from(cert.0.as_slice()); + let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?; verify_sig_using_any_alg(&cert, possible_algs, message, &dss.sig.0) .map_err(pki_error) @@ -538,16 +500,16 @@ fn verify_signed_struct( fn convert_alg_tls13( scheme: SignatureScheme, -) -> Result<&'static webpki::SignatureAlgorithm, Error> { +) -> Result<&'static dyn pki_types::SignatureVerificationAlgorithm, Error> { use crate::msgs::enums::SignatureScheme::*; match scheme { - ECDSA_NISTP256_SHA256 => Ok(&webpki::ECDSA_P256_SHA256), - ECDSA_NISTP384_SHA384 => Ok(&webpki::ECDSA_P384_SHA384), - ED25519 => Ok(&webpki::ED25519), - RSA_PSS_SHA256 => Ok(&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY), - RSA_PSS_SHA384 => Ok(&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY), - RSA_PSS_SHA512 => Ok(&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY), + ECDSA_NISTP256_SHA256 => Ok(webpki::ring::ECDSA_P256_SHA256), + ECDSA_NISTP384_SHA384 => Ok(webpki::ring::ECDSA_P384_SHA384), + ED25519 => Ok(webpki::ring::ED25519), + RSA_PSS_SHA256 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY), + RSA_PSS_SHA384 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY), + RSA_PSS_SHA512 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY), _ => { let error_msg = format!("received unsupported sig scheme {scheme:?}"); Err(Error::PeerMisbehavedError(error_msg)) @@ -583,7 +545,8 @@ fn verify_tls13( ) -> Result { let alg = convert_alg_tls13(dss.scheme)?; - let cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?; + let cert = pki_types::CertificateDer::from(cert.0.as_slice()); + let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?; cert.verify_signature(alg, msg, &dss.sig.0) .map_err(pki_error) diff --git a/crates/tlsn/Cargo.toml b/crates/tlsn/Cargo.toml index b240edf8d..4506ebafc 100644 --- a/crates/tlsn/Cargo.toml +++ b/crates/tlsn/Cargo.toml @@ -45,7 +45,8 @@ derive_builder = { workspace = true } futures = { workspace = true } opaque-debug = { workspace = true } rand = { workspace = true } -rustls-pki-types = "1.12.0" +rustls-pki-types = { workspace = true } +rustls-webpki = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tokio = { workspace = true, features = ["sync"] } diff --git a/crates/tlsn/src/config.rs b/crates/tlsn/src/config.rs index 544ac8f89..472c3fa60 100644 --- a/crates/tlsn/src/config.rs +++ b/crates/tlsn/src/config.rs @@ -5,6 +5,8 @@ use semver::Version; use serde::{Deserialize, Serialize}; use std::error::Error; +pub use tlsn_core::webpki::{CertificateDer, PrivateKeyDer, RootCertStore}; + // Default is 32 bytes to decrypt the TLS protocol messages. const DEFAULT_MAX_RECV_ONLINE: usize = 32; // Default maximum number of TLS records to allow. diff --git a/crates/tlsn/src/msg.rs b/crates/tlsn/src/msg.rs index 3dce3c32f..995153e9f 100644 --- a/crates/tlsn/src/msg.rs +++ b/crates/tlsn/src/msg.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; -use tlsn_core::connection::{ServerCertData, ServerName}; +use tlsn_core::connection::{HandshakeData, ServerName}; /// Message sent from Prover to Verifier to prove the server identity. #[derive(Debug, Serialize, Deserialize)] @@ -10,5 +10,5 @@ pub(crate) struct ServerIdentityProof { /// Server name. pub name: ServerName, /// Server identity data. - pub data: ServerCertData, + pub data: HandshakeData, } diff --git a/crates/tlsn/src/prover.rs b/crates/tlsn/src/prover.rs index beb704694..ee1b38ad1 100644 --- a/crates/tlsn/src/prover.rs +++ b/crates/tlsn/src/prover.rs @@ -8,12 +8,14 @@ pub mod state; pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder}; pub use error::ProverError; pub use future::ProverFuture; +use rustls_pki_types::CertificateDer; pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput}; use mpz_common::Context; use mpz_core::Block; use mpz_garble_core::Delta; use mpz_vm_core::prelude::*; +use webpki::anchor_from_trusted_cert; use crate::{ Role, @@ -39,7 +41,7 @@ use tls_client_async::{TlsConnection, bind_client}; use tls_core::msgs::enums::ContentType; use tlsn_core::{ ProvePayload, - connection::ServerCertData, + connection::{HandshakeData, ServerName}, hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256}, transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret}, }; @@ -179,21 +181,40 @@ impl Prover { let (mpc_ctrl, mpc_fut) = mpc_tls.run(); + let ServerName::Dns(server_name) = self.config.server_name(); let server_name = - TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| { - ProverError::config(format!( - "invalid server name: {}", - self.config.server_name() - )) - })?; + TlsServerName::try_from(server_name.as_ref()).expect("name was validated"); + + let root_store = if let Some(root_store) = self.config.tls_config().root_store() { + let roots = root_store + .roots + .iter() + .map(|cert| { + let der = CertificateDer::from_slice(&cert.0); + anchor_from_trusted_cert(&der) + .map(|anchor| anchor.to_owned()) + .map_err(ProverError::config) + }) + .collect::, _>>()?; + tls_client::RootCertStore { roots } + } else { + tls_client::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + } + }; let config = tls_client::ClientConfig::builder() .with_safe_defaults() - .with_root_certificates(self.config.tls_config().root_store().clone()); + .with_root_certificates(root_store); let config = if let Some((cert, key)) = self.config.tls_config().client_auth() { config - .with_single_cert(cert.clone(), key.clone()) + .with_single_cert( + cert.iter() + .map(|cert| tls_client::Certificate(cert.0.clone())) + .collect(), + tls_client::PrivateKey(key.0.clone()), + ) .map_err(ProverError::config)? } else { config.with_no_client_auth() @@ -350,10 +371,10 @@ impl Prover { }; let payload = ProvePayload { - server_identity: config.server_identity().then(|| { + handshake: config.server_identity().then(|| { ( self.config.server_name().clone(), - ServerCertData { + HandshakeData { certs: tls_transcript .server_cert_chain() .expect("server cert chain is present") @@ -362,7 +383,7 @@ impl Prover { .server_signature() .expect("server signature is present") .clone(), - handshake: tls_transcript.handshake_data().clone(), + binding: tls_transcript.certificate_binding().clone(), }, ) }), diff --git a/crates/tlsn/src/prover/config.rs b/crates/tlsn/src/prover/config.rs index b946a5592..d4d790c71 100644 --- a/crates/tlsn/src/prover/config.rs +++ b/crates/tlsn/src/prover/config.rs @@ -1,11 +1,10 @@ -use crate::config::{NetworkSetting, ProtocolConfig}; use mpc_tls::Config; -use rustls_pki_types::{CertificateDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, pem::PemObject}; -use tls_core::{ - anchors::{OwnedTrustAnchor, RootCertStore}, - key, +use tlsn_core::{ + connection::ServerName, + webpki::{CertificateDer, PrivateKeyDer, RootCertStore}, }; -use tlsn_core::connection::ServerName; + +use crate::config::{NetworkSetting, ProtocolConfig}; /// Configuration for the prover. #[derive(Debug, Clone, derive_builder::Builder)] @@ -67,31 +66,13 @@ impl ProverConfig { } /// Configuration for the prover's TLS connection. -#[derive(Debug, Clone)] +#[derive(Default, Debug, Clone)] pub struct TlsConfig { /// Root certificates. - root_store: RootCertStore, + root_store: Option, /// Certificate chain and a matching private key for client /// authentication. - client_auth: Option<(Vec, key::PrivateKey)>, -} - -impl Default for TlsConfig { - fn default() -> Self { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); - - Self { - root_store, - client_auth: None, - } - } + client_auth: Option<(Vec, PrivateKeyDer)>, } impl TlsConfig { @@ -100,13 +81,13 @@ impl TlsConfig { TlsConfigBuilder::default() } - pub(crate) fn root_store(&self) -> &RootCertStore { - &self.root_store + pub(crate) fn root_store(&self) -> Option<&RootCertStore> { + self.root_store.as_ref() } /// Returns a certificate chain and a matching private key for client /// authentication. - pub fn client_auth(&self) -> &Option<(Vec, key::PrivateKey)> { + pub fn client_auth(&self) -> &Option<(Vec, PrivateKeyDer)> { &self.client_auth } } @@ -115,7 +96,7 @@ impl TlsConfig { #[derive(Debug, Default)] pub struct TlsConfigBuilder { root_store: Option, - client_auth: Option<(Vec, key::PrivateKey)>, + client_auth: Option<(Vec, PrivateKeyDer)>, } impl TlsConfigBuilder { @@ -138,74 +119,16 @@ impl TlsConfigBuilder { /// /// - Each certificate in the chain must be in the X.509 format. /// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1). - pub fn client_auth(&mut self, cert_key: (Vec>, Vec)) -> &mut Self { - let certs = cert_key - .0 - .into_iter() - .map(key::Certificate) - .collect::>(); - - self.client_auth = Some((certs, key::PrivateKey(cert_key.1))); + pub fn client_auth(&mut self, cert_key: (Vec, PrivateKeyDer)) -> &mut Self { + self.client_auth = Some(cert_key); self } - /// Sets a PEM-encoded certificate chain and a matching private key for - /// client authentication. - /// - /// Often the chain will consist of a single end-entity certificate. - /// - /// # Arguments - /// - /// * `cert_key` - A tuple containing the certificate chain and the private - /// key. - /// - /// - Each certificate in the chain must be in the X.509 format. - /// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1). - pub fn client_auth_pem( - &mut self, - cert_key: (Vec>, Vec), - ) -> Result<&mut Self, TlsConfigError> { - let key = match PrivatePkcs8KeyDer::from_pem_slice(&cert_key.1) { - // Try to parse as PEM PKCS#8. - Ok(key) => (*key.secret_pkcs8_der()).to_vec(), - // Otherwise, try to parse as PEM PKCS#1. - Err(_) => match PrivatePkcs1KeyDer::from_pem_slice(&cert_key.1) { - Ok(key) => (*key.secret_pkcs1_der()).to_vec(), - Err(_) => return Err(ErrorRepr::InvalidKey.into()), - }, - }; - - let certs = cert_key - .0 - .iter() - .map(|c| { - let c = - CertificateDer::from_pem_slice(c).map_err(|_| ErrorRepr::InvalidCertificate)?; - Ok::(key::Certificate(c.as_ref().to_vec())) - }) - .collect::, _>>()?; - - self.client_auth = Some((certs, key::PrivateKey(key))); - Ok(self) - } - /// Builds the TLS configuration. - pub fn build(&self) -> Result { + pub fn build(self) -> Result { Ok(TlsConfig { - root_store: self.root_store.clone().unwrap_or_else(|| { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - }, - )); - root_store - }), - client_auth: self.client_auth.clone(), + root_store: self.root_store, + client_auth: self.client_auth, }) } } @@ -216,10 +139,5 @@ impl TlsConfigBuilder { pub struct TlsConfigError(#[from] ErrorRepr); #[derive(Debug, thiserror::Error)] -#[error("tls config error: {0}")] -enum ErrorRepr { - #[error("the certificate for client authentication is invalid")] - InvalidCertificate, - #[error("the private key for client authentication is invalid")] - InvalidKey, -} +#[error("tls config error")] +enum ErrorRepr {} diff --git a/crates/tlsn/src/verifier.rs b/crates/tlsn/src/verifier.rs index e56cd62a8..dc6c88d3c 100644 --- a/crates/tlsn/src/verifier.rs +++ b/crates/tlsn/src/verifier.rs @@ -8,7 +8,10 @@ use std::sync::Arc; pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError}; pub use error::VerifierError; -pub use tlsn_core::{VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError}; +pub use tlsn_core::{ + VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError, + webpki::ServerCertVerifier, +}; use crate::{ Role, @@ -31,7 +34,7 @@ use mpz_core::Block; use mpz_garble_core::Delta; use mpz_vm_core::prelude::*; use serio::stream::IoStreamExt; -use tls_core::{msgs::enums::ContentType, verify::WebPkiVerifier}; +use tls_core::msgs::enums::ContentType; use tlsn_core::{ ProvePayload, connection::{ConnectionInfo, ServerName}, @@ -305,15 +308,20 @@ impl Verifier { } = &mut self.state; let ProvePayload { - server_identity, + handshake, transcript, transcript_commit, } = mux_fut .poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from)) .await?; - let verifier = WebPkiVerifier::new(self.config.root_store().clone(), None); - let server_name = if let Some((name, cert_data)) = server_identity { + let verifier = if let Some(root_store) = self.config.root_store() { + ServerCertVerifier::new(root_store).map_err(VerifierError::config)? + } else { + ServerCertVerifier::mozilla() + }; + + let server_name = if let Some((name, cert_data)) = handshake { cert_data .verify( &verifier, diff --git a/crates/tlsn/src/verifier/config.rs b/crates/tlsn/src/verifier/config.rs index 97361fb88..46e6bda26 100644 --- a/crates/tlsn/src/verifier/config.rs +++ b/crates/tlsn/src/verifier/config.rs @@ -2,7 +2,7 @@ use std::fmt::{Debug, Formatter, Result}; use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator}; use mpc_tls::Config; -use tls_core::anchors::{OwnedTrustAnchor, RootCertStore}; +use tlsn_core::webpki::RootCertStore; /// Configuration for the [`Verifier`](crate::tls::Verifier). #[allow(missing_docs)] @@ -10,8 +10,8 @@ use tls_core::anchors::{OwnedTrustAnchor, RootCertStore}; #[builder(pattern = "owned")] pub struct VerifierConfig { protocol_config_validator: ProtocolConfigValidator, - #[builder(default = "default_root_store()")] - root_store: RootCertStore, + #[builder(setter(strip_option))] + root_store: Option, } impl Debug for VerifierConfig { @@ -34,8 +34,8 @@ impl VerifierConfig { } /// Returns the root certificate store. - pub fn root_store(&self) -> &RootCertStore { - &self.root_store + pub fn root_store(&self) -> Option<&RootCertStore> { + self.root_store.as_ref() } pub(crate) fn build_mpc_tls_config(&self, protocol_config: &ProtocolConfig) -> Config { @@ -61,16 +61,3 @@ impl VerifierConfig { builder.build().unwrap() } } - -fn default_root_store() -> RootCertStore { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject.as_ref(), - ta.subject_public_key_info.as_ref(), - ta.name_constraints.as_ref().map(|nc| nc.as_ref()), - ) - })); - - root_store -} diff --git a/crates/tlsn/src/verifier/error.rs b/crates/tlsn/src/verifier/error.rs index 665c7550e..f3b018196 100644 --- a/crates/tlsn/src/verifier/error.rs +++ b/crates/tlsn/src/verifier/error.rs @@ -20,6 +20,13 @@ impl VerifierError { } } + pub(crate) fn config(source: E) -> Self + where + E: Into>, + { + Self::new(ErrorKind::Config, source) + } + pub(crate) fn mpc(source: E) -> Self where E: Into>, diff --git a/crates/tlsn/tests/test.rs b/crates/tlsn/tests/test.rs index 1987b6e0c..15507a59d 100644 --- a/crates/tlsn/tests/test.rs +++ b/crates/tlsn/tests/test.rs @@ -1,6 +1,7 @@ use futures::{AsyncReadExt, AsyncWriteExt}; use tlsn::{ - config::{ProtocolConfig, ProtocolConfigValidator}, + config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore}, + connection::ServerName, prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, transcript::{TranscriptCommitConfig, TranscriptCommitment}, verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig}, @@ -37,19 +38,17 @@ async fn prover(verifier_soc let server_task = tokio::spawn(bind(server_socket.compat())); - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - let mut tls_config_builder = TlsConfig::builder(); - tls_config_builder.root_store(root_store); + tls_config_builder.root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }); let tls_config = tls_config_builder.build().unwrap(); + let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()); let prover = Prover::new( ProverConfig::builder() - .server_name(SERVER_DOMAIN) + .server_name(server_name) .tls_config(tls_config) .protocol_config( ProtocolConfig::builder() @@ -110,11 +109,6 @@ async fn prover(verifier_soc #[instrument(skip(socket))] async fn verifier(socket: T) { - let mut root_store = tls_core::anchors::RootCertStore::empty(); - root_store - .add(&tls_core::key::Certificate(CA_CERT_DER.to_vec())) - .unwrap(); - let config_validator = ProtocolConfigValidator::builder() .max_sent_data(MAX_SENT_DATA) .max_recv_data(MAX_RECV_DATA) @@ -123,7 +117,9 @@ async fn verifier(soc let verifier = Verifier::new( VerifierConfig::builder() - .root_store(root_store) + .root_store(RootCertStore { + roots: vec![CertificateDer(CA_CERT_DER.to_vec())], + }) .protocol_config_validator(config_validator) .build() .unwrap(), @@ -140,7 +136,9 @@ async fn verifier(soc let transcript = transcript.unwrap(); - assert_eq!(server_name.unwrap().as_str(), SERVER_DOMAIN); + let ServerName::Dns(server_name) = server_name.unwrap(); + + assert_eq!(server_name.as_str(), SERVER_DOMAIN); assert!(!transcript.is_complete()); assert_eq!( transcript.sent_authed().iter_ranges().next().unwrap(), diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs index 6ac9d9887..107122384 100644 --- a/crates/wasm/src/prover/config.rs +++ b/crates/wasm/src/prover/config.rs @@ -1,7 +1,11 @@ use crate::types::NetworkSetting; use serde::Deserialize; -use tlsn::config::ProtocolConfig; +use tlsn::{ + config::{CertificateDer, PrivateKeyDer, ProtocolConfig}, + connection::ServerName, +}; use tsify_next::Tsify; +use wasm_bindgen::JsError; #[derive(Debug, Tsify, Deserialize)] #[tsify(from_wasm_abi)] @@ -17,8 +21,10 @@ pub struct ProverConfig { pub client_auth: Option<(Vec>, Vec)>, } -impl From for tlsn::prover::ProverConfig { - fn from(value: ProverConfig) -> Self { +impl TryFrom for tlsn::prover::ProverConfig { + type Error = JsError; + + fn try_from(value: ProverConfig) -> Result { let mut builder = ProtocolConfig::builder(); builder.max_sent_data(value.max_sent_data); @@ -44,21 +50,36 @@ impl From for tlsn::prover::ProverConfig { let protocol_config = builder.build().unwrap(); let mut builder = tlsn::prover::TlsConfig::builder(); - if let Some(cert_key) = value.client_auth { - // Try to parse as PEM-encoded. - if builder.client_auth_pem(cert_key.clone()).is_err() { - // Otherwise assume DER encoding. - builder.client_auth(cert_key); - } + if let Some((certs, key)) = value.client_auth { + let certs = certs + .into_iter() + .map(|cert| { + // Try to parse as PEM-encoded, otherwise assume DER. + if let Ok(cert) = CertificateDer::from_pem_slice(&cert) { + cert + } else { + CertificateDer(cert) + } + }) + .collect(); + let key = PrivateKeyDer(key); + builder.client_auth((certs, key)); } let tls_config = builder.build().unwrap(); + let server_name = ServerName::Dns( + value + .server_name + .try_into() + .map_err(|_| JsError::new("invalid server name"))?, + ); + let mut builder = tlsn::prover::ProverConfig::builder(); builder - .server_name(value.server_name.as_ref()) + .server_name(server_name) .protocol_config(protocol_config) .tls_config(tls_config); - builder.build().unwrap() + Ok(builder.build().unwrap()) } } diff --git a/crates/wasm/src/prover/mod.rs b/crates/wasm/src/prover/mod.rs index 27111924e..a3390db25 100644 --- a/crates/wasm/src/prover/mod.rs +++ b/crates/wasm/src/prover/mod.rs @@ -41,10 +41,10 @@ impl State { #[wasm_bindgen(js_class = Prover)] impl JsProver { #[wasm_bindgen(constructor)] - pub fn new(config: ProverConfig) -> JsProver { - JsProver { - state: State::Initialized(Prover::new(config.into())), - } + pub fn new(config: ProverConfig) -> Result { + Ok(JsProver { + state: State::Initialized(Prover::new(config.try_into()?)), + }) } /// Set up the prover. diff --git a/crates/wasm/src/verifier/mod.rs b/crates/wasm/src/verifier/mod.rs index 46a0595c4..2f7c75817 100644 --- a/crates/wasm/src/verifier/mod.rs +++ b/crates/wasm/src/verifier/mod.rs @@ -5,7 +5,7 @@ pub use config::VerifierConfig; use enum_try_as_inner::EnumTryAsInner; use tls_core::msgs::enums::ContentType; use tlsn::{ - connection::{ConnectionInfo, TranscriptLength}, + connection::{ConnectionInfo, ServerName, TranscriptLength}, verifier::{ state::{self, Initialized}, Verifier, VerifyConfig, @@ -106,7 +106,10 @@ impl JsVerifier { self.state = State::Complete; Ok(VerifierOutput { - server_name: output.server_name.map(|s| s.as_str().to_string()), + server_name: output.server_name.map(|name| { + let ServerName::Dns(name) = name; + name.to_string() + }), connection_info: connection_info.into(), transcript: output.transcript.map(|t| t.into()), })