refactor: clean up web pki (#967)

* refactor: clean up web pki

* fix time import

* clippy

* fix wasm
This commit is contained in:
sinu.eth
2025-08-18 08:36:04 -07:00
committed by GitHub
parent cca9a318a4
commit 21086d2883
59 changed files with 819 additions and 2088 deletions

38
Cargo.lock generated
View File

@@ -5931,6 +5931,7 @@ dependencies = [
"rangeset",
"rstest",
"rustls-pki-types",
"rustls-webpki",
"semver 1.0.26",
"serde",
"serio",
@@ -5951,7 +5952,7 @@ dependencies = [
"tracing-subscriber",
"uid-mux",
"web-spawn",
"webpki-roots 0.26.11",
"webpki-roots",
]
[[package]]
@@ -5975,7 +5976,6 @@ dependencies = [
"tlsn-core",
"tlsn-data-fixtures",
"tlsn-tls-core",
"webpki-roots 0.26.11",
]
[[package]]
@@ -6013,6 +6013,8 @@ dependencies = [
"rangeset",
"rs_merkle",
"rstest",
"rustls-pki-types",
"rustls-webpki",
"serde",
"sha2",
"thiserror 1.0.69",
@@ -6021,7 +6023,9 @@ dependencies = [
"tlsn-tls-core",
"tlsn-utils",
"web-time 0.2.4",
"webpki-roots 0.26.11",
"webpki-root-certs",
"webpki-roots",
"zeroize",
]
[[package]]
@@ -6068,11 +6072,9 @@ dependencies = [
"spansy",
"tls-server-fixture",
"tlsn",
"tlsn-core",
"tlsn-formats",
"tlsn-server-fixture",
"tlsn-server-fixture-certs",
"tlsn-tls-core",
"tokio",
"tokio-util",
"tracing",
@@ -6119,10 +6121,8 @@ dependencies = [
"serde_json",
"serio",
"tlsn",
"tlsn-core",
"tlsn-harness-core",
"tlsn-server-fixture-certs",
"tlsn-tls-core",
"tlsn-wasm",
"tokio",
"tokio-util",
@@ -6250,6 +6250,8 @@ dependencies = [
"rand 0.9.1",
"rand_chacha 0.9.0",
"rstest",
"rustls-pki-types",
"rustls-webpki",
"serde",
"serio",
"thiserror 1.0.69",
@@ -6324,14 +6326,15 @@ dependencies = [
"ring 0.17.14",
"rustls 0.20.9",
"rustls-pemfile",
"rustls-pki-types",
"rustls-webpki",
"sct",
"sha2",
"tlsn-tls-backend",
"tlsn-tls-core",
"tokio",
"web-time 0.2.4",
"webpki",
"webpki-roots 0.26.11",
"webpki-roots",
]
[[package]]
@@ -6344,6 +6347,8 @@ dependencies = [
"hyper",
"hyper-util",
"rstest",
"rustls-pki-types",
"rustls-webpki",
"thiserror 1.0.69",
"tls-server-fixture",
"tlsn-tls-client",
@@ -6361,13 +6366,14 @@ dependencies = [
"rand 0.9.1",
"ring 0.17.14",
"rustls-pemfile",
"rustls-pki-types",
"rustls-webpki",
"sct",
"serde",
"sha2",
"thiserror 1.0.69",
"tracing",
"web-time 0.2.4",
"webpki",
]
[[package]]
@@ -7064,19 +7070,19 @@ dependencies = [
]
[[package]]
name = "webpki-roots"
version = "0.26.11"
name = "webpki-root-certs"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
checksum = "4e4ffd8df1c57e87c325000a3d6ef93db75279dc3a231125aac571650f22b12a"
dependencies = [
"webpki-roots 1.0.1",
"rustls-pki-types",
]
[[package]]
name = "webpki-roots"
version = "1.0.1"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8782dd5a41a24eed3a4f40b606249b3e236ca61adf1f25ea4d45c73de122b502"
checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2"
dependencies = [
"rustls-pki-types",
]

View File

@@ -137,6 +137,8 @@ rs_merkle = { git = "https://github.com/tlsnotary/rs-merkle.git", rev = "85f3e82
rstest = { version = "0.17" }
rustls = { version = "0.21" }
rustls-pemfile = { version = "1.0" }
rustls-webpki = { version = "0.103" }
rustls-pki-types = { version = "1.12" }
sct = { version = "0.7" }
semver = { version = "1.0" }
serde = { version = "1.0" }
@@ -157,7 +159,8 @@ wasm-bindgen = { version = "0.2" }
wasm-bindgen-futures = { version = "0.4" }
web-spawn = { version = "0.2" }
web-time = { version = "0.2" }
webpki = { version = "0.22" }
webpki-roots = { version = "0.26" }
webpki-roots = { version = "1.0" }
webpki-root-certs = { version = "1.0" }
# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958
ws_stream_wasm = { git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51" }
zeroize = { version = "1.8" }

View File

@@ -21,7 +21,6 @@ rand = { workspace = true }
serde = { workspace = true, features = ["derive"] }
thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
webpki-roots = { workspace = true }
[dev-dependencies]
alloy-primitives = { version = "0.8.22", default-features = false }

View File

@@ -242,7 +242,7 @@ impl std::fmt::Display for AttestationBuilderError {
mod test {
use rstest::{fixture, rstest};
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2},
connection::{CertBinding, CertBindingV1_2},
fixtures::{ConnectionFixture, encoding_provider},
hash::Blake3,
transcript::Transcript,
@@ -399,10 +399,10 @@ mod test {
server_cert_data, ..
} = connection;
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake
}) = server_cert_data.binding
else {
panic!("expected v1.2 handshake data");
};
@@ -470,10 +470,10 @@ mod test {
..
} = connection;
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake
}) = server_cert_data.binding
else {
panic!("expected v1.2 handshake data");
};

View File

@@ -22,7 +22,7 @@
use serde::{Deserialize, Serialize};
use tlsn_core::{
connection::{CertificateVerificationError, ServerCertData, ServerEphemKey, ServerName},
connection::{HandshakeData, HandshakeVerificationError, ServerEphemKey, ServerName},
hash::{Blinded, HashAlgorithm, HashProviderError, TypedHash},
};
@@ -30,14 +30,14 @@ use crate::{CryptoProvider, hash::HashAlgorithmExt, serialize::impl_domain_separ
/// Opens a [`ServerCertCommitment`].
#[derive(Clone, Serialize, Deserialize)]
pub struct ServerCertOpening(Blinded<ServerCertData>);
pub struct ServerCertOpening(Blinded<HandshakeData>);
impl_domain_separator!(ServerCertOpening);
opaque_debug::implement!(ServerCertOpening);
impl ServerCertOpening {
pub(crate) fn new(data: ServerCertData) -> Self {
pub(crate) fn new(data: HandshakeData) -> Self {
Self(Blinded::new(data))
}
@@ -49,7 +49,7 @@ impl ServerCertOpening {
}
/// Returns the server identity data.
pub fn data(&self) -> &ServerCertData {
pub fn data(&self) -> &HandshakeData {
self.0.data()
}
}
@@ -122,8 +122,8 @@ impl From<HashProviderError> for ServerIdentityProofError {
}
}
impl From<CertificateVerificationError> for ServerIdentityProofError {
fn from(err: CertificateVerificationError) -> Self {
impl From<HandshakeVerificationError> for ServerIdentityProofError {
fn from(err: HandshakeVerificationError) -> Self {
Self {
kind: ErrorKind::Certificate,
message: err.to_string(),

View File

@@ -1,7 +1,7 @@
//! Attestation fixtures.
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2},
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
hash::HashAlgorithm,
transcript::{
@@ -67,7 +67,7 @@ pub fn request_fixture(
let mut request_builder = Request::builder(&request_config);
request_builder
.server_name(server_name)
.server_cert_data(server_cert_data)
.handshake_data(server_cert_data)
.transcript(transcript);
let (request, _) = request_builder.build(&provider).unwrap();
@@ -91,12 +91,12 @@ pub fn attestation_fixture(
..
} = connection;
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake
}) = server_cert_data.binding
else {
panic!("expected v1.2 handshake data");
panic!("expected v1.2 binding data");
};
let mut provider = CryptoProvider::default();

View File

@@ -1,8 +1,4 @@
use tls_core::{
anchors::{OwnedTrustAnchor, RootCertStore},
verify::WebPkiVerifier,
};
use tlsn_core::hash::HashProvider;
use tlsn_core::{hash::HashProvider, webpki::ServerCertVerifier};
use crate::signing::{SignatureVerifierProvider, SignerProvider};
@@ -28,7 +24,7 @@ pub struct CryptoProvider {
/// This is used to verify the server's certificate chain.
///
/// The default verifier uses the Mozilla root certificates.
pub cert: WebPkiVerifier,
pub cert: ServerCertVerifier,
/// Signer provider.
///
/// This is used for signing attestations.
@@ -45,21 +41,9 @@ impl Default for CryptoProvider {
fn default() -> Self {
Self {
hash: Default::default(),
cert: default_cert_verifier(),
cert: ServerCertVerifier::mozilla(),
signer: Default::default(),
signature: Default::default(),
}
}
}
pub(crate) fn default_cert_verifier() -> WebPkiVerifier {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
WebPkiVerifier::new(root_store, None)
}

View File

@@ -1,5 +1,5 @@
use tlsn_core::{
connection::{ServerCertData, ServerName},
connection::{HandshakeData, ServerName},
transcript::{Transcript, TranscriptCommitment, TranscriptSecret},
};
@@ -13,7 +13,7 @@ use crate::{
pub struct RequestBuilder<'a> {
config: &'a RequestConfig,
server_name: Option<ServerName>,
server_cert_data: Option<ServerCertData>,
handshake_data: Option<HandshakeData>,
transcript: Option<Transcript>,
transcript_commitments: Vec<TranscriptCommitment>,
transcript_commitment_secrets: Vec<TranscriptSecret>,
@@ -25,7 +25,7 @@ impl<'a> RequestBuilder<'a> {
Self {
config,
server_name: None,
server_cert_data: None,
handshake_data: None,
transcript: None,
transcript_commitments: Vec::new(),
transcript_commitment_secrets: Vec::new(),
@@ -38,9 +38,9 @@ impl<'a> RequestBuilder<'a> {
self
}
/// Sets the server identity data.
pub fn server_cert_data(&mut self, data: ServerCertData) -> &mut Self {
self.server_cert_data = Some(data);
/// Sets the handshake data.
pub fn handshake_data(&mut self, data: HandshakeData) -> &mut Self {
self.handshake_data = Some(data);
self
}
@@ -69,7 +69,7 @@ impl<'a> RequestBuilder<'a> {
let Self {
config,
server_name,
server_cert_data,
handshake_data: server_cert_data,
transcript,
transcript_commitments,
transcript_commitment_secrets,

View File

@@ -46,7 +46,7 @@ pub(crate) use impl_domain_separator;
impl_domain_separator!(tlsn_core::connection::ServerEphemKey);
impl_domain_separator!(tlsn_core::connection::ConnectionInfo);
impl_domain_separator!(tlsn_core::connection::HandshakeData);
impl_domain_separator!(tlsn_core::connection::CertBinding);
impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment);
impl_domain_separator!(tlsn_core::transcript::TranscriptSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment);

View File

@@ -5,7 +5,7 @@ use tlsn_attestation::{
signing::SignatureAlgId,
};
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2},
connection::{CertBinding, CertBindingV1_2},
fixtures::{self, ConnectionFixture, encoder_secret},
hash::Blake3,
transcript::{
@@ -36,10 +36,10 @@ fn test_api() {
server_cert_data,
} = ConnectionFixture::tlsnotary(transcript.length());
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake.clone()
}) = server_cert_data.binding.clone()
else {
unreachable!()
};
@@ -72,7 +72,7 @@ fn test_api() {
request_builder
.server_name(server_name.clone())
.server_cert_data(server_cert_data)
.handshake_data(server_cert_data)
.transcript(transcript)
.transcript_commitments(
vec![TranscriptSecret::Encoding(encoding_tree)],

View File

@@ -36,10 +36,14 @@ thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
web-time = { workspace = true }
webpki-roots = { workspace = true }
rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true }
itybity = { workspace = true }
zeroize = { workspace = true }
[dev-dependencies]
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -2,16 +2,11 @@
use std::fmt;
use rustls_pki_types as webpki_types;
use serde::{Deserialize, Serialize};
use tls_core::{
msgs::{
codec::Codec,
enums::NamedGroup,
handshake::{DigitallySignedStruct, ServerECDHParams},
},
verify::{ServerCertVerifier as _, WebPkiVerifier},
};
use web_time::{Duration, UNIX_EPOCH};
use tls_core::msgs::{codec::Codec, enums::NamedGroup, handshake::ServerECDHParams};
use crate::webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError};
/// TLS version.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -35,40 +30,82 @@ impl TryFrom<tls_core::msgs::enums::ProtocolVersion> for TlsVersion {
}
}
/// Server's name, a.k.a. the DNS name.
/// Server's name.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ServerName(String);
pub enum ServerName {
/// DNS name.
Dns(DnsName),
}
impl ServerName {
/// Creates a new server name.
pub fn new(name: String) -> Self {
Self(name)
}
/// Returns the name as a string.
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&str> for ServerName {
fn from(name: &str) -> Self {
Self(name.to_string())
}
}
impl AsRef<str> for ServerName {
fn as_ref(&self) -> &str {
&self.0
pub(crate) fn to_webpki(&self) -> webpki_types::ServerName<'static> {
match self {
ServerName::Dns(name) => webpki_types::ServerName::DnsName(
webpki_types::DnsName::try_from(name.0.as_str())
.expect("name was validated")
.to_owned(),
),
}
}
}
impl fmt::Display for ServerName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ServerName::Dns(name) => write!(f, "{name}"),
}
}
}
/// DNS name.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(try_from = "String")]
pub struct DnsName(String);
impl DnsName {
/// Returns the DNS name as a string.
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for DnsName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for DnsName {
fn as_ref(&self) -> &str {
&self.0
}
}
/// Error returned when a DNS name is invalid.
#[derive(Debug, thiserror::Error)]
#[error("invalid DNS name")]
pub struct InvalidDnsNameError {}
impl TryFrom<&str> for DnsName {
type Error = InvalidDnsNameError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
// Borrow validation from rustls
match webpki_types::DnsName::try_from_str(value) {
Ok(_) => Ok(DnsName(value.to_string())),
Err(_) => Err(InvalidDnsNameError {}),
}
}
}
impl TryFrom<String> for DnsName {
type Error = InvalidDnsNameError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
/// Type of a public key.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
@@ -98,6 +135,25 @@ pub enum SignatureScheme {
ED25519 = 0x0807,
}
impl fmt::Display for SignatureScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SignatureScheme::RSA_PKCS1_SHA1 => write!(f, "RSA_PKCS1_SHA1"),
SignatureScheme::ECDSA_SHA1_Legacy => write!(f, "ECDSA_SHA1_Legacy"),
SignatureScheme::RSA_PKCS1_SHA256 => write!(f, "RSA_PKCS1_SHA256"),
SignatureScheme::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"),
SignatureScheme::RSA_PKCS1_SHA384 => write!(f, "RSA_PKCS1_SHA384"),
SignatureScheme::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"),
SignatureScheme::RSA_PKCS1_SHA512 => write!(f, "RSA_PKCS1_SHA512"),
SignatureScheme::ECDSA_NISTP521_SHA512 => write!(f, "ECDSA_NISTP521_SHA512"),
SignatureScheme::RSA_PSS_SHA256 => write!(f, "RSA_PSS_SHA256"),
SignatureScheme::RSA_PSS_SHA384 => write!(f, "RSA_PSS_SHA384"),
SignatureScheme::RSA_PSS_SHA512 => write!(f, "RSA_PSS_SHA512"),
SignatureScheme::ED25519 => write!(f, "ED25519"),
}
}
}
impl TryFrom<tls_core::msgs::enums::SignatureScheme> for SignatureScheme {
type Error = &'static str;
@@ -142,16 +198,6 @@ impl From<SignatureScheme> for tls_core::msgs::enums::SignatureScheme {
}
}
/// X.509 certificate, DER encoded.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Certificate(pub Vec<u8>);
impl From<tls_core::key::Certificate> for Certificate {
fn from(cert: tls_core::key::Certificate) -> Self {
Self(cert.0)
}
}
/// Server's signature of the key exchange parameters.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerSignature {
@@ -220,9 +266,9 @@ pub struct TranscriptLength {
pub received: u32,
}
/// TLS 1.2 handshake data.
/// TLS 1.2 certificate binding.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeDataV1_2 {
pub struct CertBindingV1_2 {
/// Client random.
pub client_random: [u8; 32],
/// Server random.
@@ -231,13 +277,18 @@ pub struct HandshakeDataV1_2 {
pub server_ephemeral_key: ServerEphemKey,
}
/// TLS handshake data.
/// TLS certificate binding.
///
/// This is the data that the server signs using its public key in the
/// certificate it presents during the TLS handshake. This provides a binding
/// between the server's identity and the ephemeral keys used to authenticate
/// the TLS session.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum HandshakeData {
/// TLS 1.2 handshake data.
V1_2(HandshakeDataV1_2),
pub enum CertBinding {
/// TLS 1.2 certificate binding.
V1_2(CertBindingV1_2),
}
/// Verify data from the TLS handshake finished messages.
@@ -249,19 +300,19 @@ pub struct VerifyData {
pub server_finished: Vec<u8>,
}
/// Server certificate and handshake data.
/// TLS handshake data.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerCertData {
/// Certificate chain.
pub certs: Vec<Certificate>,
/// Server signature of the key exchange parameters.
pub struct HandshakeData {
/// Server certificate chain.
pub certs: Vec<CertificateDer>,
/// Server certificate signature over the binding message.
pub sig: ServerSignature,
/// TLS handshake data.
pub handshake: HandshakeData,
/// Certificate binding.
pub binding: CertBinding,
}
impl ServerCertData {
/// Verifies the server certificate data.
impl HandshakeData {
/// Verifies the handshake data.
///
/// # Arguments
///
@@ -271,53 +322,35 @@ impl ServerCertData {
/// * `server_name` - The server name.
pub fn verify(
&self,
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
time: u64,
server_ephemeral_key: &ServerEphemKey,
server_name: &ServerName,
) -> Result<(), CertificateVerificationError> {
) -> Result<(), HandshakeVerificationError> {
#[allow(irrefutable_let_patterns)]
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
client_random,
server_random,
server_ephemeral_key: expected_server_ephemeral_key,
}) = &self.handshake
}) = &self.binding
else {
unreachable!("only TLS 1.2 is implemented")
};
if server_ephemeral_key != expected_server_ephemeral_key {
return Err(CertificateVerificationError::InvalidServerEphemeralKey);
return Err(HandshakeVerificationError::InvalidServerEphemeralKey);
}
// Verify server name.
let server_name = tls_core::dns::ServerName::try_from(server_name.as_ref())
.map_err(|_| CertificateVerificationError::InvalidIdentity(server_name.clone()))?;
// Verify server certificate.
let cert_chain = self
let (end_entity, intermediates) = self
.certs
.clone()
.into_iter()
.map(|cert| tls_core::key::Certificate(cert.0))
.collect::<Vec<_>>();
let (end_entity, intermediates) = cert_chain
.split_first()
.ok_or(CertificateVerificationError::MissingCerts)?;
.ok_or(HandshakeVerificationError::MissingCerts)?;
// Verify the end entity cert is valid for the provided server name
// and that it chains to at least one of the roots we trust.
verifier
.verify_server_cert(
end_entity,
intermediates,
&server_name,
&mut [].into_iter(),
&[],
UNIX_EPOCH + Duration::from_secs(time),
)
.map_err(|_| CertificateVerificationError::InvalidCert)?;
.verify_server_cert(end_entity, intermediates, server_name, time)
.map_err(HandshakeVerificationError::ServerCert)?;
// Verify the signature matches the certificate and key exchange parameters.
let mut message = Vec::new();
@@ -325,11 +358,31 @@ impl ServerCertData {
message.extend_from_slice(server_random);
message.extend_from_slice(&server_ephemeral_key.kx_params());
let dss = DigitallySignedStruct::new(self.sig.scheme.into(), self.sig.sig.clone());
use webpki::ring as alg;
let sig_alg = match self.sig.scheme {
SignatureScheme::RSA_PKCS1_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256,
SignatureScheme::RSA_PKCS1_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384,
SignatureScheme::RSA_PKCS1_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512,
SignatureScheme::RSA_PSS_SHA256 => alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA384 => alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA512 => alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
SignatureScheme::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384,
SignatureScheme::ED25519 => alg::ED25519,
scheme => {
return Err(HandshakeVerificationError::UnsupportedSignatureScheme(
scheme,
))
}
};
verifier
.verify_tls12_signature(&message, end_entity, &dss)
.map_err(|_| CertificateVerificationError::InvalidServerSignature)?;
let end_entity = webpki_types::CertificateDer::from(end_entity.0.as_slice());
let end_entity = webpki::EndEntityCert::try_from(&end_entity)
.map_err(|_| HandshakeVerificationError::InvalidEndEntityCertificate)?;
end_entity
.verify_signature(sig_alg, &message, &self.sig.sig)
.map_err(|_| HandshakeVerificationError::InvalidServerSignature)?;
Ok(())
}
@@ -338,58 +391,51 @@ impl ServerCertData {
/// Errors that can occur when verifying a certificate chain or signature.
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum CertificateVerificationError {
#[error("invalid server identity: {0}")]
InvalidIdentity(ServerName),
pub enum HandshakeVerificationError {
#[error("invalid end entity certificate")]
InvalidEndEntityCertificate,
#[error("missing server certificates")]
MissingCerts,
#[error("invalid server certificate")]
InvalidCert,
#[error("invalid server signature")]
InvalidServerSignature,
#[error("invalid server ephemeral key")]
InvalidServerEphemeralKey,
#[error("server certificate verification failed: {0}")]
ServerCert(ServerCertVerifierError),
#[error("unsupported signature scheme: {0}")]
UnsupportedSignatureScheme(SignatureScheme),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{fixtures::ConnectionFixture, transcript::Transcript};
use crate::{fixtures::ConnectionFixture, transcript::Transcript, webpki::RootCertStore};
use hex::FromHex;
use rstest::*;
use tls_core::{
anchors::{OwnedTrustAnchor, RootCertStore},
verify::WebPkiVerifier,
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
#[fixture]
#[once]
fn verifier() -> WebPkiVerifier {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
fn verifier() -> ServerCertVerifier {
let mut root_store = RootCertStore {
roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|c| CertificateDer(c.to_vec()))
.collect(),
};
// Add a cert which is no longer included in the Mozilla root store.
let cert = tls_core::key::Certificate(
root_store.roots.push(
appliedzkp()
.server_cert_data
.certs
.last()
.expect("chain is valid")
.0
.clone(),
);
root_store.add(&cert).unwrap();
WebPkiVerifier::new(root_store, None)
ServerCertVerifier::new(&root_store).unwrap()
}
fn tlsnotary() -> ConnectionFixture {
@@ -405,7 +451,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_sucess_ca_implicit(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Remove the CA cert
@@ -417,7 +463,7 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
)
.is_ok());
}
@@ -428,7 +474,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_success_ca_explicit(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] data: ConnectionFixture,
) {
assert!(data
@@ -437,7 +483,7 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
)
.is_ok());
}
@@ -447,7 +493,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_bad_time(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] data: ConnectionFixture,
) {
// unix time when the cert chain was NOT valid
@@ -457,12 +503,12 @@ mod tests {
verifier,
bad_time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -471,7 +517,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_no_interm_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Remove the CA cert
@@ -483,12 +529,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -498,7 +544,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_no_interm_cert_with_ca_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Remove the intermediate cert
@@ -508,12 +554,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -522,24 +568,24 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_bad_ee_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
let ee: &[u8] = include_bytes!("./fixtures/data/unknown/ee.der");
// Change the end entity cert
data.server_cert_data.certs[0] = Certificate(ee.to_vec());
data.server_cert_data.certs[0] = CertificateDer(ee.to_vec());
let err = data.server_cert_data.verify(
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -548,23 +594,23 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_sig_ke_params_fail_bad_client_random(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
let HandshakeData::V1_2(HandshakeDataV1_2 { client_random, .. }) =
&mut data.server_cert_data.handshake;
let CertBinding::V1_2(CertBindingV1_2 { client_random, .. }) =
&mut data.server_cert_data.binding;
client_random[31] = client_random[31].wrapping_add(1);
let err = data.server_cert_data.verify(
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidServerSignature
HandshakeVerificationError::InvalidServerSignature
));
}
@@ -573,7 +619,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_sig_ke_params_fail_bad_sig(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
data.server_cert_data.sig.sig[31] = data.server_cert_data.sig.sig[31].wrapping_add(1);
@@ -582,12 +628,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidServerSignature
HandshakeVerificationError::InvalidServerSignature
));
}
@@ -596,10 +642,10 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_check_dns_name_present_in_cert_fail_bad_host(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] data: ConnectionFixture,
) {
let bad_name = ServerName::from("badhost.com");
let bad_name = ServerName::Dns(DnsName::try_from("badhost.com").unwrap());
let err = data.server_cert_data.verify(
verifier,
@@ -610,7 +656,7 @@ mod tests {
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -618,7 +664,7 @@ mod tests {
#[rstest]
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_invalid_ephemeral_key(verifier: &WebPkiVerifier, #[case] data: ConnectionFixture) {
fn test_invalid_ephemeral_key(verifier: &ServerCertVerifier, #[case] data: ConnectionFixture) {
let wrong_ephemeral_key = ServerEphemKey {
typ: KeyType::SECP256R1,
key: Vec::<u8>::from_hex(include_bytes!("./fixtures/data/unknown/pubkey")).unwrap(),
@@ -628,12 +674,12 @@ mod tests {
verifier,
data.connection_info.time,
&wrong_ephemeral_key,
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidServerEphemeralKey
HandshakeVerificationError::InvalidServerEphemeralKey
));
}
@@ -642,7 +688,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_no_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Empty certs
@@ -652,12 +698,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::MissingCerts
HandshakeVerificationError::MissingCerts
));
}
}

View File

@@ -8,13 +8,14 @@ use hex::FromHex;
use crate::{
connection::{
Certificate, ConnectionInfo, HandshakeData, HandshakeDataV1_2, KeyType, ServerCertData,
CertBinding, CertBindingV1_2, ConnectionInfo, DnsName, HandshakeData, KeyType,
ServerEphemKey, ServerName, ServerSignature, SignatureScheme, TlsVersion, TranscriptLength,
},
transcript::{
encoding::{EncoderSecret, EncodingProvider},
Transcript,
},
webpki::CertificateDer,
};
/// A fixture containing various TLS connection data.
@@ -23,24 +24,26 @@ use crate::{
pub struct ConnectionFixture {
pub server_name: ServerName,
pub connection_info: ConnectionInfo,
pub server_cert_data: ServerCertData,
pub server_cert_data: HandshakeData,
}
impl ConnectionFixture {
/// Returns a connection fixture for tlsnotary.org.
pub fn tlsnotary(transcript_length: TranscriptLength) -> Self {
ConnectionFixture {
server_name: ServerName::new("tlsnotary.org".to_string()),
server_name: ServerName::Dns(DnsName::try_from("tlsnotary.org").unwrap()),
connection_info: ConnectionInfo {
time: 1671637529,
version: TlsVersion::V1_2,
transcript_length,
},
server_cert_data: ServerCertData {
server_cert_data: HandshakeData {
certs: vec![
Certificate(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()),
Certificate(include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec()),
Certificate(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()),
CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()),
CertificateDer(
include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec(),
),
CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()),
],
sig: ServerSignature {
scheme: SignatureScheme::RSA_PKCS1_SHA256,
@@ -49,7 +52,7 @@ impl ConnectionFixture {
))
.unwrap(),
},
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
binding: CertBinding::V1_2(CertBindingV1_2 {
client_random: <[u8; 32]>::from_hex(include_bytes!(
"fixtures/data/tlsnotary.org/client_random"
))
@@ -73,17 +76,19 @@ impl ConnectionFixture {
/// Returns a connection fixture for appliedzkp.org.
pub fn appliedzkp(transcript_length: TranscriptLength) -> Self {
ConnectionFixture {
server_name: ServerName::new("appliedzkp.org".to_string()),
server_name: ServerName::Dns(DnsName::try_from("appliedzkp.org").unwrap()),
connection_info: ConnectionInfo {
time: 1671637529,
version: TlsVersion::V1_2,
transcript_length,
},
server_cert_data: ServerCertData {
server_cert_data: HandshakeData {
certs: vec![
Certificate(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()),
Certificate(include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec()),
Certificate(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()),
CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()),
CertificateDer(
include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec(),
),
CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()),
],
sig: ServerSignature {
scheme: SignatureScheme::ECDSA_NISTP256_SHA256,
@@ -92,7 +97,7 @@ impl ConnectionFixture {
))
.unwrap(),
},
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
binding: CertBinding::V1_2(CertBindingV1_2 {
client_random: <[u8; 32]>::from_hex(include_bytes!(
"fixtures/data/appliedzkp.org/client_random"
))
@@ -115,10 +120,10 @@ impl ConnectionFixture {
/// Returns the server_ephemeral_key fixture.
pub fn server_ephemeral_key(&self) -> &ServerEphemKey {
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = &self.server_cert_data.handshake;
}) = &self.server_cert_data.binding;
server_ephemeral_key
}
}

View File

@@ -10,12 +10,13 @@ pub mod fixtures;
pub mod hash;
pub mod merkle;
pub mod transcript;
pub mod webpki;
use rangeset::ToRangeSet;
use serde::{Deserialize, Serialize};
use crate::{
connection::{ServerCertData, ServerName},
connection::{HandshakeData, ServerName},
transcript::{
Direction, Idx, PartialTranscript, Transcript, TranscriptCommitConfig,
TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret,
@@ -200,8 +201,8 @@ enum VerifyConfigBuilderErrorRepr {}
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload {
/// Server identity data.
pub server_identity: Option<(ServerName, ServerCertData)>,
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Transcript commitment configuration.

View File

@@ -2,10 +2,10 @@
use crate::{
connection::{
Certificate, HandshakeData, HandshakeDataV1_2, ServerEphemKey, ServerSignature, TlsVersion,
VerifyData,
CertBinding, CertBindingV1_2, ServerEphemKey, ServerSignature, TlsVersion, VerifyData,
},
transcript::{Direction, Transcript},
webpki::CertificateDer,
};
use tls_core::msgs::{
alert::AlertMessagePayload,
@@ -19,9 +19,9 @@ use tls_core::msgs::{
pub struct TlsTranscript {
time: u64,
version: TlsVersion,
server_cert_chain: Option<Vec<Certificate>>,
server_cert_chain: Option<Vec<CertificateDer>>,
server_signature: Option<ServerSignature>,
handshake_data: HandshakeData,
certificate_binding: CertBinding,
sent: Vec<Record>,
recv: Vec<Record>,
}
@@ -32,9 +32,9 @@ impl TlsTranscript {
pub fn new(
time: u64,
version: TlsVersion,
server_cert_chain: Option<Vec<Certificate>>,
server_cert_chain: Option<Vec<CertificateDer>>,
server_signature: Option<ServerSignature>,
handshake_data: HandshakeData,
certificate_binding: CertBinding,
verify_data: VerifyData,
sent: Vec<Record>,
recv: Vec<Record>,
@@ -198,7 +198,7 @@ impl TlsTranscript {
version,
server_cert_chain,
server_signature,
handshake_data,
certificate_binding,
sent,
recv,
})
@@ -215,7 +215,7 @@ impl TlsTranscript {
}
/// Returns the server certificate chain.
pub fn server_cert_chain(&self) -> Option<&[Certificate]> {
pub fn server_cert_chain(&self) -> Option<&[CertificateDer]> {
self.server_cert_chain.as_deref()
}
@@ -226,17 +226,17 @@ impl TlsTranscript {
/// Returns the server ephemeral key used in the TLS handshake.
pub fn server_ephemeral_key(&self) -> &ServerEphemKey {
match &self.handshake_data {
HandshakeData::V1_2(HandshakeDataV1_2 {
match &self.certificate_binding {
CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) => server_ephemeral_key,
}
}
/// Returns the handshake data.
pub fn handshake_data(&self) -> &HandshakeData {
&self.handshake_data
/// Returns the certificate binding data.
pub fn certificate_binding(&self) -> &CertBinding {
&self.certificate_binding
}
/// Returns the sent records.

168
crates/core/src/webpki.rs Normal file
View File

@@ -0,0 +1,168 @@
//! Web PKI types.
use std::time::Duration;
use rustls_pki_types::{self as webpki_types, pem::PemObject};
use serde::{Deserialize, Serialize};
use crate::connection::ServerName;
/// X.509 certificate, DER encoded.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateDer(pub Vec<u8>);
impl CertificateDer {
/// Creates a DER-encoded certificate from a PEM-encoded certificate.
pub fn from_pem_slice(pem: &[u8]) -> Result<Self, PemError> {
let der = webpki_types::CertificateDer::from_pem_slice(pem).map_err(|_| PemError {})?;
Ok(Self(der.to_vec()))
}
}
/// Private key, DER encoded.
#[derive(Debug, Clone, zeroize::ZeroizeOnDrop, Serialize, Deserialize)]
pub struct PrivateKeyDer(pub Vec<u8>);
impl PrivateKeyDer {
/// Creates a DER-encoded private key from a PEM-encoded private key.
pub fn from_pem_slice(pem: &[u8]) -> Result<Self, PemError> {
let der = webpki_types::PrivateKeyDer::from_pem_slice(pem).map_err(|_| PemError {})?;
Ok(Self(der.secret_der().to_vec()))
}
}
/// PEM parsing error.
#[derive(Debug, thiserror::Error)]
#[error("failed to parse PEM object")]
pub struct PemError {}
/// Root certificate store.
///
/// This stores root certificates which are used to verify end-entity
/// certificates presented by a TLS server.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RootCertStore {
/// Unvalidated DER-encoded X.509 root certificates.
pub roots: Vec<CertificateDer>,
}
impl RootCertStore {
/// Creates an empty root certificate store.
pub fn empty() -> Self {
Self { roots: Vec::new() }
}
}
/// Server certificate verifier.
#[derive(Debug)]
pub struct ServerCertVerifier {
roots: Vec<webpki_types::TrustAnchor<'static>>,
}
impl ServerCertVerifier {
/// Creates a new server certificate verifier.
pub fn new(roots: &RootCertStore) -> Result<Self, ServerCertVerifierError> {
let roots = roots
.roots
.iter()
.map(|cert| {
webpki::anchor_from_trusted_cert(&webpki_types::CertificateDer::from(
cert.0.as_slice(),
))
.map(|anchor| anchor.to_owned())
.map_err(|err| ServerCertVerifierError::InvalidRootCertificate {
cert: cert.clone(),
reason: err.to_string(),
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self { roots })
}
/// Creates a new server certificate verifier with Mozilla root
/// certificates.
pub fn mozilla() -> Self {
Self {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
}
}
/// Verifies the server certificate was valid at the given time of
/// presentation.
///
/// # Arguments
///
/// * `end_entity` - End-entity certificate to verify.
/// * `intermediates` - Intermediate certificates to a trust anchor.
/// * `server_name` - Server DNS name.
/// * `time` - Unix time the certificate was presented.
pub fn verify_server_cert(
&self,
end_entity: &CertificateDer,
intermediates: &[CertificateDer],
server_name: &ServerName,
time: u64,
) -> Result<(), ServerCertVerifierError> {
let cert = webpki_types::CertificateDer::from(end_entity.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(|e| {
ServerCertVerifierError::InvalidEndEntityCertificate {
cert: end_entity.clone(),
reason: e.to_string(),
}
})?;
let intermediates = intermediates
.iter()
.map(|c| webpki_types::CertificateDer::from(c.0.as_slice()))
.collect::<Vec<_>>();
let server_name = server_name.to_webpki();
let time = webpki_types::UnixTime::since_unix_epoch(Duration::from_secs(time));
cert.verify_for_usage(
webpki::ALL_VERIFICATION_ALGS,
&self.roots,
&intermediates,
time,
webpki::KeyUsage::server_auth(),
None,
None,
)
.map(|_| ())
.map_err(|_| ServerCertVerifierError::InvalidPath)?;
cert.verify_is_valid_for_subject_name(&server_name)
.map_err(|_| ServerCertVerifierError::InvalidServerName)?;
Ok(())
}
}
/// Error for [`ServerCertVerifier`].
#[derive(Debug, thiserror::Error)]
#[error("server certificate verification failed: {0}")]
pub enum ServerCertVerifierError {
/// Root certificate store contains invalid certificate.
#[error("root certificate store contains invalid certificate: {reason}")]
InvalidRootCertificate {
/// Invalid certificate.
cert: CertificateDer,
/// Reason for invalidity.
reason: String,
},
/// End-entity certificate is invalid.
#[error("end-entity certificate is invalid: {reason}")]
InvalidEndEntityCertificate {
/// Invalid certificate.
cert: CertificateDer,
/// Reason for invalidity.
reason: String,
},
/// Failed to verify certificate path to provided trust anchors.
#[error("failed to verify certificate path to provided trust anchors")]
InvalidPath,
/// Failed to verify certificate is valid for provided server name.
#[error("failed to verify certificate is valid for provided server name")]
InvalidServerName,
}

View File

@@ -8,10 +8,8 @@ version = "0.0.0"
workspace = true
[dependencies]
tlsn-core = { workspace = true }
tlsn = { workspace = true }
tlsn-formats = { workspace = true }
tlsn-tls-core = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }

View File

@@ -12,7 +12,8 @@ use tracing::instrument;
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{ProtocolConfig, ProtocolConfigValidator},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::PartialTranscript,
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
@@ -72,18 +73,16 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(server_domain)
.server_name(ServerName::Dns(server_domain.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
@@ -194,13 +193,10 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let verifier_config = VerifierConfig::builder()
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap();
@@ -234,6 +230,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.unwrap_or_else(|| panic!("Expected valid data from {SERVER_DOMAIN}"));
// Check Session info: server name.
let ServerName::Dns(server_name) = server_name;
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
transcript

View File

@@ -14,8 +14,6 @@ wasm-opt = ["-O3"]
[dependencies]
tlsn-harness-core = { workspace = true }
tlsn = { workspace = true }
tlsn-core = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
inventory = { workspace = true }

View File

@@ -5,7 +5,8 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
use harness_core::bench::{Bench, ProverMetrics};
use tlsn::{
config::ProtocolConfig,
config::{CertificateDer, ProtocolConfig, RootCertStore},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
};
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -32,20 +33,17 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let protocol_config = builder.build()?;
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
let prover = Prover::new(
ProverConfig::builder()
.tls_config(tls_config)
.protocol_config(protocol_config)
.server_name(SERVER_DOMAIN)
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.build()?,
);

View File

@@ -2,7 +2,7 @@ use anyhow::Result;
use harness_core::bench::Bench;
use tlsn::{
config::ProtocolConfigValidator,
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifyConfig},
};
use tlsn_server_fixture_certs::CA_CERT_DER;
@@ -17,14 +17,11 @@ pub async fn bench_verifier(provider: &IoProvider, config: &Bench) -> Result<()>
let protocol_config = builder.build()?;
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(protocol_config)
.build()?,
);

View File

@@ -1,6 +1,6 @@
use tls_core::anchors::RootCertStore;
use tlsn::{
config::{ProtocolConfig, ProtocolConfigValidator},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::HashAlgId,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind},
@@ -21,19 +21,17 @@ const MAX_RECV_DATA: usize = 1 << 11;
crate::test!("basic", prover, verifier);
async fn prover(provider: &IoProvider) {
let mut root_store = RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
let prover = Prover::new(
ProverConfig::builder()
.server_name(SERVER_DOMAIN)
.server_name(server_name)
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
@@ -114,11 +112,6 @@ async fn prover(provider: &IoProvider) {
}
async fn verifier(provider: &IoProvider) {
let mut root_store = RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let config = VerifierConfig::builder()
.protocol_config_validator(
ProtocolConfigValidator::builder()
@@ -127,7 +120,9 @@ async fn verifier(provider: &IoProvider) {
.build()
.unwrap(),
)
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap();
@@ -145,7 +140,9 @@ async fn verifier(provider: &IoProvider) {
.await
.unwrap();
assert_eq!(server_name.unwrap().as_str(), SERVER_DOMAIN);
let ServerName::Dns(server_name) = server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(
transcript_commitments
.iter()

View File

@@ -72,3 +72,5 @@ tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }
uid-mux = { workspace = true, features = ["serio", "test-utils"] }
rustls-pki-types = { workspace = true }
rustls-webpki = { workspace = true }

View File

@@ -22,7 +22,7 @@ use serio::stream::IoStreamExt;
use std::mem;
use tls_core::msgs::enums::NamedGroup;
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2, TlsVersion, VerifyData},
connection::{CertBinding, CertBindingV1_2, TlsVersion, VerifyData},
transcript::TlsTranscript,
};
use tracing::{debug, instrument};
@@ -405,7 +405,7 @@ impl MpcTlsFollower {
let cf_vd = cf_vd.ok_or(MpcTlsError::hs("client finished VD not computed"))?;
let sf_vd = sf_vd.ok_or(MpcTlsError::hs("server finished VD not computed"))?;
let handshake_data = HandshakeData::V1_2(HandshakeDataV1_2 {
let handshake_data = CertBinding::V1_2(CertBindingV1_2 {
client_random,
server_random,
server_ephemeral_key: server_key

View File

@@ -43,10 +43,9 @@ use tls_core::{
suites::SupportedCipherSuite,
};
use tlsn_core::{
connection::{
Certificate, HandshakeData, HandshakeDataV1_2, ServerSignature, TlsVersion, VerifyData,
},
connection::{CertBinding, CertBindingV1_2, ServerSignature, TlsVersion, VerifyData},
transcript::TlsTranscript,
webpki::CertificateDer,
};
use tracing::{debug, instrument, trace, warn};
@@ -325,7 +324,7 @@ impl MpcTlsLeader {
let server_cert_chain = server_cert_details
.cert_chain()
.iter()
.map(|cert| Certificate(cert.0.clone()))
.map(|cert| CertificateDer(cert.0.clone()))
.collect();
let server_signature = ServerSignature {
@@ -337,7 +336,7 @@ impl MpcTlsLeader {
sig: server_kx_details.kx_sig().sig.0.clone(),
};
let handshake_data = HandshakeData::V1_2(HandshakeDataV1_2 {
let handshake_data = CertBinding::V1_2(CertBindingV1_2 {
client_random: client_random.0,
server_random: server_random.0,
server_ephemeral_key: server_key

View File

@@ -12,11 +12,15 @@ use mpz_ot::{
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_client::Certificate;
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "expensive"]
@@ -48,11 +52,11 @@ async fn leader_task(mut leader: MpcTlsLeader) {
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();

View File

@@ -35,3 +35,5 @@ hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -6,7 +6,8 @@ use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use tls_client::{Certificate, ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
@@ -14,6 +15,9 @@ use tls_server_fixture::{
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
@@ -30,7 +34,9 @@ async fn set_up_tls() -> TlsFixture {
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
@@ -75,7 +81,9 @@ async fn test_hyper_ok() {
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)

View File

@@ -23,7 +23,8 @@ async-trait = { workspace = true }
log = { workspace = true, optional = true }
ring = { workspace = true }
sct = { workspace = true }
webpki = { workspace = true, features = ["alloc", "std"] }
rustls-pki-types = { workspace = true }
rustls-webpki = { workspace = true }
aes-gcm = { workspace = true }
p256 = { workspace = true, features = ["ecdh"] }
rand = { workspace = true }

View File

@@ -7,7 +7,6 @@ use crate::{
conn::{CommonState, ConnectionRandoms, State},
error::Error,
hash_hs::HandshakeHashBuffer,
msgs::persist,
ticketer::TimeBase,
};
use tls_core::{
@@ -42,35 +41,6 @@ pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
pub(super) type NextStateOrError = Result<NextState, Error>;
pub(super) type ClientContext<'a> = crate::conn::Context<'a>;
fn find_session(
server_name: &ServerName,
config: &ClientConfig,
) -> Option<persist::Retrieved<persist::ClientSessionValue>> {
let key = persist::ClientSessionKey::session_for_server_name(server_name);
let key_buf = key.get_encoding();
let value = config.session_storage.get(&key_buf).or_else(|| {
debug!("No cached session for {:?}", server_name);
None
})?;
#[allow(unused_mut)]
let mut reader = Reader::init(&value[2..]);
#[allow(clippy::bind_instead_of_map)] // https://github.com/rust-lang/rust-clippy/issues/8082
CipherSuite::read_bytes(&value[..2])
.and_then(|suite| {
persist::ClientSessionValue::read(&mut reader, suite, &config.cipher_suites)
})
.and_then(|resuming| {
let retrieved = persist::Retrieved::new(resuming, TimeBase::now().ok()?);
match retrieved.has_expired() {
false => Some(retrieved),
true => None,
}
})
.and_then(Some)
}
pub(super) async fn start_handshake(
server_name: ServerName,
extra_exts: Vec<ClientExtension>,
@@ -123,7 +93,6 @@ pub(super) async fn start_handshake(
emit_client_hello_for_retry(
config,
cx,
None,
random,
false,
transcript_buffer,
@@ -142,7 +111,6 @@ pub(super) async fn start_handshake(
struct ExpectServerHello {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
server_name: ServerName,
random: Random,
using_ems: bool,
@@ -162,7 +130,6 @@ struct ExpectServerHelloOrHelloRetryRequest {
async fn emit_client_hello_for_retry(
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
random: Random,
using_ems: bool,
mut transcript_buffer: HandshakeHashBuffer,
@@ -176,25 +143,6 @@ async fn emit_client_hello_for_retry(
may_send_sct_list: bool,
suite: Option<SupportedCipherSuite>,
) -> Result<NextState, Error> {
// For now we do not support session resumption
//
// Do we have a SessionID or ticket cached for this host?
// let (ticket, resume_version) = if let Some(resuming) = &resuming_session {
// match &resuming.value {
// persist::ClientSessionValue::Tls13(inner) => {
// (inner.ticket().to_vec(), ProtocolVersion::TLSv1_3)
// }
// #[cfg(feature = "tls12")]
// persist::ClientSessionValue::Tls12(inner) => {
// (inner.ticket().to_vec(), ProtocolVersion::TLSv1_2)
// }
// }
// } else {
// (Vec::new(), ProtocolVersion::Unknown(0))
// };
// let (ticket, resume_version) = (Vec::new(), ProtocolVersion::Unknown(0));
let support_tls12 = config.supports_version(ProtocolVersion::TLSv1_2);
let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3);
@@ -256,48 +204,6 @@ async fn emit_client_hello_for_retry(
// Extra extensions must be placed before the PSK extension
exts.extend(extra_exts.iter().cloned());
// let fill_in_binder = if support_tls13
// && config.enable_tickets
// && resume_version == ProtocolVersion::TLSv1_3
// && !ticket.is_empty()
// {
// let resuming =
// resuming_session
// .as_ref()
// .and_then(|resuming| match (suite, resuming.tls13()) {
// (Some(suite), Some(resuming)) => {
// suite.tls13()?.can_resume_from(resuming.suite())?;
// Some(resuming)
// }
// (None, Some(resuming)) => Some(resuming),
// _ => None,
// });
// if let Some(ref resuming) = resuming {
// tls13::prepare_resumption(
// &config,
// cx,
// ticket,
// &resuming,
// &mut exts,
// retryreq.is_some(),
// )
// .await;
// }
// resuming
// } else if config.enable_tickets {
// // If we have a ticket, include it. Otherwise, request one.
// if ticket.is_empty() {
// exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Request));
// } else {
// exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Offer(
// Payload::new(ticket),
// )));
// }
// None
// } else {
// None
// };
// Note what extensions we sent.
hello.sent_extensions = exts.iter().map(ClientExtension::get_type).collect();
@@ -319,8 +225,8 @@ async fn emit_client_hello_for_retry(
};
// let early_key_schedule = if let Some(resuming) = fill_in_binder {
// let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer, &mut chp);
// Some((resuming.suite(), schedule))
// let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer,
// &mut chp); Some((resuming.suite(), schedule))
// } else {
// None
// };
@@ -350,7 +256,6 @@ async fn emit_client_hello_for_retry(
let next = ExpectServerHello {
config,
resuming_session,
server_name,
random,
using_ems,
@@ -551,19 +456,10 @@ impl State<ClientConnectionData> for ExpectServerHello {
// handshake_traffic_secret.
match suite {
SupportedCipherSuite::Tls13(suite) => {
let resuming_session =
self.resuming_session
.and_then(|resuming| match resuming.value {
persist::ClientSessionValue::Tls13(inner) => Some(inner),
#[cfg(feature = "tls12")]
persist::ClientSessionValue::Tls12(_) => None,
});
tls13::handle_server_hello(
self.config,
cx,
server_hello,
resuming_session,
self.server_name,
randoms,
suite,
@@ -577,16 +473,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(suite) => {
let resuming_session =
self.resuming_session
.and_then(|resuming| match resuming.value {
persist::ClientSessionValue::Tls12(inner) => Some(inner),
persist::ClientSessionValue::Tls13(_) => None,
});
tls12::CompleteServerHelloHandling {
config: self.config,
resuming_session,
server_name: self.server_name,
randoms,
using_ems: self.using_ems,
@@ -723,7 +611,6 @@ impl ExpectServerHelloOrHelloRetryRequest {
emit_client_hello_for_retry(
self.next.config,
cx,
self.next.resuming_session,
self.next.random,
self.next.using_ems,
transcript_buffer,

View File

@@ -10,7 +10,6 @@ use crate::{
conn::{CommonState, ConnectionRandoms, State},
error::Error,
hash_hs::HandshakeHash,
msgs::persist,
sign::Signer,
ticketer::TimeBase,
verify,
@@ -49,7 +48,6 @@ mod server_hello {
pub(in crate::client) struct CompleteServerHelloHandling {
pub(in crate::client) config: Arc<ClientConfig>,
pub(in crate::client) resuming_session: Option<persist::Tls12ClientSessionValue>,
pub(in crate::client) server_name: ServerName,
pub(in crate::client) randoms: ConnectionRandoms,
pub(in crate::client) using_ems: bool,
@@ -113,76 +111,8 @@ mod server_hello {
None
};
// See if we're successfully resuming.
if let Some(ref _resuming) = self.resuming_session {
return Err(Error::General(
"client does not support resumption".to_string(),
));
// if resuming.session_id == server_hello.session_id {
// debug!("Server agreed to resume");
// // Is the server telling lies about the ciphersuite?
// if resuming.suite() != suite {
// let error_msg =
// "abbreviated handshake offered, but with varied cs".to_string();
// return Err(Error::PeerMisbehavedError(error_msg));
// }
// // And about EMS support?
// if resuming.extended_ms() != self.using_ems {
// let error_msg = "server varied ems support over resume".to_string();
// return Err(Error::PeerMisbehavedError(error_msg));
// }
// let secrets =
// ConnectionSecrets::new_resume(self.randoms, suite, resuming.secret());
// self.config.key_log.log(
// "CLIENT_RANDOM",
// &secrets.randoms.client,
// &secrets.master_secret,
// );
// cx.common.start_encryption_tls12(&secrets, Side::Client);
// // Since we're resuming, we verified the certificate and
// // proof of possession in the prior session.
// cx.common.peer_certificates = Some(resuming.server_cert_chain().to_vec());
// let cert_verified = verify::ServerCertVerified::assertion();
// let sig_verified = verify::HandshakeSignatureValid::assertion();
// return if must_issue_new_ticket {
// Ok(Box::new(ExpectNewTicket {
// config: self.config,
// secrets,
// resuming_session: self.resuming_session,
// session_id: server_hello.session_id,
// server_name: self.server_name,
// using_ems: self.using_ems,
// transcript: self.transcript,
// resuming: true,
// cert_verified,
// sig_verified,
// }))
// } else {
// Ok(Box::new(ExpectCcs {
// config: self.config,
// secrets,
// resuming_session: self.resuming_session,
// session_id: server_hello.session_id,
// server_name: self.server_name,
// using_ems: self.using_ems,
// transcript: self.transcript,
// ticket: None,
// resuming: true,
// cert_verified,
// sig_verified,
// }))
// };
// }
}
Ok(Box::new(ExpectCertificate {
config: self.config,
resuming_session: self.resuming_session,
session_id: server_hello.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -199,7 +129,6 @@ mod server_hello {
struct ExpectCertificate {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -228,7 +157,6 @@ impl State<ClientConnectionData> for ExpectCertificate {
if self.may_send_cert_status {
Ok(Box::new(ExpectCertificateStatusOrServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -250,7 +178,6 @@ impl State<ClientConnectionData> for ExpectCertificate {
Ok(Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -266,7 +193,6 @@ impl State<ClientConnectionData> for ExpectCertificate {
struct ExpectCertificateStatusOrServerKx {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -303,7 +229,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -322,7 +247,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
}) => {
Box::new(ExpectCertificateStatus {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -350,7 +274,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
struct ExpectCertificateStatus {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -395,7 +318,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatus {
Ok(Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -410,7 +332,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatus {
struct ExpectServerKx {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -458,7 +379,6 @@ impl State<ClientConnectionData> for ExpectServerKx {
Ok(Box::new(ExpectServerDoneOrCertReq {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -570,7 +490,6 @@ async fn emit_finished(
// client auth. Otherwise we go straight to ServerHelloDone.
struct ExpectServerDoneOrCertReq {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -598,7 +517,6 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
) {
Box::new(ExpectCertificateRequest {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -616,7 +534,6 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
Box::new(ExpectServerDone {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -636,7 +553,6 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
struct ExpectCertificateRequest {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -679,7 +595,6 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
Ok(Box::new(ExpectServerDone {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -696,7 +611,6 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
struct ExpectServerDone {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -745,6 +659,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
// 3. Verify that the top certificate signed their kx.
// 4. If doing client auth, send our Certificate.
// 5. Complete the key exchange:
//
// a) generate our kx pair
// b) emit a ClientKeyExchange containing it
// c) if doing client auth, emit a CertificateVerify
@@ -891,7 +806,6 @@ impl State<ClientConnectionData> for ExpectServerDone {
if st.must_issue_new_ticket {
Ok(Box::new(ExpectNewTicket {
config: st.config,
resuming_session: st.resuming_session,
session_id: st.session_id,
server_name: st.server_name,
using_ems: st.using_ems,
@@ -903,7 +817,6 @@ impl State<ClientConnectionData> for ExpectServerDone {
} else {
Ok(Box::new(ExpectCcs {
config: st.config,
resuming_session: st.resuming_session,
session_id: st.session_id,
server_name: st.server_name,
using_ems: st.using_ems,
@@ -919,7 +832,6 @@ impl State<ClientConnectionData> for ExpectServerDone {
struct ExpectNewTicket {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
using_ems: bool,
@@ -946,7 +858,6 @@ impl State<ClientConnectionData> for ExpectNewTicket {
Ok(Box::new(ExpectCcs {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
using_ems: self.using_ems,
@@ -962,7 +873,6 @@ impl State<ClientConnectionData> for ExpectNewTicket {
// -- Waiting for their CCS --
struct ExpectCcs {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
using_ems: bool,
@@ -998,7 +908,6 @@ impl State<ClientConnectionData> for ExpectCcs {
Ok(Box::new(ExpectFinished {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
using_ems: self.using_ems,
@@ -1013,7 +922,6 @@ impl State<ClientConnectionData> for ExpectCcs {
struct ExpectFinished {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
using_ems: bool,
@@ -1024,60 +932,6 @@ struct ExpectFinished {
sig_verified: verify::HandshakeSignatureValid,
}
// impl ExpectFinished {
// // -- Waiting for their finished --
// fn save_session(&mut self, cx: &mut ClientContext<'_>) {
// // Save a ticket. If we got a new ticket, save that. Otherwise, save the
// // original ticket again.
// let (mut ticket, lifetime) = match self.ticket.take() {
// Some(nst) => (nst.ticket.0, nst.lifetime_hint),
// None => (Vec::new(), 0),
// };
// if ticket.is_empty() {
// if let Some(resuming_session) = &mut self.resuming_session {
// ticket = resuming_session.take_ticket();
// }
// }
// if self.session_id.is_empty() && ticket.is_empty() {
// debug!("Session not saved: server didn't allocate id or ticket");
// return;
// }
// let time_now = match TimeBase::now() {
// Ok(time_now) => time_now,
// Err(e) => {
// debug!("Session not saved: {}", e);
// return;
// }
// };
// let key = persist::ClientSessionKey::session_for_server_name(&self.server_name);
// let value = persist::Tls12ClientSessionValue::new(
// self.secrets.suite(),
// self.session_id,
// ticket,
// self.secrets.get_master_secret(),
// cx.common.peer_certificates.clone().unwrap_or_default(),
// time_now,
// lifetime,
// self.using_ems,
// );
// let worked = self
// .config
// .session_storage
// .put(key.get_encoding(), value.get_encoding());
// if worked {
// debug!("Session saved");
// } else {
// debug!("Session not saved");
// }
// }
// }
#[async_trait]
impl State<ClientConnectionData> for ExpectFinished {
async fn handle(

View File

@@ -11,7 +11,6 @@ use crate::{
conn::{CommonState, ConnectionRandoms, State},
error::Error,
hash_hs::{HandshakeHash, HandshakeHashBuffer},
msgs::persist,
sign, verify, KeyLog,
};
#[allow(deprecated)]
@@ -60,7 +59,6 @@ pub(super) async fn handle_server_hello(
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
server_hello: &ServerHelloPayload,
resuming_session: Option<persist::Tls13ClientSessionValue>,
server_name: ServerName,
randoms: ConnectionRandoms,
suite: &'static Tls13CipherSuite,
@@ -102,8 +100,8 @@ pub(super) async fn handle_server_hello(
// };
// if server_hello.get_psk_index() != Some(0) {
// return Err(cx.common.illegal_param("server selected invalid psk").await);
// }
// return Err(cx.common.illegal_param("server selected invalid
// psk").await); }
// debug!("Resuming using PSK");
// // The key schedule has been initialized and set in fill_in_psk_binder()
@@ -143,7 +141,6 @@ pub(super) async fn handle_server_hello(
Ok(Box::new(ExpectEncryptedExtensions {
config,
resuming_session,
server_name,
randoms,
suite,
@@ -170,69 +167,6 @@ async fn validate_server_hello(
Ok(())
}
// fn save_kx_hint(config: &ClientConfig, server_name: &ServerName, group: NamedGroup) {
// let key = persist::ClientSessionKey::hint_for_server_name(server_name);
// config
// .session_storage
// .put(key.get_encoding(), group.get_encoding());
// }
// /// This implements the horrifying TLS1.3 hack where PSK binders have a
// /// data dependency on the message they are contained within.
// pub(super) fn fill_in_psk_binder(
// resuming: &persist::Tls13ClientSessionValue,
// transcript: &HandshakeHashBuffer,
// hmp: &mut HandshakeMessagePayload,
// ) -> KeyScheduleEarly {
// // We need to know the hash function of the suite we're trying to resume into.
// let hkdf_alg = &resuming.suite().hkdf_algorithm;
// let suite_hash = resuming.suite().hash_algorithm();
// // The binder is calculated over the clienthello, but doesn't include itself or its
// // length, or the length of its container.
// let binder_plaintext = hmp.get_encoding_for_binder_signing();
// let handshake_hash = transcript.get_hash_given(suite_hash, &binder_plaintext);
// // Run a fake key_schedule to simulate what the server will do if it chooses
// // to resume.
// let key_schedule = KeyScheduleEarly::new(hkdf_alg, resuming.secret());
// let real_binder = key_schedule.resumption_psk_binder_key_and_sign_verify_data(&handshake_hash);
// if let HandshakePayload::ClientHello(ref mut ch) = hmp.payload {
// ch.set_psk_binder(real_binder.as_ref());
// };
// key_schedule
// }
// pub(super) async fn prepare_resumption(
// config: &ClientConfig,
// cx: &mut ClientContext<'_>,
// ticket: Vec<u8>,
// resuming_session: &persist::Retrieved<&persist::Tls13ClientSessionValue>,
// exts: &mut Vec<ClientExtension>,
// doing_retry: bool,
// ) {
// let resuming_suite = resuming_session.suite();
// cx.common.suite = Some(resuming_suite.into());
// cx.data.resumption_ciphersuite = Some(resuming_suite.into());
// // Finally, and only for TLS1.3 with a ticket resumption, include a binder
// // for our ticket. This must go last.
// //
// // Include an empty binder. It gets filled in below because it depends on
// // the message it's contained in (!!!).
// let obfuscated_ticket_age = resuming_session.obfuscated_ticket_age();
// let binder_len = resuming_suite.hash_algorithm().output_len();
// let binder = vec![0u8; binder_len];
// let psk_identity = PresharedKeyIdentity::new(ticket, obfuscated_ticket_age);
// let psk_ext = PresharedKeyOffer::new(psk_identity, binder);
// exts.push(ClientExtension::PresharedKey(psk_ext));
// }
pub(super) async fn emit_fake_ccs(
sent_tls13_fake_ccs: &mut bool,
common: &mut CommonState,
@@ -287,7 +221,6 @@ async fn validate_encrypted_extensions(
struct ExpectEncryptedExtensions {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls13ClientSessionValue>,
server_name: ServerName,
randoms: ConnectionRandoms,
suite: &'static Tls13CipherSuite,
@@ -313,52 +246,19 @@ impl State<ClientConnectionData> for ExpectEncryptedExtensions {
validate_encrypted_extensions(cx.common, &self.hello, exts).await?;
hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol()).await?;
if let Some(resuming_session) = self.resuming_session {
let was_early_traffic = cx.common.early_traffic;
if was_early_traffic {
if exts.early_data_extension_offered() {
cx.data.early_data.accepted();
} else {
cx.data.early_data.rejected();
cx.common.early_traffic = false;
}
}
if was_early_traffic && !cx.common.early_traffic {
// If no early traffic, set the encryption key for handshakes
cx.common.record_layer.set_message_encrypter();
}
cx.common.peer_certificates = Some(resuming_session.server_cert_chain().to_vec());
// We *don't* reverify the certificate chain here: resumption is a
// continuation of the previous session in terms of security policy.
let cert_verified = verify::ServerCertVerified::assertion();
let sig_verified = verify::HandshakeSignatureValid::assertion();
Ok(Box::new(ExpectFinished {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
client_auth: None,
cert_verified,
sig_verified,
}))
} else {
if exts.early_data_extension_offered() {
let msg = "server sent early data extension without resumption".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
Ok(Box::new(ExpectCertificateOrCertReq {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
may_send_sct_list: self.hello.server_may_send_sct_list(),
}))
if exts.early_data_extension_offered() {
let msg = "server sent early data extension without resumption".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
Ok(Box::new(ExpectCertificateOrCertReq {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
may_send_sct_list: self.hello.server_may_send_sct_list(),
}))
}
}
@@ -422,9 +322,9 @@ impl State<ClientConnectionData> for ExpectCertificateOrCertReq {
}
}
// TLS1.3 version of CertificateRequest handling. We then move to expecting the server
// Certificate. Unfortunately the CertificateRequest type changed in an annoying way
// in TLS1.3.
// TLS1.3 version of CertificateRequest handling. We then move to expecting the
// server Certificate. Unfortunately the CertificateRequest type changed in an
// annoying way in TLS1.3.
struct ExpectCertificateRequest {
config: Arc<ClientConfig>,
server_name: ServerName,
@@ -787,8 +687,8 @@ impl State<ClientConnectionData> for ExpectFinished {
st.transcript.add_message(&m);
/* The EndOfEarlyData message to server is still encrypted with early data keys,
* but appears in the transcript after the server Finished. */
/* The EndOfEarlyData message to server is still encrypted with early data
* keys, but appears in the transcript after the server Finished. */
if cx.common.early_traffic {
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await?;
cx.common.early_traffic = false;
@@ -893,42 +793,6 @@ impl ExpectTraffic {
));
}
// let handshake_hash = self.transcript.get_current_hash();
// let secret = self
// .key_schedule
// .resumption_master_secret_and_derive_ticket_psk(&handshake_hash, &nst.nonce.0);
// let time_now = match TimeBase::now() {
// Ok(t) => t,
// #[allow(unused_variables)]
// Err(e) => {
// debug!("Session not saved: {}", e);
// return Ok(());
// }
// };
// let value = persist::Tls13ClientSessionValue::new(
// self.suite,
// nst.ticket.0.clone(),
// secret,
// cx.common.peer_certificates.clone().unwrap_or_default(),
// time_now,
// nst.lifetime,
// nst.age_add,
// nst.get_max_early_data_size().unwrap_or_default(),
// );
// let key = persist::ClientSessionKey::session_for_server_name(&self.server_name);
// #[allow(unused_mut)]
// let mut ticket = value.get_encoding();
// let worked = self.session_storage.put(key.get_encoding(), ticket);
// if worked {
// debug!("Ticket saved");
// } else {
// debug!("Ticket not saved");
// }
Ok(())
}
@@ -948,27 +812,6 @@ impl ExpectTraffic {
Err(Error::General(
"received unsupported key update request from peer".to_string(),
))
// match kur {
// KeyUpdateRequest::UpdateNotRequested => {}
// KeyUpdateRequest::UpdateRequested => {
// self.want_write_key_update = true;
// }
// _ => {
// common
// .send_fatal_alert(AlertDescription::IllegalParameter)
// .await;
// return Err(Error::CorruptMessagePayload(ContentType::Handshake));
// }
// }
// // Update our read-side keys.
// let new_read_key = self.key_schedule.next_server_application_traffic_secret();
// common
// .record_layer
// .set_message_decrypter(self.suite.derive_decrypter(&new_read_key));
// Ok(())
}
}
@@ -1022,10 +865,11 @@ impl State<ClientConnectionData> for ExpectTraffic {
// .send_msg_encrypt(Message::build_key_update_notify().into())
// .await;
// let write_key = self.key_schedule.next_client_application_traffic_secret();
// let write_key =
// self.key_schedule.next_client_application_traffic_secret();
// common
// .record_layer
// .set_message_encrypter(self.suite.derive_encrypter(&write_key));
// }
// .set_message_encrypter(self.suite.derive_encrypter(&
// write_key)); }
}
}

View File

@@ -301,15 +301,11 @@ mod conn;
mod error;
mod hash_hs;
mod limited_cache;
mod msgs;
mod rand;
mod record_layer;
//mod stream;
mod vecbuf;
pub(crate) use tls_core::verify;
#[cfg(test)]
mod verifybench;
pub(crate) use tls_core::x509;
pub(crate) use tls_core::{verify, x509};
#[macro_use]
mod check;
mod bs_debug;
@@ -330,7 +326,7 @@ pub mod internal {
// The public interface is:
pub use crate::{
anchors::{OwnedTrustAnchor, RootCertStore},
anchors::RootCertStore,
builder::{ConfigBuilder, WantsCipherSuites, WantsKxGroups, WantsVerifier, WantsVersions},
conn::{CommonState, ConnectionCommon, IoState, Reader, SideData},
error::Error,

View File

@@ -1,4 +0,0 @@
pub(crate) mod persist;
#[cfg(test)]
mod persist_test;

View File

@@ -1,526 +0,0 @@
use crate::{client::ServerName, ticketer::TimeBase};
use std::cmp;
#[cfg(feature = "tls12")]
use std::mem;
#[cfg(feature = "tls12")]
use tls_core::suites::Tls12CipherSuite;
use tls_core::{
msgs::{
base::{PayloadU16, PayloadU8},
codec::{Codec, Reader},
enums::{CipherSuite, ProtocolVersion},
handshake::{CertificatePayload, SessionID},
},
suites::{SupportedCipherSuite, Tls13CipherSuite},
};
// These are the keys and values we store in session storage.
// --- Client types ---
/// Keys for session resumption and tickets.
/// Matching value is a `ClientSessionValue`.
#[derive(Debug)]
pub struct ClientSessionKey {
kind: &'static [u8],
name: Vec<u8>,
}
impl Codec for ClientSessionKey {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(self.kind);
bytes.extend_from_slice(&self.name);
}
// Don't need to read these.
fn read(_r: &mut Reader) -> Option<Self> {
None
}
}
impl ClientSessionKey {
pub fn session_for_server_name(server_name: &ServerName) -> Self {
Self {
kind: b"session",
name: server_name.encode(),
}
}
pub fn hint_for_server_name(server_name: &ServerName) -> Self {
Self {
kind: b"kx-hint",
name: server_name.encode(),
}
}
}
#[derive(Debug)]
pub enum ClientSessionValue {
Tls13(Tls13ClientSessionValue),
#[cfg(feature = "tls12")]
Tls12(Tls12ClientSessionValue),
}
impl ClientSessionValue {
pub fn read(
reader: &mut Reader<'_>,
suite: CipherSuite,
supported: &[SupportedCipherSuite],
) -> Option<Self> {
match supported.iter().find(|s| s.suite() == suite)? {
SupportedCipherSuite::Tls13(inner) => {
Tls13ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls13)
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(inner) => {
Tls12ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls12)
}
}
}
fn common(&self) -> &ClientSessionCommon {
match self {
Self::Tls13(inner) => &inner.common,
#[cfg(feature = "tls12")]
Self::Tls12(inner) => &inner.common,
}
}
}
impl From<Tls13ClientSessionValue> for ClientSessionValue {
fn from(v: Tls13ClientSessionValue) -> Self {
Self::Tls13(v)
}
}
#[cfg(feature = "tls12")]
impl From<Tls12ClientSessionValue> for ClientSessionValue {
fn from(v: Tls12ClientSessionValue) -> Self {
Self::Tls12(v)
}
}
pub struct Retrieved<T> {
pub value: T,
retrieved_at: TimeBase,
}
impl<T> Retrieved<T> {
pub fn new(value: T, retrieved_at: TimeBase) -> Self {
Self {
value,
retrieved_at,
}
}
}
impl Retrieved<&Tls13ClientSessionValue> {
pub fn obfuscated_ticket_age(&self) -> u32 {
let age_secs = self
.retrieved_at
.as_secs()
.saturating_sub(self.value.common.epoch);
let age_millis = age_secs as u32 * 1000;
age_millis.wrapping_add(self.value.age_add)
}
}
impl Retrieved<ClientSessionValue> {
pub fn tls13(&self) -> Option<Retrieved<&Tls13ClientSessionValue>> {
match &self.value {
ClientSessionValue::Tls13(value) => Some(Retrieved::new(value, self.retrieved_at)),
#[cfg(feature = "tls12")]
ClientSessionValue::Tls12(_) => None,
}
}
pub fn has_expired(&self) -> bool {
let common = self.value.common();
common.lifetime_secs != 0
&& common.epoch + u64::from(common.lifetime_secs) < self.retrieved_at.as_secs()
}
}
impl<T> std::ops::Deref for Retrieved<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
#[derive(Debug)]
pub struct Tls13ClientSessionValue {
suite: &'static Tls13CipherSuite,
age_add: u32,
max_early_data_size: u32,
pub common: ClientSessionCommon,
}
impl Tls13ClientSessionValue {
pub fn new(
suite: &'static Tls13CipherSuite,
ticket: Vec<u8>,
secret: Vec<u8>,
server_cert_chain: Vec<tls_core::key::Certificate>,
time_now: TimeBase,
lifetime_secs: u32,
age_add: u32,
max_early_data_size: u32,
) -> Self {
Self {
suite,
age_add,
max_early_data_size,
common: ClientSessionCommon::new(
ticket,
secret,
time_now,
lifetime_secs,
server_cert_chain,
),
}
}
/// [`Codec::read()`] with an extra `suite` argument.
///
/// We decode the `suite` argument separately because it allows us to
/// decide whether we're decoding an 1.2 or 1.3 session value.
pub fn read(suite: &'static Tls13CipherSuite, r: &mut Reader) -> Option<Self> {
Some(Self {
suite,
age_add: u32::read(r)?,
max_early_data_size: u32::read(r)?,
common: ClientSessionCommon::read(r)?,
})
}
/// Inherent implementation of the [`Codec::get_encoding()`] method.
///
/// (See `read()` for why this is inherent here.)
pub fn get_encoding(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
self.suite.common.suite.encode(&mut bytes);
self.age_add.encode(&mut bytes);
self.max_early_data_size.encode(&mut bytes);
self.common.encode(&mut bytes);
bytes
}
pub fn max_early_data_size(&self) -> u32 {
self.max_early_data_size
}
pub fn suite(&self) -> &'static Tls13CipherSuite {
self.suite
}
}
impl std::ops::Deref for Tls13ClientSessionValue {
type Target = ClientSessionCommon;
fn deref(&self) -> &Self::Target {
&self.common
}
}
#[cfg(feature = "tls12")]
#[derive(Debug)]
pub struct Tls12ClientSessionValue {
suite: &'static Tls12CipherSuite,
pub session_id: SessionID,
extended_ms: bool,
pub common: ClientSessionCommon,
}
#[cfg(feature = "tls12")]
impl Tls12ClientSessionValue {
pub fn new(
suite: &'static Tls12CipherSuite,
session_id: SessionID,
ticket: Vec<u8>,
master_secret: Vec<u8>,
server_cert_chain: Vec<tls_core::key::Certificate>,
time_now: TimeBase,
lifetime_secs: u32,
extended_ms: bool,
) -> Self {
Self {
suite,
session_id,
extended_ms,
common: ClientSessionCommon::new(
ticket,
master_secret,
time_now,
lifetime_secs,
server_cert_chain,
),
}
}
/// [`Codec::read()`] with an extra `suite` argument.
///
/// We decode the `suite` argument separately because it allows us to
/// decide whether we're decoding an 1.2 or 1.3 session value.
fn read(suite: &'static Tls12CipherSuite, r: &mut Reader) -> Option<Self> {
Some(Self {
suite,
session_id: SessionID::read(r)?,
extended_ms: u8::read(r)? == 1,
common: ClientSessionCommon::read(r)?,
})
}
/// Inherent implementation of the [`Codec::get_encoding()`] method.
///
/// (See `read()` for why this is inherent here.)
pub fn get_encoding(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
self.suite.common.suite.encode(&mut bytes);
self.session_id.encode(&mut bytes);
(if self.extended_ms { 1u8 } else { 0u8 }).encode(&mut bytes);
self.common.encode(&mut bytes);
bytes
}
pub fn take_ticket(&mut self) -> Vec<u8> {
mem::take(&mut self.common.ticket.0)
}
pub fn extended_ms(&self) -> bool {
self.extended_ms
}
pub fn suite(&self) -> &'static Tls12CipherSuite {
self.suite
}
}
#[cfg(feature = "tls12")]
impl std::ops::Deref for Tls12ClientSessionValue {
type Target = ClientSessionCommon;
fn deref(&self) -> &Self::Target {
&self.common
}
}
#[derive(Debug)]
pub struct ClientSessionCommon {
ticket: PayloadU16,
secret: PayloadU8,
epoch: u64,
lifetime_secs: u32,
server_cert_chain: CertificatePayload,
}
impl ClientSessionCommon {
fn new(
ticket: Vec<u8>,
secret: Vec<u8>,
time_now: TimeBase,
lifetime_secs: u32,
server_cert_chain: Vec<tls_core::key::Certificate>,
) -> Self {
Self {
ticket: PayloadU16(ticket),
secret: PayloadU8(secret),
epoch: time_now.as_secs(),
lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
server_cert_chain,
}
}
/// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`]
/// implementation through [`Deref`] implementations on
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
fn read(r: &mut Reader) -> Option<Self> {
Some(Self {
ticket: PayloadU16::read(r)?,
secret: PayloadU8::read(r)?,
epoch: u64::read(r)?,
lifetime_secs: u32::read(r)?,
server_cert_chain: CertificatePayload::read(r)?,
})
}
/// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`]
/// implementation through [`Deref`] implementations on
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
fn encode(&self, bytes: &mut Vec<u8>) {
self.ticket.encode(bytes);
self.secret.encode(bytes);
self.epoch.encode(bytes);
self.lifetime_secs.encode(bytes);
self.server_cert_chain.encode(bytes);
}
pub fn server_cert_chain(&self) -> &[tls_core::key::Certificate] {
self.server_cert_chain.as_ref()
}
pub fn secret(&self) -> &[u8] {
self.secret.0.as_ref()
}
pub fn ticket(&self) -> &[u8] {
self.ticket.0.as_ref()
}
/// Test only: wind back epoch by delta seconds.
pub fn rewind_epoch(&mut self, delta: u32) {
self.epoch -= delta as u64;
}
}
static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
/// This is the maximum allowed skew between server and client clocks, over
/// the maximum ticket lifetime period. This encompasses TCP retransmission
/// times in case packet loss occurs when the client sends the ClientHello
/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
// --- Server types ---
pub type ServerSessionKey = SessionID;
#[derive(Debug)]
pub struct ServerSessionValue {
pub sni: Option<webpki::DnsName>,
pub version: ProtocolVersion,
pub cipher_suite: CipherSuite,
pub master_secret: PayloadU8,
pub extended_ms: bool,
pub client_cert_chain: Option<CertificatePayload>,
pub alpn: Option<PayloadU8>,
pub application_data: PayloadU16,
pub creation_time_sec: u64,
pub age_obfuscation_offset: u32,
freshness: Option<bool>,
}
impl Codec for ServerSessionValue {
fn encode(&self, bytes: &mut Vec<u8>) {
if let Some(ref sni) = self.sni {
1u8.encode(bytes);
let sni_bytes: &str = sni.as_ref().into();
PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
} else {
0u8.encode(bytes);
}
self.version.encode(bytes);
self.cipher_suite.encode(bytes);
self.master_secret.encode(bytes);
(if self.extended_ms { 1u8 } else { 0u8 }).encode(bytes);
if let Some(ref chain) = self.client_cert_chain {
1u8.encode(bytes);
chain.encode(bytes);
} else {
0u8.encode(bytes);
}
if let Some(ref alpn) = self.alpn {
1u8.encode(bytes);
alpn.encode(bytes);
} else {
0u8.encode(bytes);
}
self.application_data.encode(bytes);
self.creation_time_sec.encode(bytes);
self.age_obfuscation_offset.encode(bytes);
}
fn read(r: &mut Reader) -> Option<Self> {
let has_sni = u8::read(r)?;
let sni = if has_sni == 1 {
let dns_name = PayloadU8::read(r)?;
let dns_name = webpki::DnsNameRef::try_from_ascii(&dns_name.0).ok()?;
Some(dns_name.into())
} else {
None
};
let v = ProtocolVersion::read(r)?;
let cs = CipherSuite::read(r)?;
let ms = PayloadU8::read(r)?;
let ems = u8::read(r)?;
let has_ccert = u8::read(r)? == 1;
let ccert = if has_ccert {
Some(CertificatePayload::read(r)?)
} else {
None
};
let has_alpn = u8::read(r)? == 1;
let alpn = if has_alpn {
Some(PayloadU8::read(r)?)
} else {
None
};
let application_data = PayloadU16::read(r)?;
let creation_time_sec = u64::read(r)?;
let age_obfuscation_offset = u32::read(r)?;
Some(Self {
sni,
version: v,
cipher_suite: cs,
master_secret: ms,
extended_ms: ems == 1u8,
client_cert_chain: ccert,
alpn,
application_data,
creation_time_sec,
age_obfuscation_offset,
freshness: None,
})
}
}
impl ServerSessionValue {
pub fn new(
sni: Option<&webpki::DnsName>,
v: ProtocolVersion,
cs: CipherSuite,
ms: Vec<u8>,
client_cert_chain: Option<CertificatePayload>,
alpn: Option<Vec<u8>>,
application_data: Vec<u8>,
creation_time: TimeBase,
age_obfuscation_offset: u32,
) -> Self {
Self {
sni: sni.cloned(),
version: v,
cipher_suite: cs,
master_secret: PayloadU8::new(ms),
extended_ms: false,
client_cert_chain,
alpn: alpn.map(PayloadU8::new),
application_data: PayloadU16::new(application_data),
creation_time_sec: creation_time.as_secs(),
age_obfuscation_offset,
freshness: None,
}
}
pub fn set_extended_ms_used(&mut self) {
self.extended_ms = true;
}
pub fn set_freshness(mut self, obfuscated_client_age_ms: u32, time_now: TimeBase) -> Self {
let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
let server_age_ms =
(time_now.as_secs().saturating_sub(self.creation_time_sec) as u32).saturating_mul(1000);
let age_difference = if client_age_ms < server_age_ms {
server_age_ms - client_age_ms
} else {
client_age_ms - server_age_ms
};
self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
self
}
pub fn is_fresh(&self) -> bool {
self.freshness.unwrap_or_default()
}
}

View File

@@ -1,78 +0,0 @@
use super::persist::*;
use crate::ticketer::TimeBase;
use std::convert::TryInto;
use tls_core::{
key::Certificate,
msgs::{
codec::{Codec, Reader},
enums::*,
},
suites::TLS13_AES_128_GCM_SHA256,
};
#[test]
fn clientsessionkey_is_debug() {
let name = "hello".try_into().unwrap();
let csk = ClientSessionKey::session_for_server_name(&name);
println!("{:?}", csk);
}
#[test]
fn clientsessionkey_cannot_be_read() {
let bytes = [0; 1];
let mut rd = Reader::init(&bytes);
assert!(ClientSessionKey::read(&mut rd).is_none());
}
#[test]
fn clientsessionvalue_is_debug() {
let csv = ClientSessionValue::from(Tls13ClientSessionValue::new(
TLS13_AES_128_GCM_SHA256.tls13().unwrap(),
vec![],
vec![1, 2, 3],
vec![Certificate(b"abc".to_vec()), Certificate(b"def".to_vec())],
TimeBase::now().unwrap(),
15,
10,
128,
));
println!("{:?}", csv);
}
#[test]
fn serversessionvalue_is_debug() {
let ssv = ServerSessionValue::new(
None,
ProtocolVersion::TLSv1_3,
CipherSuite::TLS13_AES_128_GCM_SHA256,
vec![1, 2, 3],
None,
None,
vec![4, 5, 6],
TimeBase::now().unwrap(),
0x12345678,
);
println!("{:?}", ssv);
}
#[test]
fn serversessionvalue_no_sni() {
let bytes = [
0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12,
0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
];
let mut rd = Reader::init(&bytes);
let ssv = ServerSessionValue::read(&mut rd).unwrap();
assert_eq!(ssv.get_encoding(), bytes);
}
#[test]
fn serversessionvalue_with_cert() {
let bytes = [
0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12,
0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
];
let mut rd = Reader::init(&bytes);
let ssv = ServerSessionValue::read(&mut rd).unwrap();
assert_eq!(ssv.get_encoding(), bytes);
}

View File

@@ -6,6 +6,7 @@ use ring::{
io::der,
signature::{self, EcdsaKeyPair, Ed25519KeyPair, RsaKeyPair},
};
use rustls_pki_types as pki_types;
use std::{convert::TryFrom, error::Error as StdError, fmt, sync::Arc};
use tls_core::{
key,
@@ -71,51 +72,6 @@ impl CertifiedKey {
pub fn end_entity_cert(&self) -> Result<&key::Certificate, SignError> {
self.cert.first().ok_or(SignError(()))
}
/// Check the certificate chain for validity:
/// - it should be non-empty list
/// - the first certificate should be parsable as a x509v3,
/// - the first certificate should quote the given server name
/// (if provided)
///
/// These checks are not security-sensitive. They are the
/// *server* attempting to detect accidental misconfiguration.
pub(crate) fn cross_check_end_entity_cert(
&self,
name: Option<webpki::DnsNameRef>,
) -> Result<(), Error> {
// Always reject an empty certificate chain.
let end_entity_cert = self.end_entity_cert().map_err(|SignError(())| {
Error::General("No end-entity certificate in certificate chain".to_string())
})?;
// Reject syntactically-invalid end-entity certificates.
let end_entity_cert =
webpki::EndEntityCert::try_from(end_entity_cert.as_ref()).map_err(|_| {
Error::General(
"End-entity certificate in certificate \
chain is syntactically invalid"
.to_string(),
)
})?;
if let Some(name) = name {
// If SNI was offered then the certificate must be valid for
// that hostname. Note that this doesn't fully validate that the
// certificate is valid; it only validates that the name is one
// that the certificate is valid for, if the certificate is
// valid.
if end_entity_cert.verify_is_valid_for_dns_name(name).is_err() {
return Err(Error::General(
"The server certificate is not \
valid for the given name"
.to_string(),
));
}
}
Ok(())
}
}
/// Parse `der` as any supported key encoding/type, returning

View File

@@ -1,223 +0,0 @@
// This program does benchmarking of the functions in verify.rs,
// that do certificate chain validation and signature verification.
//
// Note: we don't use any of the standard 'cargo bench', 'test::Bencher',
// etc. because it's unstable at the time of writing.
use crate::{anchors, verify, verify::ServerCertVerifier, OwnedTrustAnchor};
use std::convert::TryInto;
use web_time::{Duration, Instant, SystemTime};
use webpki_roots;
fn duration_nanos(d: Duration) -> u64 {
((d.as_secs() as f64) * 1e9 + (d.subsec_nanos() as f64)) as u64
}
#[test]
fn test_reddit_cert() {
Context::new(
"reddit",
"reddit.com",
&[
include_bytes!("testdata/cert-reddit.0.der"),
include_bytes!("testdata/cert-reddit.1.der"),
],
)
.bench(100)
}
#[test]
fn test_github_cert() {
Context::new(
"github",
"github.com",
&[
include_bytes!("testdata/cert-github.0.der"),
include_bytes!("testdata/cert-github.1.der"),
],
)
.bench(100)
}
#[test]
fn test_arstechnica_cert() {
Context::new(
"arstechnica",
"arstechnica.com",
&[
include_bytes!("testdata/cert-arstechnica.0.der"),
include_bytes!("testdata/cert-arstechnica.1.der"),
include_bytes!("testdata/cert-arstechnica.2.der"),
include_bytes!("testdata/cert-arstechnica.3.der"),
],
)
.bench(100)
}
#[test]
fn test_twitter_cert() {
Context::new(
"twitter",
"twitter.com",
&[
include_bytes!("testdata/cert-twitter.0.der"),
include_bytes!("testdata/cert-twitter.1.der"),
],
)
.bench(100)
}
#[test]
fn test_wikipedia_cert() {
Context::new(
"wikipedia",
"wikipedia.org",
&[
include_bytes!("testdata/cert-wikipedia.0.der"),
include_bytes!("testdata/cert-wikipedia.1.der"),
],
)
.bench(100)
}
#[test]
fn test_google_cert() {
Context::new(
"google",
"www.google.com",
&[
include_bytes!("testdata/cert-google.0.der"),
include_bytes!("testdata/cert-google.1.der"),
],
)
.bench(100)
}
#[test]
fn test_hn_cert() {
Context::new(
"hn",
"news.ycombinator.com",
&[
include_bytes!("testdata/cert-hn.0.der"),
include_bytes!("testdata/cert-hn.1.der"),
],
)
.bench(100)
}
#[test]
fn test_stackoverflow_cert() {
Context::new(
"stackoverflow",
"stackoverflow.com",
&[
include_bytes!("testdata/cert-stackoverflow.0.der"),
include_bytes!("testdata/cert-stackoverflow.1.der"),
],
)
.bench(100)
}
#[test]
fn test_duckduckgo_cert() {
Context::new(
"duckduckgo",
"duckduckgo.com",
&[
include_bytes!("testdata/cert-duckduckgo.0.der"),
include_bytes!("testdata/cert-duckduckgo.1.der"),
],
)
.bench(100)
}
#[test]
fn test_rustlang_cert() {
Context::new(
"rustlang",
"www.rust-lang.org",
&[
include_bytes!("testdata/cert-rustlang.0.der"),
include_bytes!("testdata/cert-rustlang.1.der"),
include_bytes!("testdata/cert-rustlang.2.der"),
],
)
.bench(100)
}
#[test]
fn test_wapo_cert() {
Context::new(
"wapo",
"www.washingtonpost.com",
&[
include_bytes!("testdata/cert-wapo.0.der"),
include_bytes!("testdata/cert-wapo.1.der"),
],
)
.bench(100)
}
struct Context {
name: &'static str,
domain: &'static str,
roots: anchors::RootCertStore,
chain: Vec<tls_core::key::Certificate>,
now: SystemTime,
}
impl Context {
fn new(name: &'static str, domain: &'static str, certs: &[&'static [u8]]) -> Self {
let mut roots = anchors::RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
Self {
name,
domain,
roots,
chain: certs
.iter()
.copied()
.map(|bytes| tls_core::key::Certificate(bytes.to_vec()))
.collect(),
now: SystemTime::UNIX_EPOCH + Duration::from_secs(1640870720),
}
}
fn bench(&self, count: usize) {
let verifier = verify::WebPkiVerifier::new(self.roots.clone(), None);
const SCTS: &[&[u8]] = &[];
const OCSP_RESPONSE: &[u8] = &[];
let mut times = Vec::new();
let (end_entity, intermediates) = self.chain.split_first().unwrap();
for _ in 0..count {
let start = Instant::now();
let server_name = self.domain.try_into().unwrap();
verifier
.verify_server_cert(
end_entity,
intermediates,
&server_name,
&mut SCTS.iter().copied(),
OCSP_RESPONSE,
self.now,
)
.unwrap();
times.push(duration_nanos(Instant::now().duration_since(start)));
}
println!(
"verify_server_cert({}): min {:?}us",
self.name,
times.iter().min().unwrap() / 1000
);
}
}

View File

@@ -780,14 +780,12 @@ async fn client_checks_server_certificate_with_given_name() {
let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap();
let err = do_handshake_until_error(&mut client, &mut server).await;
assert_eq!(
assert!(matches!(
err,
Err(ErrorFromPeer::Client(Error::CoreError(
tls_core::Error::InvalidCertificateData(
"invalid peer certificate: CertNotValidForName".into(),
)
tls_core::Error::InvalidCertificateData(_)
)))
);
));
}
}
}

View File

@@ -2,6 +2,7 @@
use futures::{AsyncRead, AsyncWrite};
use rustls::{server::AllowAnyAuthenticatedClient, ServerConfig, ServerConnection};
use rustls_pki_types::CertificateDer;
use std::{
convert::{TryFrom, TryInto},
io,
@@ -15,6 +16,7 @@ use tls_client::{
Certificate, ClientConfig, ClientConnection, Error, PrivateKey, RootCertStore,
RustCryptoBackend,
};
use webpki::anchor_from_trusted_cert;
macro_rules! embed_files {
(
@@ -409,9 +411,17 @@ pub fn finish_client_config(
kt: KeyType,
config: tls_client::ConfigBuilder<tls_client::WantsVerifier>,
) -> ClientConfig {
let mut root_store = RootCertStore::empty();
let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap());
let roots = rustls_pemfile::certs(&mut rootbuf)
.unwrap()
.into_iter()
.map(|cert| {
let der = CertificateDer::from_slice(&cert);
anchor_from_trusted_cert(&der).unwrap().to_owned()
})
.collect();
let root_store = RootCertStore { roots };
config
.with_root_certificates(root_store)
@@ -422,9 +432,17 @@ pub fn finish_client_config_with_creds(
kt: KeyType,
config: tls_client::ConfigBuilder<tls_client::WantsVerifier>,
) -> ClientConfig {
let mut root_store = RootCertStore::empty();
let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap());
let roots = rustls_pemfile::certs(&mut rootbuf)
.unwrap()
.into_iter()
.map(|cert| {
let der = CertificateDer::from_slice(&cert);
anchor_from_trusted_cert(&der).unwrap().to_owned()
})
.collect();
let root_store = RootCertStore { roots };
config
.with_root_certificates(root_store)

View File

@@ -35,4 +35,5 @@ sha2 = { workspace = true, optional = true }
thiserror = { workspace = true }
tracing = { workspace = true, optional = true }
web-time = { workspace = true }
webpki = { workspace = true, features = ["alloc", "std"] }
rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true }

View File

@@ -1,65 +1,16 @@
use rustls_pki_types::TrustAnchor;
use crate::{
msgs::handshake::{DistinguishedName, DistinguishedNames},
x509,
};
/// A trust anchor, commonly known as a "Root Certificate."
#[derive(Debug, Clone)]
pub struct OwnedTrustAnchor {
subject: Vec<u8>,
spki: Vec<u8>,
name_constraints: Option<Vec<u8>>,
}
impl OwnedTrustAnchor {
/// Get a `webpki::TrustAnchor` by borrowing the owned elements.
pub(crate) fn to_trust_anchor(&self) -> webpki::TrustAnchor<'_> {
webpki::TrustAnchor {
subject: &self.subject,
spki: &self.spki,
name_constraints: self.name_constraints.as_deref(),
}
}
/// Constructs an `OwnedTrustAnchor` from its components.
///
/// `subject` is the subject field of the trust anchor.
///
/// `spki` is the `subjectPublicKeyInfo` field of the trust anchor.
///
/// `name_constraints` is the value of a DER-encoded name constraints to
/// apply for this trust anchor, if any.
pub fn from_subject_spki_name_constraints(
subject: impl Into<Vec<u8>>,
spki: impl Into<Vec<u8>>,
name_constraints: Option<impl Into<Vec<u8>>>,
) -> Self {
Self {
subject: subject.into(),
spki: spki.into(),
name_constraints: name_constraints.map(|x| x.into()),
}
}
}
/// Errors that can occur during operations with RootCertStore
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum RootCertStoreError {
#[error(transparent)]
WebpkiError(#[from] webpki::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error("Unexpected PEM certificate count. Expected 1 certificate, got {0}")]
PemCertUnexpectedCount(usize),
}
/// A container for root certificates able to provide a root-of-trust
/// for connection authentication.
#[derive(Debug, Clone)]
pub struct RootCertStore {
/// The list of roots.
pub roots: Vec<OwnedTrustAnchor>,
pub roots: Vec<TrustAnchor<'static>>,
}
impl RootCertStore {
@@ -91,100 +42,4 @@ impl RootCertStore {
r
}
/// Add a single DER-encoded certificate to the store.
pub fn add(&mut self, der: &crate::key::Certificate) -> Result<(), RootCertStoreError> {
let ta = webpki::TrustAnchor::try_from_cert_der(&der.0)?;
let ota = OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
);
self.roots.push(ota);
Ok(())
}
/// Adds a single PEM-encoded certificate to the store.
pub fn add_pem(&mut self, pem: &str) -> Result<(), RootCertStoreError> {
let mut certificates = rustls_pemfile::certs(&mut pem.as_bytes())?;
if certificates.len() != 1 {
return Err(RootCertStoreError::PemCertUnexpectedCount(
certificates.len(),
));
}
self.add(&crate::key::Certificate(certificates.remove(0)))?;
Ok(())
}
/// Adds all the given TrustAnchors `anchors`. This does not
/// fail.
pub fn add_server_trust_anchors(
&mut self,
trust_anchors: impl Iterator<Item = OwnedTrustAnchor>,
) {
self.roots.extend(trust_anchors)
}
/// Parse the given DER-encoded certificates and add all that can be parsed
/// in a best-effort fashion.
///
/// This is because large collections of root certificates often
/// include ancient or syntactically invalid certificates.
///
/// Returns the number of certificates added, and the number that were ignored.
pub fn add_parsable_certificates(&mut self, der_certs: &[Vec<u8>]) -> (usize, usize) {
let mut valid_count = 0;
let mut invalid_count = 0;
for der_cert in der_certs {
match self.add(&crate::key::Certificate(der_cert.clone())) {
Ok(_) => valid_count += 1,
Err(_err) => invalid_count += 1,
}
}
(valid_count, invalid_count)
}
}
#[cfg(test)]
mod tests {
use super::*;
const CA_PEM_CERT: &[u8] = include_bytes!("../testdata/cert-digicert.pem");
#[test]
fn test_add_pem_ok() {
let pem = std::str::from_utf8(CA_PEM_CERT).unwrap();
assert!(RootCertStore::empty().add_pem(pem).is_ok());
}
#[test]
fn test_add_pem_err_bad_cert() {
assert_eq!(
RootCertStore::empty()
.add_pem("bad pem")
.err()
.unwrap()
.to_string(),
"Unexpected PEM certificate count. Expected 1 certificate, got 0"
);
}
#[test]
fn test_add_pem_err_more_than_one_cert() {
let pem1 = std::str::from_utf8(CA_PEM_CERT).unwrap();
let pem2 = pem1;
assert_eq!(
RootCertStore::empty()
.add_pem((pem1.to_owned() + pem2).as_str())
.err()
.unwrap()
.to_string(),
"Unexpected PEM certificate count. Expected 1 certificate, got 2"
);
}
}

View File

@@ -1,6 +1,6 @@
use std::{error::Error as StdError, fmt};
use crate::verify;
use rustls_pki_types as pki_types;
/// Encodes ways a client can know the expected name of the server.
///
@@ -26,20 +26,16 @@ use crate::verify;
/// ```
#[non_exhaustive]
#[derive(Debug, PartialEq, Clone)]
pub enum ServerName {
/// The server is identified by a DNS name. The name
/// is sent in the TLS Server Name Indication (SNI)
/// extension.
DnsName(verify::DnsName),
}
pub struct ServerName(pub(crate) pki_types::ServerName<'static>);
impl ServerName {
/// Return the name that should go in the SNI extension.
/// If [`None`] is returned, the SNI extension is not included
/// in the handshake.
pub fn for_sni(&self) -> Option<webpki::DnsNameRef<'_>> {
match self {
Self::DnsName(dns_name) => Some(dns_name.0.as_ref()),
pub fn for_sni(&self) -> Option<pki_types::DnsName<'static>> {
match &self.0 {
pki_types::ServerName::DnsName(dns_name) => Some(dns_name.clone()),
_ => None,
}
}
@@ -49,13 +45,17 @@ impl ServerName {
DnsName = 0x01,
}
let Self::DnsName(dns_name) = self;
let bytes = dns_name.0.as_ref();
let bytes: &[u8] = match &self.0 {
pki_types::ServerName::DnsName(dns_name) => dns_name.as_ref().as_ref(),
pki_types::ServerName::IpAddress(pki_types::IpAddr::V4(ip)) => ip.as_ref(),
pki_types::ServerName::IpAddress(pki_types::IpAddr::V6(ip)) => ip.as_ref(),
_ => unreachable!(),
};
let mut r = Vec::with_capacity(2 + bytes.as_ref().len());
let mut r = Vec::with_capacity(2 + bytes.len());
r.push(UniqueTypeCode::DnsName as u8);
r.push(bytes.as_ref().len() as u8);
r.extend_from_slice(bytes.as_ref());
r.push(bytes.len() as u8);
r.extend_from_slice(bytes);
r
}
@@ -66,9 +66,9 @@ impl ServerName {
impl TryFrom<&str> for ServerName {
type Error = InvalidDnsNameError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match webpki::DnsNameRef::try_from_ascii_str(s) {
Ok(dns) => Ok(Self::DnsName(verify::DnsName(dns.into()))),
Err(webpki::InvalidDnsNameError) => Err(InvalidDnsNameError),
match pki_types::DnsName::try_from(s) {
Ok(dns) => Ok(Self(pki_types::ServerName::DnsName(dns.to_owned()))),
Err(_) => Err(InvalidDnsNameError),
}
}
}

View File

@@ -1,5 +1,6 @@
use crate::msgs::enums::{AlertDescription, ContentType, HandshakeType};
use std::{error::Error as StdError, fmt, time::SystemTimeError};
use std::{error::Error as StdError, fmt};
use web_time::SystemTimeError;
/// rustls reports protocol errors using this type.
#[derive(Debug, PartialEq, Clone)]
@@ -41,8 +42,9 @@ pub enum Error {
/// We couldn't decrypt a message. This is invariably fatal.
DecryptError,
/// We couldn't encrypt a message because it was larger than the allowed message size.
/// This should never happen if the application is using valid record sizes.
/// We couldn't encrypt a message because it was larger than the allowed
/// message size. This should never happen if the application is using
/// valid record sizes.
EncryptError,
/// The peer doesn't support a protocol version/feature we require.

View File

@@ -1,3 +1,5 @@
use rustls_pki_types as pki_types;
use crate::{
key,
msgs::{
@@ -249,14 +251,14 @@ impl DecomposedSignatureScheme for SignatureScheme {
#[derive(Clone, Debug)]
pub enum ServerNamePayload {
// Stored twice, bytes so we can round-trip, and DnsName for use
HostName((PayloadU16, webpki::DnsName)),
HostName((PayloadU16, pki_types::DnsName<'static>)),
Unknown(Payload),
}
impl ServerNamePayload {
pub fn new_hostname(hostname: webpki::DnsName) -> Self {
pub fn new_hostname(hostname: pki_types::DnsName<'static>) -> Self {
let raw = {
let s: &str = hostname.as_ref().into();
let s: &str = hostname.as_ref();
PayloadU16::new(s.as_bytes().into())
};
Self::HostName((raw, hostname))
@@ -266,8 +268,8 @@ impl ServerNamePayload {
let raw = PayloadU16::read(r)?;
let dns_name = {
match webpki::DnsNameRef::try_from_ascii(&raw.0) {
Ok(dns_name) => dns_name.into(),
match pki_types::DnsName::try_from(raw.0.as_slice()) {
Ok(dns_name) => dns_name.to_owned(),
Err(_) => {
warn!("Illegal SNI hostname received {:?}", raw.0);
return None;
@@ -313,11 +315,12 @@ declare_u16_vec!(ServerNameRequest, ServerName);
pub trait ConvertServerNameList {
fn has_duplicate_names_for_type(&self) -> bool;
fn get_single_hostname(&self) -> Option<webpki::DnsNameRef<'_>>;
fn get_single_hostname(&self) -> Option<pki_types::DnsName<'static>>;
}
impl ConvertServerNameList for ServerNameRequest {
/// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
/// RFC6066: "The ServerNameList MUST NOT contain more than one name of the
/// same name_type."
fn has_duplicate_names_for_type(&self) -> bool {
let mut seen = collections::HashSet::new();
@@ -330,10 +333,10 @@ impl ConvertServerNameList for ServerNameRequest {
false
}
fn get_single_hostname(&self) -> Option<webpki::DnsNameRef<'_>> {
fn only_dns_hostnames(name: &ServerName) -> Option<webpki::DnsNameRef<'_>> {
fn get_single_hostname(&self) -> Option<pki_types::DnsName<'static>> {
fn only_dns_hostnames(name: &ServerName) -> Option<pki_types::DnsName<'static>> {
if let ServerNamePayload::HostName((_, ref dns)) = name.payload {
Some(dns.as_ref())
Some(dns.clone())
} else {
None
}
@@ -694,16 +697,16 @@ impl Codec for ClientExtension {
}
}
fn trim_hostname_trailing_dot_for_sni(dns_name: webpki::DnsNameRef) -> webpki::DnsName {
let dns_name_str: &str = dns_name.into();
fn trim_hostname_trailing_dot_for_sni(
dns_name: pki_types::DnsName<'static>,
) -> pki_types::DnsName<'static> {
let dns_name_str: &str = dns_name.as_ref();
// RFC6066: "The hostname is represented as a byte string using
// ASCII encoding without a trailing dot"
if dns_name_str.ends_with('.') {
let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
webpki::DnsNameRef::try_from_ascii_str(trimmed)
.unwrap()
.to_owned()
pki_types::DnsName::try_from(trimmed).unwrap().to_owned()
} else {
dns_name.to_owned()
}
@@ -711,7 +714,7 @@ fn trim_hostname_trailing_dot_for_sni(dns_name: webpki::DnsNameRef) -> webpki::D
impl ClientExtension {
/// Make a basic SNI ServerNameRequest quoting `hostname`.
pub fn make_sni(dns_name: webpki::DnsNameRef) -> Self {
pub fn make_sni(dns_name: pki_types::DnsName<'static>) -> Self {
let name = ServerName {
typ: ServerNameType::HostName,
payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),

View File

@@ -1,3 +1,5 @@
use rustls_pki_types as pki_types;
use super::{
base::{Payload, PayloadU16, PayloadU24, PayloadU8},
codec::{put_u16, Codec, Reader},
@@ -5,7 +7,6 @@ use super::{
handshake::*,
};
use crate::key::Certificate;
use webpki::DnsNameRef;
#[test]
fn rejects_short_random() {
@@ -186,7 +187,8 @@ fn can_roundtrip_multiname_sni() {
assert!(req.has_duplicate_names_for_type());
let dns_name_str: &str = req.get_single_hostname().unwrap().into();
let dns_name = req.get_single_hostname().unwrap();
let dns_name_str: &str = dns_name.as_ref();
assert_eq!(dns_name_str, "hi");
assert_eq!(req[0].typ, ServerNameType::HostName);
@@ -363,7 +365,7 @@ fn get_sample_clienthellopayload() -> ClientHelloPayload {
ClientExtension::ECPointFormats(ECPointFormatList::supported()),
ClientExtension::NamedGroups(vec![NamedGroup::X25519]),
ClientExtension::SignatureAlgorithms(vec![SignatureScheme::ECDSA_NISTP256_SHA256]),
ClientExtension::make_sni(DnsNameRef::try_from_ascii_str("hello").unwrap()),
ClientExtension::make_sni(pki_types::DnsName::try_from("hello").unwrap().to_owned()),
ClientExtension::SessionTicket(ClientSessionTicket::Request),
ClientExtension::SessionTicket(ClientSessionTicket::Offer(Payload(vec![]))),
ClientExtension::Protocols(vec![PayloadU8(vec![0])]),

View File

@@ -1,5 +1,5 @@
use crate::{
anchors::{OwnedTrustAnchor, RootCertStore},
anchors::RootCertStore,
dns::ServerName,
error::Error,
key::Certificate,
@@ -9,27 +9,10 @@ use crate::{
},
};
use ring::digest::Digest;
use std::convert::TryFrom;
use web_time::SystemTime;
use rustls_pki_types as pki_types;
use web_time::{SystemTime, UNIX_EPOCH};
type SignatureAlgorithms = &'static [&'static webpki::SignatureAlgorithm];
/// Which signature verification mechanisms we support. No particular
/// order.
static SUPPORTED_SIG_ALGS: SignatureAlgorithms = &[
&webpki::ECDSA_P256_SHA256,
&webpki::ECDSA_P256_SHA384,
&webpki::ECDSA_P384_SHA256,
&webpki::ECDSA_P384_SHA384,
&webpki::ED25519,
&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
&webpki::RSA_PKCS1_2048_8192_SHA256,
&webpki::RSA_PKCS1_2048_8192_SHA384,
&webpki::RSA_PKCS1_2048_8192_SHA512,
&webpki::RSA_PKCS1_3072_8192_SHA384,
];
type SignatureAlgorithms = &'static [&'static dyn pki_types::SignatureVerificationAlgorithm];
// Marker types. These are used to bind the fact some verification
// (certificate chain or handshake signature) has taken place into
@@ -170,7 +153,7 @@ pub trait ServerCertVerifier: Send + Sync {
/// A type which encapsuates a string that is a syntactically valid DNS name.
#[derive(Clone, Debug, PartialEq)]
pub struct DnsName(pub(crate) webpki::DnsName);
pub struct DnsName(pub(crate) pki_types::DnsName<'static>);
impl AsRef<str> for DnsName {
fn as_ref(&self) -> &str {
@@ -289,35 +272,33 @@ impl ServerCertVerifier for WebPkiVerifier {
_ocsp_response: &[u8],
now: SystemTime,
) -> Result<ServerCertVerified, Error> {
let (cert, chain, trustroots) = prepare(end_entity, intermediates, &self.roots)?;
// `webpki::Time::try_from` does not work with `web_time::SystemTime`.
// To workaround this we convert `SystemTime` to seconds and use
// `webpki::Time::from_seconds_since_unix_epoch` instead.
let duration_since_epoch = now
.duration_since(web_time::UNIX_EPOCH)
.map_err(|_| Error::FailedToGetCurrentTime)?;
let seconds_since_unix_epoch = duration_since_epoch.as_secs();
let webpki_now = webpki::Time::from_seconds_since_unix_epoch(seconds_since_unix_epoch);
let cert = pki_types::CertificateDer::from(end_entity.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
let intermediates = intermediates
.iter()
.map(|c| pki_types::CertificateDer::from(c.0.as_slice()))
.collect::<Vec<_>>();
let time = pki_types::UnixTime::since_unix_epoch(now.duration_since(UNIX_EPOCH)?);
let ServerName::DnsName(dns_name) = server_name;
let cert = cert
.verify_is_valid_tls_server_cert(
SUPPORTED_SIG_ALGS,
&webpki::TlsServerTrustAnchors(&trustroots),
&chain,
webpki_now,
)
.map_err(pki_error)
.map(|_| cert)?;
cert.verify_for_usage(
webpki::ALL_VERIFICATION_ALGS,
&self.roots.roots,
&intermediates,
time,
webpki::KeyUsage::server_auth(),
None,
None,
)
.map(|_| ())
.map_err(pki_error)?;
if let Some(policy) = &self.ct_policy {
policy.verify(end_entity, now, scts)?;
}
cert.verify_is_valid_for_dns_name(dns_name.0.as_ref())
.map_err(pki_error)
cert.verify_is_valid_for_subject_name(&server_name.0)
.map(|_| ServerCertVerified::assertion())
.map_err(pki_error)
}
}
@@ -429,31 +410,6 @@ impl CertificateTransparencyPolicy {
}
}
type CertChainAndRoots<'a, 'b> = (
webpki::EndEntityCert<'a>,
Vec<&'a [u8]>,
Vec<webpki::TrustAnchor<'b>>,
);
fn prepare<'a, 'b>(
end_entity: &'a Certificate,
intermediates: &'a [Certificate],
roots: &'b RootCertStore,
) -> Result<CertChainAndRoots<'a, 'b>, Error> {
// EE cert must appear first.
let cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?;
let intermediates: Vec<&'a [u8]> = intermediates.iter().map(|cert| cert.0.as_ref()).collect();
let trustroots: Vec<webpki::TrustAnchor> = roots
.roots
.iter()
.map(OwnedTrustAnchor::to_trust_anchor)
.collect();
Ok((cert, intermediates, trustroots))
}
pub(crate) fn pki_error(error: webpki::Error) -> Error {
use webpki::Error::*;
match error {
@@ -466,20 +422,24 @@ pub(crate) fn pki_error(error: webpki::Error) -> Error {
}
}
static ECDSA_SHA256: SignatureAlgorithms =
&[&webpki::ECDSA_P256_SHA256, &webpki::ECDSA_P384_SHA256];
static ECDSA_SHA256: SignatureAlgorithms = &[
webpki::ring::ECDSA_P256_SHA256,
webpki::ring::ECDSA_P384_SHA256,
];
static ECDSA_SHA384: SignatureAlgorithms =
&[&webpki::ECDSA_P256_SHA384, &webpki::ECDSA_P384_SHA384];
static ECDSA_SHA384: SignatureAlgorithms = &[
webpki::ring::ECDSA_P256_SHA384,
webpki::ring::ECDSA_P384_SHA384,
];
static ED25519: SignatureAlgorithms = &[&webpki::ED25519];
static ED25519: SignatureAlgorithms = &[webpki::ring::ED25519];
static RSA_SHA256: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA256];
static RSA_SHA384: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA384];
static RSA_SHA512: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA512];
static RSA_PSS_SHA256: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY];
static RSA_PSS_SHA384: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY];
static RSA_PSS_SHA512: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY];
static RSA_SHA256: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA256];
static RSA_SHA384: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA384];
static RSA_SHA512: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA512];
static RSA_PSS_SHA256: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY];
static RSA_PSS_SHA384: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY];
static RSA_PSS_SHA512: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY];
fn convert_scheme(scheme: SignatureScheme) -> Result<SignatureAlgorithms, Error> {
match scheme {
@@ -514,7 +474,7 @@ fn verify_sig_using_any_alg(
// webpki::SignatureAlgorithm. Therefore, convert_algs maps to several and
// we try them all.
for alg in algs {
match cert.verify_signature(alg, message, sig) {
match cert.verify_signature(*alg, message, sig) {
Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey) => continue,
res => return res,
}
@@ -529,7 +489,9 @@ fn verify_signed_struct(
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
let possible_algs = convert_scheme(dss.scheme)?;
let cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?;
let cert = pki_types::CertificateDer::from(cert.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
verify_sig_using_any_alg(&cert, possible_algs, message, &dss.sig.0)
.map_err(pki_error)
@@ -538,16 +500,16 @@ fn verify_signed_struct(
fn convert_alg_tls13(
scheme: SignatureScheme,
) -> Result<&'static webpki::SignatureAlgorithm, Error> {
) -> Result<&'static dyn pki_types::SignatureVerificationAlgorithm, Error> {
use crate::msgs::enums::SignatureScheme::*;
match scheme {
ECDSA_NISTP256_SHA256 => Ok(&webpki::ECDSA_P256_SHA256),
ECDSA_NISTP384_SHA384 => Ok(&webpki::ECDSA_P384_SHA384),
ED25519 => Ok(&webpki::ED25519),
RSA_PSS_SHA256 => Ok(&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY),
RSA_PSS_SHA384 => Ok(&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY),
RSA_PSS_SHA512 => Ok(&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY),
ECDSA_NISTP256_SHA256 => Ok(webpki::ring::ECDSA_P256_SHA256),
ECDSA_NISTP384_SHA384 => Ok(webpki::ring::ECDSA_P384_SHA384),
ED25519 => Ok(webpki::ring::ED25519),
RSA_PSS_SHA256 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY),
RSA_PSS_SHA384 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY),
RSA_PSS_SHA512 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY),
_ => {
let error_msg = format!("received unsupported sig scheme {scheme:?}");
Err(Error::PeerMisbehavedError(error_msg))
@@ -583,7 +545,8 @@ fn verify_tls13(
) -> Result<HandshakeSignatureValid, Error> {
let alg = convert_alg_tls13(dss.scheme)?;
let cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?;
let cert = pki_types::CertificateDer::from(cert.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
cert.verify_signature(alg, msg, &dss.sig.0)
.map_err(pki_error)

View File

@@ -45,7 +45,8 @@ derive_builder = { workspace = true }
futures = { workspace = true }
opaque-debug = { workspace = true }
rand = { workspace = true }
rustls-pki-types = "1.12.0"
rustls-pki-types = { workspace = true }
rustls-webpki = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
tokio = { workspace = true, features = ["sync"] }

View File

@@ -5,6 +5,8 @@ use semver::Version;
use serde::{Deserialize, Serialize};
use std::error::Error;
pub use tlsn_core::webpki::{CertificateDer, PrivateKeyDer, RootCertStore};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
// Default maximum number of TLS records to allow.

View File

@@ -2,7 +2,7 @@
use serde::{Deserialize, Serialize};
use tlsn_core::connection::{ServerCertData, ServerName};
use tlsn_core::connection::{HandshakeData, ServerName};
/// Message sent from Prover to Verifier to prove the server identity.
#[derive(Debug, Serialize, Deserialize)]
@@ -10,5 +10,5 @@ pub(crate) struct ServerIdentityProof {
/// Server name.
pub name: ServerName,
/// Server identity data.
pub data: ServerCertData,
pub data: HandshakeData,
}

View File

@@ -8,12 +8,14 @@ pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
pub use error::ProverError;
pub use future::ProverFuture;
use rustls_pki_types::CertificateDer;
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use webpki::anchor_from_trusted_cert;
use crate::{
Role,
@@ -39,7 +41,7 @@ use tls_client_async::{TlsConnection, bind_client};
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
connection::ServerCertData,
connection::{HandshakeData, ServerName},
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
};
@@ -179,21 +181,40 @@ impl Prover<state::Setup> {
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = self.config.server_name();
let server_name =
TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
ProverError::config(format!(
"invalid server name: {}",
self.config.server_name()
))
})?;
TlsServerName::try_from(server_name.as_ref()).expect("name was validated");
let root_store = if let Some(root_store) = self.config.tls_config().root_store() {
let roots = root_store
.roots
.iter()
.map(|cert| {
let der = CertificateDer::from_slice(&cert.0);
anchor_from_trusted_cert(&der)
.map(|anchor| anchor.to_owned())
.map_err(ProverError::config)
})
.collect::<Result<Vec<_>, _>>()?;
tls_client::RootCertStore { roots }
} else {
tls_client::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
}
};
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(self.config.tls_config().root_store().clone());
.with_root_certificates(root_store);
let config = if let Some((cert, key)) = self.config.tls_config().client_auth() {
config
.with_single_cert(cert.clone(), key.clone())
.with_single_cert(
cert.iter()
.map(|cert| tls_client::Certificate(cert.0.clone()))
.collect(),
tls_client::PrivateKey(key.0.clone()),
)
.map_err(ProverError::config)?
} else {
config.with_no_client_auth()
@@ -350,10 +371,10 @@ impl Prover<state::Committed> {
};
let payload = ProvePayload {
server_identity: config.server_identity().then(|| {
handshake: config.server_identity().then(|| {
(
self.config.server_name().clone(),
ServerCertData {
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
@@ -362,7 +383,7 @@ impl Prover<state::Committed> {
.server_signature()
.expect("server signature is present")
.clone(),
handshake: tls_transcript.handshake_data().clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),

View File

@@ -1,11 +1,10 @@
use crate::config::{NetworkSetting, ProtocolConfig};
use mpc_tls::Config;
use rustls_pki_types::{CertificateDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, pem::PemObject};
use tls_core::{
anchors::{OwnedTrustAnchor, RootCertStore},
key,
use tlsn_core::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use tlsn_core::connection::ServerName;
use crate::config::{NetworkSetting, ProtocolConfig};
/// Configuration for the prover.
#[derive(Debug, Clone, derive_builder::Builder)]
@@ -67,31 +66,13 @@ impl ProverConfig {
}
/// Configuration for the prover's TLS connection.
#[derive(Debug, Clone)]
#[derive(Default, Debug, Clone)]
pub struct TlsConfig {
/// Root certificates.
root_store: RootCertStore,
root_store: Option<RootCertStore>,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<key::Certificate>, key::PrivateKey)>,
}
impl Default for TlsConfig {
fn default() -> Self {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
Self {
root_store,
client_auth: None,
}
}
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfig {
@@ -100,13 +81,13 @@ impl TlsConfig {
TlsConfigBuilder::default()
}
pub(crate) fn root_store(&self) -> &RootCertStore {
&self.root_store
pub(crate) fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> &Option<(Vec<key::Certificate>, key::PrivateKey)> {
pub fn client_auth(&self) -> &Option<(Vec<CertificateDer>, PrivateKeyDer)> {
&self.client_auth
}
}
@@ -115,7 +96,7 @@ impl TlsConfig {
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<key::Certificate>, key::PrivateKey)>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
@@ -138,74 +119,16 @@ impl TlsConfigBuilder {
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(&mut self, cert_key: (Vec<Vec<u8>>, Vec<u8>)) -> &mut Self {
let certs = cert_key
.0
.into_iter()
.map(key::Certificate)
.collect::<Vec<_>>();
self.client_auth = Some((certs, key::PrivateKey(cert_key.1)));
pub fn client_auth(&mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> &mut Self {
self.client_auth = Some(cert_key);
self
}
/// Sets a PEM-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth_pem(
&mut self,
cert_key: (Vec<Vec<u8>>, Vec<u8>),
) -> Result<&mut Self, TlsConfigError> {
let key = match PrivatePkcs8KeyDer::from_pem_slice(&cert_key.1) {
// Try to parse as PEM PKCS#8.
Ok(key) => (*key.secret_pkcs8_der()).to_vec(),
// Otherwise, try to parse as PEM PKCS#1.
Err(_) => match PrivatePkcs1KeyDer::from_pem_slice(&cert_key.1) {
Ok(key) => (*key.secret_pkcs1_der()).to_vec(),
Err(_) => return Err(ErrorRepr::InvalidKey.into()),
},
};
let certs = cert_key
.0
.iter()
.map(|c| {
let c =
CertificateDer::from_pem_slice(c).map_err(|_| ErrorRepr::InvalidCertificate)?;
Ok::<key::Certificate, TlsConfigError>(key::Certificate(c.as_ref().to_vec()))
})
.collect::<Result<Vec<_>, _>>()?;
self.client_auth = Some((certs, key::PrivateKey(key)));
Ok(self)
}
/// Builds the TLS configuration.
pub fn build(&self) -> Result<TlsConfig, TlsConfigError> {
pub fn build(self) -> Result<TlsConfig, TlsConfigError> {
Ok(TlsConfig {
root_store: self.root_store.clone().unwrap_or_else(|| {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
},
));
root_store
}),
client_auth: self.client_auth.clone(),
root_store: self.root_store,
client_auth: self.client_auth,
})
}
}
@@ -216,10 +139,5 @@ impl TlsConfigBuilder {
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error: {0}")]
enum ErrorRepr {
#[error("the certificate for client authentication is invalid")]
InvalidCertificate,
#[error("the private key for client authentication is invalid")]
InvalidKey,
}
#[error("tls config error")]
enum ErrorRepr {}

View File

@@ -8,7 +8,10 @@ use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError};
pub use tlsn_core::{
VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError,
webpki::ServerCertVerifier,
};
use crate::{
Role,
@@ -31,7 +34,7 @@ use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use serio::stream::IoStreamExt;
use tls_core::{msgs::enums::ContentType, verify::WebPkiVerifier};
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
connection::{ConnectionInfo, ServerName},
@@ -305,15 +308,20 @@ impl Verifier<state::Committed> {
} = &mut self.state;
let ProvePayload {
server_identity,
handshake,
transcript,
transcript_commit,
} = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let verifier = WebPkiVerifier::new(self.config.root_store().clone(), None);
let server_name = if let Some((name, cert_data)) = server_identity {
let verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else {
ServerCertVerifier::mozilla()
};
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
&verifier,

View File

@@ -2,7 +2,7 @@ use std::fmt::{Debug, Formatter, Result};
use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
use mpc_tls::Config;
use tls_core::anchors::{OwnedTrustAnchor, RootCertStore};
use tlsn_core::webpki::RootCertStore;
/// Configuration for the [`Verifier`](crate::tls::Verifier).
#[allow(missing_docs)]
@@ -10,8 +10,8 @@ use tls_core::anchors::{OwnedTrustAnchor, RootCertStore};
#[builder(pattern = "owned")]
pub struct VerifierConfig {
protocol_config_validator: ProtocolConfigValidator,
#[builder(default = "default_root_store()")]
root_store: RootCertStore,
#[builder(setter(strip_option))]
root_store: Option<RootCertStore>,
}
impl Debug for VerifierConfig {
@@ -34,8 +34,8 @@ impl VerifierConfig {
}
/// Returns the root certificate store.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
pub fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
pub(crate) fn build_mpc_tls_config(&self, protocol_config: &ProtocolConfig) -> Config {
@@ -61,16 +61,3 @@ impl VerifierConfig {
builder.build().unwrap()
}
}
fn default_root_store() -> RootCertStore {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
root_store
}

View File

@@ -20,6 +20,13 @@ impl VerifierError {
}
}
pub(crate) fn config<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Config, source)
}
pub(crate) fn mpc<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,

View File

@@ -1,6 +1,7 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use tlsn::{
config::{ProtocolConfig, ProtocolConfigValidator},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{TranscriptCommitConfig, TranscriptCommitment},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
@@ -37,19 +38,17 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let server_task = tokio::spawn(bind(server_socket.compat()));
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
let prover = Prover::new(
ProverConfig::builder()
.server_name(SERVER_DOMAIN)
.server_name(server_name)
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
@@ -110,11 +109,6 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(socket: T) {
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
@@ -123,7 +117,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap(),
@@ -140,7 +136,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
let transcript = transcript.unwrap();
assert_eq!(server_name.unwrap().as_str(), SERVER_DOMAIN);
let ServerName::Dns(server_name) = server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!transcript.is_complete());
assert_eq!(
transcript.sent_authed().iter_ranges().next().unwrap(),

View File

@@ -1,7 +1,11 @@
use crate::types::NetworkSetting;
use serde::Deserialize;
use tlsn::config::ProtocolConfig;
use tlsn::{
config::{CertificateDer, PrivateKeyDer, ProtocolConfig},
connection::ServerName,
};
use tsify_next::Tsify;
use wasm_bindgen::JsError;
#[derive(Debug, Tsify, Deserialize)]
#[tsify(from_wasm_abi)]
@@ -17,8 +21,10 @@ pub struct ProverConfig {
pub client_auth: Option<(Vec<Vec<u8>>, Vec<u8>)>,
}
impl From<ProverConfig> for tlsn::prover::ProverConfig {
fn from(value: ProverConfig) -> Self {
impl TryFrom<ProverConfig> for tlsn::prover::ProverConfig {
type Error = JsError;
fn try_from(value: ProverConfig) -> Result<Self, Self::Error> {
let mut builder = ProtocolConfig::builder();
builder.max_sent_data(value.max_sent_data);
@@ -44,21 +50,36 @@ impl From<ProverConfig> for tlsn::prover::ProverConfig {
let protocol_config = builder.build().unwrap();
let mut builder = tlsn::prover::TlsConfig::builder();
if let Some(cert_key) = value.client_auth {
// Try to parse as PEM-encoded.
if builder.client_auth_pem(cert_key.clone()).is_err() {
// Otherwise assume DER encoding.
builder.client_auth(cert_key);
}
if let Some((certs, key)) = value.client_auth {
let certs = certs
.into_iter()
.map(|cert| {
// Try to parse as PEM-encoded, otherwise assume DER.
if let Ok(cert) = CertificateDer::from_pem_slice(&cert) {
cert
} else {
CertificateDer(cert)
}
})
.collect();
let key = PrivateKeyDer(key);
builder.client_auth((certs, key));
}
let tls_config = builder.build().unwrap();
let server_name = ServerName::Dns(
value
.server_name
.try_into()
.map_err(|_| JsError::new("invalid server name"))?,
);
let mut builder = tlsn::prover::ProverConfig::builder();
builder
.server_name(value.server_name.as_ref())
.server_name(server_name)
.protocol_config(protocol_config)
.tls_config(tls_config);
builder.build().unwrap()
Ok(builder.build().unwrap())
}
}

View File

@@ -41,10 +41,10 @@ impl State {
#[wasm_bindgen(js_class = Prover)]
impl JsProver {
#[wasm_bindgen(constructor)]
pub fn new(config: ProverConfig) -> JsProver {
JsProver {
state: State::Initialized(Prover::new(config.into())),
}
pub fn new(config: ProverConfig) -> Result<JsProver> {
Ok(JsProver {
state: State::Initialized(Prover::new(config.try_into()?)),
})
}
/// Set up the prover.

View File

@@ -5,7 +5,7 @@ pub use config::VerifierConfig;
use enum_try_as_inner::EnumTryAsInner;
use tls_core::msgs::enums::ContentType;
use tlsn::{
connection::{ConnectionInfo, TranscriptLength},
connection::{ConnectionInfo, ServerName, TranscriptLength},
verifier::{
state::{self, Initialized},
Verifier, VerifyConfig,
@@ -106,7 +106,10 @@ impl JsVerifier {
self.state = State::Complete;
Ok(VerifierOutput {
server_name: output.server_name.map(|s| s.as_str().to_string()),
server_name: output.server_name.map(|name| {
let ServerName::Dns(name) = name;
name.to_string()
}),
connection_info: connection_info.into(),
transcript: output.transcript.map(|t| t.into()),
})