use std::{ env, net::{IpAddr, SocketAddr}, }; use http_body_util::Empty; use hyper::{body::Bytes, Request, StatusCode, Uri}; use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::instrument; use tlsn::{ config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore}, connection::ServerName, prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, transcript::PartialTranscript, verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig}, }; use tlsn_server_fixture::DEFAULT_FIXTURE_PORT; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; const SECRET: &str = "TLSNotary's private key 🤡"; // Maximum number of bytes that can be sent from prover to server. const MAX_SENT_DATA: usize = 1 << 12; // Maximum number of bytes that can be received by prover from server. const MAX_RECV_DATA: usize = 1 << 14; #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into()); let server_port: u16 = env::var("SERVER_PORT") .map(|port| port.parse().expect("port should be valid integer")) .unwrap_or(DEFAULT_FIXTURE_PORT); // We use SERVER_DOMAIN here to make sure it matches the domain in the test // server's certificate. let uri = format!("https://{SERVER_DOMAIN}:{server_port}/formats/html"); let server_ip: IpAddr = server_host.parse().expect("Invalid IP address"); let server_addr = SocketAddr::from((server_ip, server_port)); // Connect prover and verifier. let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23); let prover = prover(prover_socket, &server_addr, &uri); let verifier = verifier(verifier_socket); let (_, transcript) = tokio::join!(prover, verifier); println!("Successfully verified {}", &uri); println!( "Verified sent data:\n{}", bytes_to_redacted_string(transcript.sent_unsafe()) ); println!( "Verified received data:\n{}", bytes_to_redacted_string(transcript.received_unsafe()) ); } #[instrument(skip(verifier_socket))] async fn prover( verifier_socket: T, server_addr: &SocketAddr, uri: &str, ) { let uri = uri.parse::().unwrap(); assert_eq!(uri.scheme().unwrap().as_str(), "https"); let server_domain = uri.authority().unwrap().host(); // Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the // server-fixture. let mut tls_config_builder = TlsConfig::builder(); tls_config_builder.root_store(RootCertStore { roots: vec![CertificateDer(CA_CERT_DER.to_vec())], }); let tls_config = tls_config_builder.build().unwrap(); // Set up protocol configuration for prover. let mut prover_config_builder = ProverConfig::builder(); prover_config_builder .server_name(ServerName::Dns(server_domain.try_into().unwrap())) .tls_config(tls_config) .protocol_config( ProtocolConfig::builder() .max_sent_data(MAX_SENT_DATA) .max_recv_data(MAX_RECV_DATA) .build() .unwrap(), ); let prover_config = prover_config_builder.build().unwrap(); // Create prover and connect to verifier. // // Perform the setup phase with the verifier. let prover = Prover::new(prover_config) .setup(verifier_socket.compat()) .await .unwrap(); // Connect to TLS Server. let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await.unwrap(); // Pass server connection into the prover. let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await.unwrap(); // Wrap the connection in a TokioIo compatibility layer to use it with hyper. let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); // Spawn the Prover to run in the background. let prover_task = tokio::spawn(prover_fut); // MPC-TLS Handshake. let (mut request_sender, connection) = hyper::client::conn::http1::handshake(mpc_tls_connection) .await .unwrap(); // Spawn the connection to run in the background. tokio::spawn(connection); // MPC-TLS: Send Request and wait for Response. let request = Request::builder() .uri(uri.clone()) .header("Host", server_domain) .header("Connection", "close") .header("Secret", SECRET) .method("GET") .body(Empty::::new()) .unwrap(); let response = request_sender.send_request(request).await.unwrap(); assert!(response.status() == StatusCode::OK); // Create proof for the Verifier. let mut prover = prover_task.await.unwrap().unwrap(); let mut builder = ProveConfig::builder(prover.transcript()); // Reveal the DNS name. builder.server_identity(); // Find the secret in the request. let pos = prover .transcript() .sent() .windows(SECRET.len()) .position(|w| w == SECRET.as_bytes()) .expect("the secret should be in the sent data"); // Reveal everything except for the secret. builder.reveal_sent(&(0..pos)).unwrap(); builder .reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len())) .unwrap(); // Find the substring "Dick". let pos = prover .transcript() .received() .windows(4) .position(|w| w == b"Dick") .expect("the substring 'Dick' should be in the received data"); // Reveal everything except for the substring. builder.reveal_recv(&(0..pos)).unwrap(); builder .reveal_recv(&(pos + 4..prover.transcript().received().len())) .unwrap(); let config = builder.build().unwrap(); prover.prove(&config).await.unwrap(); prover.close().await.unwrap(); } #[instrument(skip(socket))] async fn verifier( socket: T, ) -> PartialTranscript { // Set up Verifier. let config_validator = ProtocolConfigValidator::builder() .max_sent_data(MAX_SENT_DATA) .max_recv_data(MAX_RECV_DATA) .build() .unwrap(); // Create a root certificate store with the server-fixture's self-signed // certificate. This is only required for offline testing with the // server-fixture. let verifier_config = VerifierConfig::builder() .root_store(RootCertStore { roots: vec![CertificateDer(CA_CERT_DER.to_vec())], }) .protocol_config_validator(config_validator) .build() .unwrap(); let verifier = Verifier::new(verifier_config); // Receive authenticated data. let VerifierOutput { server_name, transcript, .. } = verifier .verify(socket.compat(), &VerifyConfig::default()) .await .unwrap(); let server_name = server_name.expect("prover should have revealed server name"); let transcript = transcript.expect("prover should have revealed transcript data"); // Check sent data. let sent = transcript.sent_unsafe().to_vec(); let sent_data = String::from_utf8(sent.clone()).expect("Verifier expected sent data"); sent_data .find(SERVER_DOMAIN) .unwrap_or_else(|| panic!("Verification failed: Expected host {SERVER_DOMAIN}")); // Check received data. let received = transcript.received_unsafe().to_vec(); let response = String::from_utf8(received.clone()).expect("Verifier expected received data"); response .find("Herman Melville") .unwrap_or_else(|| panic!("Expected valid data from {SERVER_DOMAIN}")); // Check Session info: server name. let ServerName::Dns(server_name) = server_name; assert_eq!(server_name.as_str(), SERVER_DOMAIN); transcript } /// Render redacted bytes as `🙈`. fn bytes_to_redacted_string(bytes: &[u8]) -> String { String::from_utf8(bytes.to_vec()) .unwrap() .replace('\0', "🙈") }