mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-09 14:48:13 -05:00
perf: MPC-TLS upgrade (#698)
* fix: add new Cargo.toml * (alpha.8) - Refactor key-exchange crate (#685) * refactor(key-exchange): adapt key-exchange to new vm * fix: fix feature flags * simplify * delete old msg module * clean up error --------- Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com> * (alpha.8) - Refactor prf crate (#684) * refactor(prf): adapt prf to new mpz vm Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com> * refactor: remove preprocessing bench * fix: fix feature flags * clean up attributes --------- Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com> * refactor: key exchange interface (#688) * refactor: prf interface (#689) * (alpha.8) - Create cipher crate (#683) * feat(cipher): add cipher crate, replacing stream/block cipher and aead * delete old config module * remove mpz generics --------- Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com> * refactor(core): decouple encoder from mpz (#692) * WIP: Adding new encoding logic... * feat: add new encoder * add feedback * rename conversions * feat: DEAP VM (#690) * feat: DEAP VM * use rangeset, add desync guard * move MPC execution up in finalization * refactor: MPC-TLS (#693) * refactor: MPC-TLS Co-authored-by: th4s <th4s@metavoid.xyz> * output key references * bump deps --------- Co-authored-by: th4s <th4s@metavoid.xyz> * refactor: prover + verifier (#696) * refactor: wasm crates (#697) * chore: appease clippy (#699) * chore: rustfmt * chore: appease clippy more * chore: more rustfmt! * chore: clippy is stubborn * chore: rustfmt sorting change is annoying! * fix: remove wasm bundling hack * fix: aes ctr test * chore: clippy * fix: flush client when sending close notify * fix: failing tests --------- Co-authored-by: th4s <th4s@metavoid.xyz>
This commit is contained in:
51
Cargo.toml
51
Cargo.toml
@@ -6,13 +6,11 @@ members = [
|
||||
"crates/benches/browser/wasm",
|
||||
"crates/benches/library",
|
||||
"crates/common",
|
||||
"crates/components/aead",
|
||||
"crates/components/block-cipher",
|
||||
"crates/components/deap",
|
||||
"crates/components/cipher",
|
||||
"crates/components/hmac-sha256",
|
||||
"crates/components/hmac-sha256-circuits",
|
||||
"crates/components/key-exchange",
|
||||
"crates/components/stream-cipher",
|
||||
"crates/components/universal-hash",
|
||||
"crates/core",
|
||||
"crates/data-fixtures",
|
||||
"crates/examples",
|
||||
@@ -28,7 +26,7 @@ members = [
|
||||
"crates/tls/client",
|
||||
"crates/tls/client-async",
|
||||
"crates/tls/core",
|
||||
"crates/tls/mpc",
|
||||
"crates/mpc-tls",
|
||||
"crates/tls/server-fixture",
|
||||
"crates/verifier",
|
||||
"crates/wasm",
|
||||
@@ -44,45 +42,47 @@ opt-level = 1
|
||||
notary-client = { path = "crates/notary/client" }
|
||||
notary-server = { path = "crates/notary/server" }
|
||||
tls-server-fixture = { path = "crates/tls/server-fixture" }
|
||||
tlsn-aead = { path = "crates/components/aead" }
|
||||
tlsn-cipher = { path = "crates/components/cipher" }
|
||||
tlsn-benches-browser-core = { path = "crates/benches/browser/core" }
|
||||
tlsn-benches-browser-native = { path = "crates/benches/browser/native" }
|
||||
tlsn-benches-library = { path = "crates/benches/library" }
|
||||
tlsn-block-cipher = { path = "crates/components/block-cipher" }
|
||||
tlsn-common = { path = "crates/common" }
|
||||
tlsn-core = { path = "crates/core" }
|
||||
tlsn-data-fixtures = { path = "crates/data-fixtures" }
|
||||
tlsn-deap = { path = "crates/components/deap" }
|
||||
tlsn-formats = { path = "crates/formats" }
|
||||
tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" }
|
||||
tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" }
|
||||
tlsn-key-exchange = { path = "crates/components/key-exchange" }
|
||||
tlsn-mpc-tls = { path = "crates/mpc-tls" }
|
||||
tlsn-prover = { path = "crates/prover" }
|
||||
tlsn-server-fixture = { path = "crates/server-fixture/server" }
|
||||
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
|
||||
tlsn-stream-cipher = { path = "crates/components/stream-cipher" }
|
||||
tlsn-tls-backend = { path = "crates/tls/backend" }
|
||||
tlsn-tls-client = { path = "crates/tls/client" }
|
||||
tlsn-tls-client-async = { path = "crates/tls/client-async" }
|
||||
tlsn-tls-core = { path = "crates/tls/core" }
|
||||
tlsn-tls-mpc = { path = "crates/tls/mpc" }
|
||||
tlsn-universal-hash = { path = "crates/components/universal-hash" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
|
||||
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
|
||||
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
|
||||
tlsn-verifier = { path = "crates/verifier" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
|
||||
|
||||
serio = { version = "0.1" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
|
||||
uid-mux = { version = "0.1", features = ["serio"] }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
|
||||
serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
|
||||
uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
|
||||
|
||||
aes = { version = "0.8" }
|
||||
aes-gcm = { version = "0.9" }
|
||||
@@ -113,6 +113,7 @@ http = { version = "1.1" }
|
||||
http-body-util = { version = "0.1" }
|
||||
hyper = { version = "1.1" }
|
||||
hyper-util = { version = "0.1" }
|
||||
itybity = { version = "0.2" }
|
||||
k256 = { version = "0.13" }
|
||||
log = { version = "0.4" }
|
||||
once_cell = { version = "1.19" }
|
||||
@@ -123,6 +124,7 @@ pin-project-lite = { version = "0.2" }
|
||||
rand = { version = "0.8" }
|
||||
rand_chacha = { version = "0.3" }
|
||||
rand_core = { version = "0.6" }
|
||||
rayon = { version = "1.10" }
|
||||
regex = { version = "1.10" }
|
||||
ring = { version = "0.17" }
|
||||
rs_merkle = { git = "https://github.com/tlsnotary/rs-merkle.git", rev = "85f3e82" }
|
||||
@@ -141,6 +143,7 @@ tokio-util = { version = "0.7" }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-subscriber = { version = "0.3" }
|
||||
uuid = { version = "1.4" }
|
||||
web-spawn = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "2d93c56" }
|
||||
web-time = { version = "0.2" }
|
||||
webpki = { version = "0.22" }
|
||||
webpki-roots = { version = "0.26" }
|
||||
|
||||
@@ -9,8 +9,5 @@ fi
|
||||
# Run the benchmark binary.
|
||||
../../../target/release/bench
|
||||
|
||||
# Run the benchmark binary in memory profiling mode.
|
||||
../../../target/release/bench --memory-profiling
|
||||
|
||||
# Plot the results.
|
||||
../../../target/release/plot metrics.csv
|
||||
|
||||
@@ -1,71 +1,5 @@
|
||||
use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role};
|
||||
use mpz_common::executor::test_st_executor;
|
||||
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory};
|
||||
use mpz_ot::ideal::ot::ideal_ot;
|
||||
use hmac_sha256::build_circuits;
|
||||
|
||||
pub async fn preprocess_prf_circuits() {
|
||||
let pms = [42u8; 32];
|
||||
let client_random = [69u8; 32];
|
||||
|
||||
let (leader_ctx_0, follower_ctx_0) = test_st_executor(128);
|
||||
let (leader_ctx_1, follower_ctx_1) = test_st_executor(128);
|
||||
|
||||
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
|
||||
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
|
||||
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
|
||||
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
|
||||
|
||||
let leader_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Leader,
|
||||
[0u8; 32],
|
||||
leader_ctx_0,
|
||||
leader_ot_send_0,
|
||||
leader_ot_recv_0,
|
||||
);
|
||||
let leader_thread_1 = leader_thread_0
|
||||
.new_thread(leader_ctx_1, leader_ot_send_1, leader_ot_recv_1)
|
||||
.unwrap();
|
||||
|
||||
let follower_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Follower,
|
||||
[0u8; 32],
|
||||
follower_ctx_0,
|
||||
follower_ot_send_0,
|
||||
follower_ot_recv_0,
|
||||
);
|
||||
let follower_thread_1 = follower_thread_0
|
||||
.new_thread(follower_ctx_1, follower_ot_send_1, follower_ot_recv_1)
|
||||
.unwrap();
|
||||
|
||||
// Set up public PMS for testing.
|
||||
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
|
||||
let follower_pms = follower_thread_0
|
||||
.new_public_input::<[u8; 32]>("pms")
|
||||
.unwrap();
|
||||
|
||||
leader_thread_0.assign(&leader_pms, pms).unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Leader).build().unwrap(),
|
||||
leader_thread_0,
|
||||
leader_thread_1,
|
||||
);
|
||||
let mut follower = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Follower).build().unwrap(),
|
||||
follower_thread_0,
|
||||
follower_thread_1,
|
||||
);
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader.setup(leader_pms).await.unwrap();
|
||||
leader.set_client_random(Some(client_random)).await.unwrap();
|
||||
leader.preprocess().await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower.setup(follower_pms).await.unwrap();
|
||||
follower.set_client_random(None).await.unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
}
|
||||
);
|
||||
build_circuits().await;
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ async fn run_instance<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
|
||||
.build()?,
|
||||
);
|
||||
|
||||
_ = verifier.verify(io.compat()).await?;
|
||||
verifier.verify(io.compat()).await?;
|
||||
|
||||
println!("verifier done");
|
||||
|
||||
|
||||
@@ -20,12 +20,11 @@ wasm-bindgen = { version = "0.2.87" }
|
||||
wasm-bindgen-futures = { version = "0.4.37" }
|
||||
web-time = { workspace = true }
|
||||
# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958
|
||||
ws_stream_wasm = { version = "0.7.4", git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51", features = ["tokio_io"]}
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
wasm-bindgen-rayon = { version = "1.2", features = ["no-bundler"] }
|
||||
ws_stream_wasm = { version = "0.7.4", git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51", features = [
|
||||
"tokio_io",
|
||||
] }
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
# Note: these wasm-pack options should match those in crates/wasm/Cargo.toml
|
||||
opt-level = "z"
|
||||
opt-level = "z"
|
||||
wasm-opt = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import * as Comlink from "./comlink.mjs";
|
||||
|
||||
import init, { wasm_main, initThreadPool, init_logging } from './tlsn_benches_browser_wasm.js';
|
||||
import init, { wasm_main, initialize } from './tlsn_benches_browser_wasm.js';
|
||||
|
||||
class Worker {
|
||||
async init() {
|
||||
@@ -12,7 +12,7 @@ class Worker {
|
||||
// crate_filters: undefined,
|
||||
// span_events: undefined,
|
||||
// });
|
||||
await initThreadPool(navigator.hardwareConcurrency);
|
||||
await initialize({ thread_count: navigator.hardwareConcurrency });
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
throw e;
|
||||
@@ -20,12 +20,12 @@ class Worker {
|
||||
}
|
||||
|
||||
async run(
|
||||
ws_ip,
|
||||
ws_port,
|
||||
wasm_to_server_port,
|
||||
wasm_to_verifier_port,
|
||||
wasm_to_native_port
|
||||
) {
|
||||
ws_ip,
|
||||
ws_port,
|
||||
wasm_to_server_port,
|
||||
wasm_to_verifier_port,
|
||||
wasm_to_native_port
|
||||
) {
|
||||
try {
|
||||
await wasm_main(
|
||||
ws_ip,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![cfg(target_arch = "wasm32")]
|
||||
|
||||
//! Contains the wasm component of the browser prover.
|
||||
//!
|
||||
//! Conceptually the browser prover consists of the native and the wasm
|
||||
@@ -9,13 +11,10 @@ use tlsn_benches_browser_core::{
|
||||
FramedIo,
|
||||
};
|
||||
use tlsn_benches_library::run_prover;
|
||||
pub use tlsn_wasm::init_logging;
|
||||
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
use wasm_bindgen::prelude::*;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
use web_time::Instant;
|
||||
use ws_stream_wasm::WsMeta;
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use futures::{future::join, AsyncReadExt as _, AsyncWriteExt as _};
|
||||
use futures::{future::try_join, AsyncReadExt as _, AsyncWriteExt as _, TryFutureExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
@@ -106,12 +106,14 @@ pub async fn run_prover(
|
||||
let mut response = vec![];
|
||||
mpc_tls_connection.read_to_end(&mut response).await?;
|
||||
|
||||
dbg!(response.len());
|
||||
|
||||
Ok::<(), anyhow::Error>(())
|
||||
};
|
||||
|
||||
let (prover_task, _) = join(prover_fut, tls_fut).await;
|
||||
let (prover_task, _) = try_join(prover_fut.map_err(anyhow::Error::from), tls_fut).await?;
|
||||
|
||||
let mut prover = prover_task?.start_prove();
|
||||
let mut prover = prover_task.start_prove();
|
||||
|
||||
let (sent_len, recv_len) = prover.transcript().len();
|
||||
prover
|
||||
|
||||
@@ -9,13 +9,21 @@ default = []
|
||||
|
||||
[dependencies]
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-cipher = { workspace = true }
|
||||
tlsn-utils = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-zk = { workspace = true }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
opaque-debug = { workspace = true }
|
||||
serio = { workspace = true, features = ["codec", "bincode"] }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
@@ -23,5 +31,9 @@ uid-mux = { workspace = true, features = ["serio"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
semver = { version = "1.0", features = ["serde"] }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
wasm-bindgen = { version = "0.2" }
|
||||
web-spawn = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = { workspace = true }
|
||||
|
||||
108
crates/common/src/commit.rs
Normal file
108
crates/common/src/commit.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
//! Plaintext commitment and proof of encryption.
|
||||
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{binary::Binary, DecodeFutureTyped};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
|
||||
use crate::{
|
||||
transcript::Record,
|
||||
zk_aes::{ZkAesCtr, ZkAesCtrError},
|
||||
Role,
|
||||
};
|
||||
|
||||
/// Commits the plaintext of the provided records, returning a proof of
|
||||
/// encryption.
|
||||
///
|
||||
/// Writes the plaintext VM reference to the provided records.
|
||||
pub fn commit_records<'record>(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
aes: &mut ZkAesCtr,
|
||||
records: impl IntoIterator<Item = &'record mut Record>,
|
||||
) -> Result<RecordProof, RecordProofError> {
|
||||
let mut ciphertexts = Vec::new();
|
||||
for record in records {
|
||||
if record.plaintext_ref.is_some() {
|
||||
return Err(ErrorRepr::PlaintextRefAlreadySet.into());
|
||||
}
|
||||
|
||||
let (plaintext_ref, ciphertext_ref) = aes
|
||||
.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())
|
||||
.map_err(ErrorRepr::Aes)?;
|
||||
|
||||
record.plaintext_ref = Some(plaintext_ref);
|
||||
|
||||
if let Role::Prover = aes.role() {
|
||||
let Some(plaintext) = record.plaintext.clone() else {
|
||||
return Err(ErrorRepr::MissingPlaintext.into());
|
||||
};
|
||||
|
||||
vm.assign(plaintext_ref, plaintext)
|
||||
.map_err(RecordProofError::vm)?;
|
||||
}
|
||||
vm.commit(plaintext_ref).map_err(RecordProofError::vm)?;
|
||||
|
||||
let ciphertext = vm.decode(ciphertext_ref).map_err(RecordProofError::vm)?;
|
||||
ciphertexts.push((ciphertext, record.ciphertext.clone()));
|
||||
}
|
||||
|
||||
Ok(RecordProof { ciphertexts })
|
||||
}
|
||||
|
||||
/// Proof of encryption.
|
||||
#[derive(Debug)]
|
||||
#[must_use]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct RecordProof {
|
||||
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
|
||||
}
|
||||
|
||||
impl RecordProof {
|
||||
/// Verifies the proof.
|
||||
pub fn verify(self) -> Result<(), RecordProofError> {
|
||||
let Self { ciphertexts } = self;
|
||||
|
||||
for (mut ciphertext, expected) in ciphertexts {
|
||||
let ciphertext = ciphertext
|
||||
.try_recv()
|
||||
.map_err(RecordProofError::vm)?
|
||||
.ok_or_else(|| ErrorRepr::NotDecoded)?;
|
||||
|
||||
if ciphertext != expected {
|
||||
return Err(ErrorRepr::InvalidCiphertext.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`RecordProof`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct RecordProofError(#[from] ErrorRepr);
|
||||
|
||||
impl RecordProofError {
|
||||
fn vm<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("record proof error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("VM error: {0}")]
|
||||
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("zk aes error: {0}")]
|
||||
Aes(ZkAesCtrError),
|
||||
#[error("plaintext is missing")]
|
||||
MissingPlaintext,
|
||||
#[error("plaintext reference is already set")]
|
||||
PlaintextRefAlreadySet,
|
||||
#[error("ciphertext was not decoded")]
|
||||
NotDecoded,
|
||||
#[error("ciphertext does not match expected")]
|
||||
InvalidCiphertext,
|
||||
}
|
||||
@@ -5,16 +5,8 @@ use semver::Version;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::Role;
|
||||
|
||||
// Extra cushion room, eg. for sharing J0 blocks.
|
||||
const EXTRA_OTS: usize = 16384;
|
||||
|
||||
const OTS_PER_BYTE_SENT: usize = 8;
|
||||
|
||||
// Without deferred decryption we use 16, with it we use 8.
|
||||
const OTS_PER_BYTE_RECV_ONLINE: usize = 16;
|
||||
const OTS_PER_BYTE_RECV_DEFER: usize = 8;
|
||||
// Default is 32 bytes to decrypt the TLS protocol messages.
|
||||
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
|
||||
|
||||
// Current version that is running.
|
||||
static VERSION: Lazy<Version> = Lazy::new(|| {
|
||||
@@ -31,7 +23,7 @@ pub struct ProtocolConfig {
|
||||
max_sent_data: usize,
|
||||
/// Maximum number of bytes that can be decrypted online, i.e. while the
|
||||
/// MPC-TLS connection is active.
|
||||
#[builder(default = "0")]
|
||||
#[builder(default = "DEFAULT_MAX_RECV_ONLINE")]
|
||||
max_recv_data_online: usize,
|
||||
/// Maximum number of bytes that can be received.
|
||||
max_recv_data: usize,
|
||||
@@ -71,26 +63,6 @@ impl ProtocolConfig {
|
||||
pub fn max_recv_data(&self) -> usize {
|
||||
self.max_recv_data
|
||||
}
|
||||
|
||||
/// Returns OT sender setup count.
|
||||
pub fn ot_sender_setup_count(&self, role: Role) -> usize {
|
||||
ot_send_estimate(
|
||||
role,
|
||||
self.max_sent_data,
|
||||
self.max_recv_data_online,
|
||||
self.max_recv_data,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns OT receiver setup count.
|
||||
pub fn ot_receiver_setup_count(&self, role: Role) -> usize {
|
||||
ot_recv_estimate(
|
||||
role,
|
||||
self.max_sent_data,
|
||||
self.max_recv_data_online,
|
||||
self.max_recv_data,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Protocol configuration validator used by checker (i.e. verifier) to perform
|
||||
@@ -222,42 +194,6 @@ enum ErrorKind {
|
||||
Version,
|
||||
}
|
||||
|
||||
/// Returns an estimate of the number of OTs that will be sent.
|
||||
pub fn ot_send_estimate(
|
||||
role: Role,
|
||||
max_sent_data: usize,
|
||||
max_recv_data_online: usize,
|
||||
max_recv_data: usize,
|
||||
) -> usize {
|
||||
match role {
|
||||
Role::Prover => EXTRA_OTS,
|
||||
Role::Verifier => {
|
||||
EXTRA_OTS
|
||||
+ (max_sent_data * OTS_PER_BYTE_SENT)
|
||||
+ (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE)
|
||||
+ ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an estimate of the number of OTs that will be received.
|
||||
pub fn ot_recv_estimate(
|
||||
role: Role,
|
||||
max_sent_data: usize,
|
||||
max_recv_data_online: usize,
|
||||
max_recv_data: usize,
|
||||
) -> usize {
|
||||
match role {
|
||||
Role::Prover => {
|
||||
EXTRA_OTS
|
||||
+ (max_sent_data * OTS_PER_BYTE_SENT)
|
||||
+ (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE)
|
||||
+ ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER)
|
||||
}
|
||||
Role::Verifier => EXTRA_OTS,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
21
crates/common/src/context.rs
Normal file
21
crates/common/src/context.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
//! Execution context.
|
||||
|
||||
use mpz_common::context::Multithread;
|
||||
|
||||
use crate::mux::MuxControl;
|
||||
|
||||
/// Maximum concurrency for multi-threaded context.
|
||||
pub const MAX_CONCURRENCY: usize = 8;
|
||||
|
||||
/// Builds a multi-threaded context with the given muxer.
|
||||
pub fn build_mt_context(mux: MuxControl) -> Multithread {
|
||||
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let builder = builder.spawn_handler(|f| {
|
||||
let _ = web_spawn::spawn(f);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
builder.build().unwrap()
|
||||
}
|
||||
173
crates/common/src/encoding.rs
Normal file
173
crates/common/src/encoding.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
//! Encoding commitment protocol.
|
||||
|
||||
use mpz_common::Context;
|
||||
use mpz_core::Block;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{stream::IoStreamExt, SinkExt};
|
||||
use tlsn_core::transcript::{
|
||||
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider},
|
||||
Direction, Idx,
|
||||
};
|
||||
|
||||
/// Bytes of encoding, per byte.
|
||||
const ENCODING_SIZE: usize = 128;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Encodings {
|
||||
sent: Vec<u8>,
|
||||
recv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Transfers the encodings using the provided seed and keys.
|
||||
pub async fn transfer(
|
||||
ctx: &mut Context,
|
||||
secret: &EncoderSecret,
|
||||
sent_keys: impl IntoIterator<Item = &'_ Block>,
|
||||
recv_keys: impl IntoIterator<Item = &'_ Block>,
|
||||
) -> Result<(), EncodingError> {
|
||||
let encoder = new_encoder(secret);
|
||||
|
||||
let sent_keys: Vec<u8> = sent_keys
|
||||
.into_iter()
|
||||
.flat_map(|key| key.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
let recv_keys: Vec<u8> = recv_keys
|
||||
.into_iter()
|
||||
.flat_map(|key| key.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
|
||||
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
|
||||
|
||||
let mut sent_encoding = encoder.encode_idx(
|
||||
Direction::Sent,
|
||||
&Idx::new(0..sent_keys.len() / ENCODING_SIZE),
|
||||
);
|
||||
let mut recv_encoding = encoder.encode_idx(
|
||||
Direction::Received,
|
||||
&Idx::new(0..recv_keys.len() / ENCODING_SIZE),
|
||||
);
|
||||
|
||||
sent_encoding
|
||||
.iter_mut()
|
||||
.zip(sent_keys)
|
||||
.for_each(|(enc, key)| *enc ^= key);
|
||||
recv_encoding
|
||||
.iter_mut()
|
||||
.zip(recv_keys)
|
||||
.for_each(|(enc, key)| *enc ^= key);
|
||||
|
||||
ctx.io_mut()
|
||||
.send(Encodings {
|
||||
sent: sent_encoding,
|
||||
recv: recv_encoding,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receives the encodings using the provided MACs.
|
||||
pub async fn receive(
|
||||
ctx: &mut Context,
|
||||
sent_macs: impl IntoIterator<Item = &'_ Block>,
|
||||
recv_macs: impl IntoIterator<Item = &'_ Block>,
|
||||
) -> Result<impl EncodingProvider, EncodingError> {
|
||||
let Encodings { mut sent, mut recv } = ctx.io_mut().expect_next().await?;
|
||||
|
||||
let sent_macs: Vec<u8> = sent_macs
|
||||
.into_iter()
|
||||
.flat_map(|mac| mac.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
let recv_macs: Vec<u8> = recv_macs
|
||||
.into_iter()
|
||||
.flat_map(|mac| mac.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
|
||||
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
|
||||
|
||||
if sent.len() != sent_macs.len() {
|
||||
return Err(ErrorRepr::IncorrectMacCount {
|
||||
direction: Direction::Sent,
|
||||
expected: sent_macs.len(),
|
||||
got: sent.len(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if recv.len() != recv_macs.len() {
|
||||
return Err(ErrorRepr::IncorrectMacCount {
|
||||
direction: Direction::Received,
|
||||
expected: recv_macs.len(),
|
||||
got: recv.len(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
sent.iter_mut()
|
||||
.zip(sent_macs)
|
||||
.for_each(|(enc, mac)| *enc ^= mac);
|
||||
recv.iter_mut()
|
||||
.zip(recv_macs)
|
||||
.for_each(|(enc, mac)| *enc ^= mac);
|
||||
|
||||
Ok(Provider { sent, recv })
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Provider {
|
||||
sent: Vec<u8>,
|
||||
recv: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncodingProvider for Provider {
|
||||
fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option<Vec<u8>> {
|
||||
let encodings = match direction {
|
||||
Direction::Sent => &self.sent,
|
||||
Direction::Received => &self.recv,
|
||||
};
|
||||
|
||||
let mut encoding = Vec::with_capacity(idx.len());
|
||||
for range in idx.iter_ranges() {
|
||||
let start = range.start * ENCODING_SIZE;
|
||||
let end = range.end * ENCODING_SIZE;
|
||||
|
||||
if end > encodings.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
encoding.extend_from_slice(&encodings[start..end]);
|
||||
}
|
||||
|
||||
Some(encoding)
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoding protocol error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct EncodingError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("encoding protocol error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("I/O error: {0}")]
|
||||
Io(std::io::Error),
|
||||
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
|
||||
IncorrectMacCount {
|
||||
direction: Direction,
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for EncodingError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self(ErrorRepr::Io(value))
|
||||
}
|
||||
}
|
||||
@@ -4,34 +4,19 @@
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod commit;
|
||||
pub mod config;
|
||||
pub mod context;
|
||||
pub mod encoding;
|
||||
pub mod msg;
|
||||
pub mod mux;
|
||||
|
||||
use serio::codec::Codec;
|
||||
|
||||
use crate::mux::MuxControl;
|
||||
|
||||
/// IO type.
|
||||
pub type Io = <serio::codec::Bincode as Codec<uid_mux::yamux::Stream>>::Framed;
|
||||
/// Base OT sender.
|
||||
pub type BaseOTSender = mpz_ot::chou_orlandi::Sender;
|
||||
/// Base OT receiver.
|
||||
pub type BaseOTReceiver = mpz_ot::chou_orlandi::Receiver;
|
||||
/// OT sender.
|
||||
pub type OTSender = mpz_ot::kos::SharedSender<BaseOTReceiver>;
|
||||
/// OT receiver.
|
||||
pub type OTReceiver = mpz_ot::kos::SharedReceiver<BaseOTSender>;
|
||||
/// MPC executor.
|
||||
pub type Executor = mpz_common::executor::MTExecutor<MuxControl>;
|
||||
/// MPC thread context.
|
||||
pub type Context = mpz_common::executor::MTContext<MuxControl, Io>;
|
||||
/// DEAP thread.
|
||||
pub type DEAPThread = mpz_garble::protocol::deap::DEAPThread<Context, OTSender, OTReceiver>;
|
||||
pub mod transcript;
|
||||
pub mod zk_aes;
|
||||
|
||||
/// The party's role in the TLSN protocol.
|
||||
///
|
||||
/// A Notary is classified as a Verifier.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Role {
|
||||
/// The prover.
|
||||
Prover,
|
||||
|
||||
@@ -6,16 +6,15 @@ use futures::{
|
||||
future::{FusedFuture, FutureExt},
|
||||
AsyncRead, AsyncWrite, Future,
|
||||
};
|
||||
use serio::codec::Bincode;
|
||||
use tracing::error;
|
||||
use uid_mux::{yamux, FramedMux};
|
||||
use uid_mux::yamux;
|
||||
|
||||
use crate::Role;
|
||||
|
||||
/// Multiplexer supporting unique deterministic stream IDs.
|
||||
pub type Mux<Io> = yamux::Yamux<Io>;
|
||||
/// Multiplexer controller providing streams with a codec attached.
|
||||
pub type MuxControl = FramedMux<yamux::YamuxCtrl, Bincode>;
|
||||
/// Multiplexer controller providing streams.
|
||||
pub type MuxControl = yamux::YamuxCtrl;
|
||||
|
||||
/// Multiplexer future which must be polled for the muxer to make progress.
|
||||
pub struct MuxFuture(
|
||||
@@ -73,7 +72,7 @@ pub fn attach_mux<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
role: Role,
|
||||
) -> (MuxFuture, MuxControl) {
|
||||
let mut mux_config = yamux::Config::default();
|
||||
mux_config.set_max_num_streams(64);
|
||||
mux_config.set_max_num_streams(32);
|
||||
|
||||
let mux_role = match role {
|
||||
Role::Prover => yamux::Mode::Client,
|
||||
@@ -81,10 +80,10 @@ pub fn attach_mux<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
};
|
||||
|
||||
let mux = Mux::new(socket, mux_config, mux_role);
|
||||
let ctrl = FramedMux::new(mux.control(), Bincode);
|
||||
let ctrl = mux.control();
|
||||
|
||||
if let Role::Prover = role {
|
||||
ctrl.mux().alloc(64);
|
||||
ctrl.alloc(32);
|
||||
}
|
||||
|
||||
(MuxFuture(Box::new(mux.into_future().fuse())), ctrl)
|
||||
|
||||
169
crates/common/src/transcript.rs
Normal file
169
crates/common/src/transcript.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
//! TLS transcript.
|
||||
|
||||
use mpz_memory_core::{binary::U8, Vector};
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_core::transcript::{Direction, Idx, Transcript};
|
||||
use utils::range::Intersection;
|
||||
|
||||
/// A transcript of sent and received TLS records.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TlsTranscript {
|
||||
/// Records sent by the prover.
|
||||
pub sent: Vec<Record>,
|
||||
/// Records received by the prover.
|
||||
pub recv: Vec<Record>,
|
||||
}
|
||||
|
||||
impl TlsTranscript {
|
||||
/// Returns the application data transcript.
|
||||
pub fn to_transcript(&self) -> Result<Transcript, IncompleteTranscript> {
|
||||
let mut sent = Vec::new();
|
||||
let mut recv = Vec::new();
|
||||
|
||||
for record in self
|
||||
.sent
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData)
|
||||
{
|
||||
let plaintext = record
|
||||
.plaintext
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?
|
||||
.clone();
|
||||
sent.extend_from_slice(&plaintext);
|
||||
}
|
||||
|
||||
for record in self
|
||||
.recv
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData)
|
||||
{
|
||||
let plaintext = record
|
||||
.plaintext
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?
|
||||
.clone();
|
||||
recv.extend_from_slice(&plaintext);
|
||||
}
|
||||
|
||||
Ok(Transcript::new(sent, recv))
|
||||
}
|
||||
|
||||
/// Returns the application data transcript references.
|
||||
pub fn to_transcript_refs(&self) -> Result<TranscriptRefs, IncompleteTranscript> {
|
||||
let mut sent = Vec::new();
|
||||
let mut recv = Vec::new();
|
||||
|
||||
for record in self
|
||||
.sent
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData)
|
||||
{
|
||||
let plaintext_ref = record
|
||||
.plaintext_ref
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?;
|
||||
sent.push(*plaintext_ref);
|
||||
}
|
||||
|
||||
for record in self
|
||||
.recv
|
||||
.iter()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData)
|
||||
{
|
||||
let plaintext_ref = record
|
||||
.plaintext_ref
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?;
|
||||
recv.push(*plaintext_ref);
|
||||
}
|
||||
|
||||
Ok(TranscriptRefs { sent, recv })
|
||||
}
|
||||
}
|
||||
|
||||
/// A TLS record.
|
||||
#[derive(Clone)]
|
||||
pub struct Record {
|
||||
/// Sequence number.
|
||||
pub seq: u64,
|
||||
/// Content type.
|
||||
pub typ: ContentType,
|
||||
/// Plaintext.
|
||||
pub plaintext: Option<Vec<u8>>,
|
||||
/// VM reference to the plaintext.
|
||||
pub plaintext_ref: Option<Vector<U8>>,
|
||||
/// Explicit nonce.
|
||||
pub explicit_nonce: Vec<u8>,
|
||||
/// Ciphertext.
|
||||
pub ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Record);
|
||||
|
||||
/// References to the application plaintext in the transcript.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TranscriptRefs {
|
||||
sent: Vec<Vector<U8>>,
|
||||
recv: Vec<Vector<U8>>,
|
||||
}
|
||||
|
||||
impl TranscriptRefs {
|
||||
/// Returns the sent plaintext references.
|
||||
pub fn sent(&self) -> &[Vector<U8>] {
|
||||
&self.sent
|
||||
}
|
||||
|
||||
/// Returns the received plaintext references.
|
||||
pub fn recv(&self) -> &[Vector<U8>] {
|
||||
&self.recv
|
||||
}
|
||||
|
||||
/// Returns VM references for the given direction and index, otherwise
|
||||
/// `None` if the index is out of bounds.
|
||||
pub fn get(&self, direction: Direction, idx: &Idx) -> Option<Vec<Vector<U8>>> {
|
||||
if idx.is_empty() {
|
||||
return Some(Vec::new());
|
||||
}
|
||||
|
||||
let refs = match direction {
|
||||
Direction::Sent => &self.sent,
|
||||
Direction::Received => &self.recv,
|
||||
};
|
||||
|
||||
// Computes the transcript range for each reference.
|
||||
let mut start = 0;
|
||||
let mut slice_iter = refs.iter().map(move |slice| {
|
||||
let out = (slice, start..start + slice.len());
|
||||
start += slice.len();
|
||||
out
|
||||
});
|
||||
|
||||
let mut slices = Vec::new();
|
||||
let (mut slice, mut slice_range) = slice_iter.next()?;
|
||||
for range in idx.iter_ranges() {
|
||||
loop {
|
||||
if let Some(intersection) = slice_range.intersection(&range) {
|
||||
let start = intersection.start - slice_range.start;
|
||||
let end = intersection.end - slice_range.start;
|
||||
slices.push(slice.get(start..end).expect("range should be in bounds"));
|
||||
}
|
||||
|
||||
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
|
||||
// to the next slice.
|
||||
if range.end <= slice_range.end {
|
||||
break;
|
||||
} else {
|
||||
(slice, slice_range) = slice_iter.next()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(slices)
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`TranscriptRefs::from_transcript`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("not all application plaintext was committed to in the TLS transcript")]
|
||||
pub struct IncompleteTranscript {}
|
||||
212
crates/common/src/zk_aes.rs
Normal file
212
crates/common/src/zk_aes.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Zero-knowledge AES-CTR encryption.
|
||||
|
||||
use cipher::{
|
||||
aes::{Aes128, AesError},
|
||||
Cipher, CipherError, Keystream,
|
||||
};
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
Array, Vector,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
|
||||
use crate::Role;
|
||||
|
||||
type Nonce = Array<U8, 8>;
|
||||
type Ctr = Array<U8, 4>;
|
||||
type Block = Array<U8, 16>;
|
||||
|
||||
const START_CTR: u32 = 2;
|
||||
|
||||
/// ZK AES-CTR encryption.
|
||||
#[derive(Debug)]
|
||||
pub struct ZkAesCtr {
|
||||
role: Role,
|
||||
aes: Aes128,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl ZkAesCtr {
|
||||
/// Creates a new ZK AES-CTR instance.
|
||||
pub fn new(role: Role) -> Self {
|
||||
Self {
|
||||
role,
|
||||
aes: Aes128::default(),
|
||||
state: State::Init,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the role.
|
||||
pub fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
|
||||
/// Allocates `len` bytes for encryption.
|
||||
pub fn alloc(&mut self, vm: &mut dyn Vm<Binary>, len: usize) -> Result<(), ZkAesCtrError> {
|
||||
let State::Init = self.state.take() else {
|
||||
Err(ErrorRepr::State {
|
||||
reason: "must be in init state to allocate",
|
||||
})?
|
||||
};
|
||||
|
||||
// Round up to the nearest block size.
|
||||
let len = 16 * len.div_ceil(16);
|
||||
|
||||
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
|
||||
let keystream = self.aes.alloc_keystream(vm, len)?;
|
||||
let output = keystream.apply(vm, input)?;
|
||||
|
||||
match self.role {
|
||||
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
|
||||
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
|
||||
}
|
||||
|
||||
self.state = State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the key and IV for the cipher.
|
||||
pub fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
|
||||
self.aes.set_key(key);
|
||||
self.aes.set_iv(iv);
|
||||
}
|
||||
|
||||
/// Proves the encryption of `len` bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `len` - Length of the plaintext in bytes.
|
||||
pub fn encrypt(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
} = &mut self.state
|
||||
else {
|
||||
Err(ErrorRepr::State {
|
||||
reason: "must be in ready state to encrypt",
|
||||
})?
|
||||
};
|
||||
|
||||
let explicit_nonce: [u8; 8] =
|
||||
explicit_nonce
|
||||
.try_into()
|
||||
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
|
||||
expected: 8,
|
||||
actual: explicit_nonce.len(),
|
||||
})?;
|
||||
|
||||
let block_count = len.div_ceil(16);
|
||||
let padded_len = block_count * 16;
|
||||
let padding_len = padded_len - len;
|
||||
|
||||
if padded_len > input.len() {
|
||||
Err(ErrorRepr::InsufficientPreprocessing {
|
||||
expected: padded_len,
|
||||
actual: input.len(),
|
||||
})?
|
||||
}
|
||||
|
||||
let mut input = input.split_off(input.len() - padded_len);
|
||||
let keystream = keystream.consume(padded_len)?;
|
||||
let mut output = output.split_off(output.len() - padded_len);
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
keystream.assign(vm, explicit_nonce, move || {
|
||||
ctr.next().expect("range is unbounded").to_be_bytes()
|
||||
})?;
|
||||
|
||||
// Assign zeroes to the padding.
|
||||
if padding_len > 0 {
|
||||
let padding = input.split_off(input.len() - padding_len);
|
||||
if let Role::Prover = self.role {
|
||||
vm.assign(padding, vec![0; padding_len])
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
}
|
||||
vm.commit(padding).map_err(ZkAesCtrError::vm)?;
|
||||
output.truncate(len);
|
||||
}
|
||||
|
||||
Ok((input, output))
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
Init,
|
||||
Ready {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
output: Vector<U8>,
|
||||
},
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for State {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
State::Init => write!(f, "Init"),
|
||||
State::Ready { .. } => write!(f, "Ready"),
|
||||
State::Error => write!(f, "Error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`ZkAesCtr`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ZkAesCtrError(#[from] ErrorRepr);
|
||||
|
||||
impl ZkAesCtrError {
|
||||
fn vm<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("zk aes error")]
|
||||
enum ErrorRepr {
|
||||
#[error("invalid state: {reason}")]
|
||||
State { reason: &'static str },
|
||||
#[error("cipher error: {0}")]
|
||||
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("vm error: {0}")]
|
||||
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
|
||||
ExplicitNonceLength { expected: usize, actual: usize },
|
||||
#[error("insufficient preprocessing: expected {expected}, got {actual}")]
|
||||
InsufficientPreprocessing { expected: usize, actual: usize },
|
||||
}
|
||||
|
||||
impl From<AesError> for ZkAesCtrError {
|
||||
fn from(err: AesError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CipherError> for ZkAesCtrError {
|
||||
fn from(err: CipherError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
[package]
|
||||
name = "tlsn-aead"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "This crate provides an implementation of a two-party version of AES-GCM behind an AEAD trait"
|
||||
keywords = ["tls", "mpc", "2pc", "aead", "aes", "aes-gcm"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.8-pre"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "aead"
|
||||
|
||||
[features]
|
||||
default = ["mock"]
|
||||
mock = ["mpz-common/test-utils", "dep:mpz-ot"]
|
||||
|
||||
[dependencies]
|
||||
tlsn-block-cipher = { workspace = true }
|
||||
tlsn-stream-cipher = { workspace = true }
|
||||
tlsn-universal-hash = { workspace = true }
|
||||
|
||||
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", optional = true, features = [
|
||||
"ideal",
|
||||
] }
|
||||
|
||||
serio = { workspace = true }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
aes-gcm = { workspace = true }
|
||||
@@ -1,36 +0,0 @@
|
||||
use derive_builder::Builder;
|
||||
|
||||
/// Protocol role.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum Role {
|
||||
Leader,
|
||||
Follower,
|
||||
}
|
||||
|
||||
/// Configuration for AES-GCM.
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
pub struct AesGcmConfig {
|
||||
/// The id of this instance.
|
||||
#[builder(setter(into))]
|
||||
id: String,
|
||||
/// The protocol role.
|
||||
role: Role,
|
||||
}
|
||||
|
||||
impl AesGcmConfig {
|
||||
/// Creates a new builder for the AES-GCM configuration.
|
||||
pub fn builder() -> AesGcmConfigBuilder {
|
||||
AesGcmConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Returns the id of this instance.
|
||||
pub fn id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
|
||||
/// Returns the protocol role.
|
||||
pub fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
/// AES-GCM error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct AesGcmError {
|
||||
kind: ErrorKind,
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl AesGcmError {
|
||||
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn kind(&self) -> ErrorKind {
|
||||
self.kind
|
||||
}
|
||||
|
||||
pub(crate) fn invalid_tag() -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Tag,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn peer(reason: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::PeerMisbehaved,
|
||||
source: Some(reason.into().into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn payload(reason: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Payload,
|
||||
source: Some(reason.into().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) enum ErrorKind {
|
||||
Io,
|
||||
BlockCipher,
|
||||
StreamCipher,
|
||||
Ghash,
|
||||
Tag,
|
||||
PeerMisbehaved,
|
||||
Payload,
|
||||
}
|
||||
|
||||
impl Display for AesGcmError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::Io => write!(f, "io error")?,
|
||||
ErrorKind::BlockCipher => write!(f, "block cipher error")?,
|
||||
ErrorKind::StreamCipher => write!(f, "stream cipher error")?,
|
||||
ErrorKind::Ghash => write!(f, "ghash error")?,
|
||||
ErrorKind::Tag => write!(f, "payload has corrupted tag")?,
|
||||
ErrorKind::PeerMisbehaved => write!(f, "peer misbehaved")?,
|
||||
ErrorKind::Payload => write!(f, "payload error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
write!(f, " caused by: {}", source)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for AesGcmError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
Self::new(ErrorKind::Io, err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<block_cipher::BlockCipherError> for AesGcmError {
|
||||
fn from(err: block_cipher::BlockCipherError) -> Self {
|
||||
Self::new(ErrorKind::BlockCipher, err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tlsn_stream_cipher::StreamCipherError> for AesGcmError {
|
||||
fn from(err: tlsn_stream_cipher::StreamCipherError) -> Self {
|
||||
Self::new(ErrorKind::StreamCipher, err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tlsn_universal_hash::UniversalHashError> for AesGcmError {
|
||||
fn from(err: tlsn_universal_hash::UniversalHashError) -> Self {
|
||||
Self::new(ErrorKind::Ghash, err)
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
//! Mock implementation of AES-GCM for testing purposes.
|
||||
|
||||
use block_cipher::{BlockCipherConfig, MpcBlockCipher};
|
||||
use mpz_common::executor::{test_st_executor, STExecutor};
|
||||
use mpz_garble::protocol::deap::mock::{MockFollower, MockLeader};
|
||||
use mpz_ot::ideal::ot::ideal_ot;
|
||||
use serio::channel::MemoryDuplex;
|
||||
use tlsn_stream_cipher::{MpcStreamCipher, StreamCipherConfig};
|
||||
use tlsn_universal_hash::ghash::ideal_ghash;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Creates a mock AES-GCM pair.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` - The id of the AES-GCM instances.
|
||||
/// * `(leader, follower)` - The leader and follower vms.
|
||||
/// * `leader_config` - The configuration of the leader.
|
||||
/// * `follower_config` - The configuration of the follower.
|
||||
pub async fn create_mock_aes_gcm_pair(
|
||||
id: &str,
|
||||
(leader, follower): (MockLeader, MockFollower),
|
||||
leader_config: AesGcmConfig,
|
||||
follower_config: AesGcmConfig,
|
||||
) -> (
|
||||
MpcAesGcm<STExecutor<MemoryDuplex>>,
|
||||
MpcAesGcm<STExecutor<MemoryDuplex>>,
|
||||
) {
|
||||
let block_cipher_id = format!("{}/block_cipher", id);
|
||||
let (ctx_leader, ctx_follower) = test_st_executor(128);
|
||||
|
||||
let (leader_ot_send, follower_ot_recv) = ideal_ot();
|
||||
let (follower_ot_send, leader_ot_recv) = ideal_ot();
|
||||
|
||||
let block_leader = leader
|
||||
.new_thread(ctx_leader, leader_ot_send, leader_ot_recv)
|
||||
.unwrap();
|
||||
|
||||
let block_follower = follower
|
||||
.new_thread(ctx_follower, follower_ot_send, follower_ot_recv)
|
||||
.unwrap();
|
||||
|
||||
let leader_block_cipher = MpcBlockCipher::new(
|
||||
BlockCipherConfig::builder()
|
||||
.id(block_cipher_id.clone())
|
||||
.build()
|
||||
.unwrap(),
|
||||
block_leader,
|
||||
);
|
||||
let follower_block_cipher = MpcBlockCipher::new(
|
||||
BlockCipherConfig::builder()
|
||||
.id(block_cipher_id.clone())
|
||||
.build()
|
||||
.unwrap(),
|
||||
block_follower,
|
||||
);
|
||||
|
||||
let stream_cipher_id = format!("{}/stream_cipher", id);
|
||||
let leader_stream_cipher = MpcStreamCipher::new(
|
||||
StreamCipherConfig::builder()
|
||||
.id(stream_cipher_id.clone())
|
||||
.build()
|
||||
.unwrap(),
|
||||
leader,
|
||||
);
|
||||
let follower_stream_cipher = MpcStreamCipher::new(
|
||||
StreamCipherConfig::builder()
|
||||
.id(stream_cipher_id.clone())
|
||||
.build()
|
||||
.unwrap(),
|
||||
follower,
|
||||
);
|
||||
|
||||
let (ctx_a, ctx_b) = test_st_executor(128);
|
||||
let (leader_ghash, follower_ghash) = ideal_ghash(ctx_a, ctx_b);
|
||||
|
||||
let (ctx_a, ctx_b) = test_st_executor(128);
|
||||
let leader = MpcAesGcm::new(
|
||||
leader_config,
|
||||
ctx_a,
|
||||
Box::new(leader_block_cipher),
|
||||
Box::new(leader_stream_cipher),
|
||||
Box::new(leader_ghash),
|
||||
);
|
||||
|
||||
let follower = MpcAesGcm::new(
|
||||
follower_config,
|
||||
ctx_b,
|
||||
Box::new(follower_block_cipher),
|
||||
Box::new(follower_stream_cipher),
|
||||
Box::new(follower_ghash),
|
||||
);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
@@ -1,712 +0,0 @@
|
||||
//! This module provides an implementation of 2PC AES-GCM.
|
||||
|
||||
mod config;
|
||||
mod error;
|
||||
#[cfg(feature = "mock")]
|
||||
pub mod mock;
|
||||
mod tag;
|
||||
|
||||
pub use config::{AesGcmConfig, AesGcmConfigBuilder, AesGcmConfigBuilderError, Role};
|
||||
pub use error::AesGcmError;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use block_cipher::{Aes128, BlockCipher};
|
||||
use futures::TryFutureExt;
|
||||
use mpz_common::Context;
|
||||
use mpz_garble::value::ValueRef;
|
||||
use tlsn_stream_cipher::{Aes128Ctr, StreamCipher};
|
||||
use tlsn_universal_hash::UniversalHash;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
aes_gcm::tag::{compute_tag, verify_tag, TAG_LEN},
|
||||
Aead,
|
||||
};
|
||||
|
||||
/// MPC AES-GCM.
|
||||
pub struct MpcAesGcm<Ctx> {
|
||||
config: AesGcmConfig,
|
||||
ctx: Ctx,
|
||||
aes_block: Box<dyn BlockCipher<Aes128>>,
|
||||
aes_ctr: Box<dyn StreamCipher<Aes128Ctr>>,
|
||||
ghash: Box<dyn UniversalHash>,
|
||||
}
|
||||
|
||||
impl<Ctx> std::fmt::Debug for MpcAesGcm<Ctx> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("MpcAesGcm")
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Ctx: Context> MpcAesGcm<Ctx> {
|
||||
/// Creates a new instance of [`MpcAesGcm`].
|
||||
pub fn new(
|
||||
config: AesGcmConfig,
|
||||
context: Ctx,
|
||||
aes_block: Box<dyn BlockCipher<Aes128>>,
|
||||
aes_ctr: Box<dyn StreamCipher<Aes128Ctr>>,
|
||||
ghash: Box<dyn UniversalHash>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
ctx: context,
|
||||
aes_block,
|
||||
aes_ctr,
|
||||
ghash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<Ctx: Context> Aead for MpcAesGcm<Ctx> {
|
||||
type Error = AesGcmError;
|
||||
|
||||
#[instrument(level = "info", skip_all, err)]
|
||||
async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AesGcmError> {
|
||||
self.aes_block.set_key(key.clone());
|
||||
self.aes_ctr.set_key(key, iv);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "info", skip_all, err)]
|
||||
async fn decode_key_private(&mut self) -> Result<(), AesGcmError> {
|
||||
self.aes_ctr
|
||||
.decode_key_private()
|
||||
.await
|
||||
.map_err(AesGcmError::from)
|
||||
}
|
||||
|
||||
#[instrument(level = "info", skip_all, err)]
|
||||
async fn decode_key_blind(&mut self) -> Result<(), AesGcmError> {
|
||||
self.aes_ctr
|
||||
.decode_key_blind()
|
||||
.await
|
||||
.map_err(AesGcmError::from)
|
||||
}
|
||||
|
||||
fn set_transcript_id(&mut self, id: &str) {
|
||||
self.aes_ctr.set_transcript_id(id)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self), err)]
|
||||
async fn setup(&mut self) -> Result<(), AesGcmError> {
|
||||
self.ghash.setup().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self), err)]
|
||||
async fn preprocess(&mut self, len: usize) -> Result<(), AesGcmError> {
|
||||
futures::try_join!(
|
||||
// Preprocess the GHASH key block.
|
||||
self.aes_block
|
||||
.preprocess(block_cipher::Visibility::Public, 1)
|
||||
.map_err(AesGcmError::from),
|
||||
self.aes_ctr.preprocess(len).map_err(AesGcmError::from),
|
||||
self.ghash.preprocess().map_err(AesGcmError::from),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn start(&mut self) -> Result<(), AesGcmError> {
|
||||
let h_share = self.aes_block.encrypt_share(vec![0u8; 16]).await?;
|
||||
self.ghash.set_key(h_share).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
let ciphertext = self
|
||||
.aes_ctr
|
||||
.encrypt_public(explicit_nonce.clone(), plaintext)
|
||||
.await?;
|
||||
|
||||
let tag = compute_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
explicit_nonce,
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut payload = ciphertext;
|
||||
payload.extend(tag);
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
let ciphertext = self
|
||||
.aes_ctr
|
||||
.encrypt_private(explicit_nonce.clone(), plaintext)
|
||||
.await?;
|
||||
|
||||
let tag = compute_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
explicit_nonce,
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut payload = ciphertext;
|
||||
payload.extend(tag);
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext_len: usize,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
let ciphertext = self
|
||||
.aes_ctr
|
||||
.encrypt_blind(explicit_nonce.clone(), plaintext_len)
|
||||
.await?;
|
||||
|
||||
let tag = compute_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
explicit_nonce,
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut payload = ciphertext;
|
||||
payload.extend(tag);
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
mut payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
let purported_tag: [u8; TAG_LEN] = payload
|
||||
.split_off(payload.len() - TAG_LEN)
|
||||
.try_into()
|
||||
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
|
||||
let ciphertext = payload;
|
||||
|
||||
verify_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
*self.config.role(),
|
||||
explicit_nonce.clone(),
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
purported_tag,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext = self
|
||||
.aes_ctr
|
||||
.decrypt_public(explicit_nonce, ciphertext)
|
||||
.await?;
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
mut payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
let purported_tag: [u8; TAG_LEN] = payload
|
||||
.split_off(payload.len() - TAG_LEN)
|
||||
.try_into()
|
||||
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
|
||||
let ciphertext = payload;
|
||||
|
||||
verify_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
*self.config.role(),
|
||||
explicit_nonce.clone(),
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
purported_tag,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext = self
|
||||
.aes_ctr
|
||||
.decrypt_private(explicit_nonce, ciphertext)
|
||||
.await?;
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
mut payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<(), AesGcmError> {
|
||||
let purported_tag: [u8; TAG_LEN] = payload
|
||||
.split_off(payload.len() - TAG_LEN)
|
||||
.try_into()
|
||||
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
|
||||
let ciphertext = payload;
|
||||
|
||||
verify_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
*self.config.role(),
|
||||
explicit_nonce.clone(),
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
purported_tag,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.aes_ctr
|
||||
.decrypt_blind(explicit_nonce, ciphertext)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn verify_tag(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
mut payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<(), AesGcmError> {
|
||||
let purported_tag: [u8; TAG_LEN] = payload
|
||||
.split_off(payload.len() - TAG_LEN)
|
||||
.try_into()
|
||||
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
|
||||
let ciphertext = payload;
|
||||
|
||||
verify_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
*self.config.role(),
|
||||
explicit_nonce,
|
||||
ciphertext,
|
||||
aad,
|
||||
purported_tag,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn prove_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
mut payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
let purported_tag: [u8; TAG_LEN] = payload
|
||||
.split_off(payload.len() - TAG_LEN)
|
||||
.try_into()
|
||||
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
|
||||
let ciphertext = payload;
|
||||
|
||||
verify_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
*self.config.role(),
|
||||
explicit_nonce.clone(),
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
purported_tag,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext = self
|
||||
.aes_ctr
|
||||
.prove_plaintext(explicit_nonce, ciphertext)
|
||||
.await?;
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn prove_plaintext_no_tag(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, AesGcmError> {
|
||||
self.aes_ctr
|
||||
.prove_plaintext(explicit_nonce, ciphertext)
|
||||
.map_err(AesGcmError::from)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn verify_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
mut payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<(), AesGcmError> {
|
||||
let purported_tag: [u8; TAG_LEN] = payload
|
||||
.split_off(payload.len() - TAG_LEN)
|
||||
.try_into()
|
||||
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
|
||||
let ciphertext = payload;
|
||||
|
||||
verify_tag(
|
||||
&mut self.ctx,
|
||||
self.aes_ctr.as_mut(),
|
||||
self.ghash.as_mut(),
|
||||
*self.config.role(),
|
||||
explicit_nonce.clone(),
|
||||
ciphertext.clone(),
|
||||
aad,
|
||||
purported_tag,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.aes_ctr
|
||||
.verify_plaintext(explicit_nonce, ciphertext)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn verify_plaintext_no_tag(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<(), AesGcmError> {
|
||||
self.aes_ctr
|
||||
.verify_plaintext(explicit_nonce, ciphertext)
|
||||
.map_err(AesGcmError::from)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
aes_gcm::{mock::create_mock_aes_gcm_pair, AesGcmConfigBuilder, Role},
|
||||
Aead,
|
||||
};
|
||||
use ::aes_gcm::{aead::AeadInPlace, Aes128Gcm, NewAead, Nonce};
|
||||
use error::ErrorKind;
|
||||
use mpz_common::executor::STExecutor;
|
||||
use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory};
|
||||
use serio::channel::MemoryDuplex;
|
||||
|
||||
fn reference_impl(
|
||||
key: &[u8],
|
||||
iv: &[u8],
|
||||
explicit_nonce: &[u8],
|
||||
plaintext: &[u8],
|
||||
aad: &[u8],
|
||||
) -> Vec<u8> {
|
||||
let cipher = Aes128Gcm::new_from_slice(key).unwrap();
|
||||
let nonce = [iv, explicit_nonce].concat();
|
||||
let nonce = Nonce::from_slice(nonce.as_slice());
|
||||
|
||||
let mut ciphertext = plaintext.to_vec();
|
||||
cipher
|
||||
.encrypt_in_place(nonce, aad, &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
ciphertext
|
||||
}
|
||||
|
||||
async fn setup_pair(
|
||||
key: Vec<u8>,
|
||||
iv: Vec<u8>,
|
||||
) -> (
|
||||
MpcAesGcm<STExecutor<MemoryDuplex>>,
|
||||
MpcAesGcm<STExecutor<MemoryDuplex>>,
|
||||
) {
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
let leader_key = leader_vm
|
||||
.new_public_array_input::<u8>("key", key.len())
|
||||
.unwrap();
|
||||
let leader_iv = leader_vm
|
||||
.new_public_array_input::<u8>("iv", iv.len())
|
||||
.unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, key.clone()).unwrap();
|
||||
leader_vm.assign(&leader_iv, iv.clone()).unwrap();
|
||||
|
||||
let follower_key = follower_vm
|
||||
.new_public_array_input::<u8>("key", key.len())
|
||||
.unwrap();
|
||||
let follower_iv = follower_vm
|
||||
.new_public_array_input::<u8>("iv", iv.len())
|
||||
.unwrap();
|
||||
|
||||
follower_vm.assign(&follower_key, key.clone()).unwrap();
|
||||
follower_vm.assign(&follower_iv, iv.clone()).unwrap();
|
||||
|
||||
let leader_config = AesGcmConfigBuilder::default()
|
||||
.id("test".to_string())
|
||||
.role(Role::Leader)
|
||||
.build()
|
||||
.unwrap();
|
||||
let follower_config = AesGcmConfigBuilder::default()
|
||||
.id("test".to_string())
|
||||
.role(Role::Follower)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let (mut leader, mut follower) = create_mock_aes_gcm_pair(
|
||||
"test",
|
||||
(leader_vm, follower_vm),
|
||||
leader_config,
|
||||
follower_config,
|
||||
)
|
||||
.await;
|
||||
|
||||
futures::try_join!(
|
||||
leader.set_key(leader_key, leader_iv),
|
||||
follower.set_key(follower_key, follower_iv)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
futures::try_join!(leader.setup(), follower.setup()).unwrap();
|
||||
futures::try_join!(leader.start(), follower.start()).unwrap();
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_encrypt_private() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
|
||||
leader.encrypt_private(explicit_nonce.clone(), plaintext.clone(), aad.clone(),),
|
||||
follower.encrypt_blind(explicit_nonce.clone(), plaintext.len(), aad.clone())
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(leader_ciphertext, follower_ciphertext);
|
||||
assert_eq!(
|
||||
leader_ciphertext,
|
||||
reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_encrypt_public() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
|
||||
leader.encrypt_public(explicit_nonce.clone(), plaintext.clone(), aad.clone(),),
|
||||
follower.encrypt_public(explicit_nonce.clone(), plaintext.clone(), aad.clone(),)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(leader_ciphertext, follower_ciphertext);
|
||||
assert_eq!(
|
||||
leader_ciphertext,
|
||||
reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_decrypt_private() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
let (leader_plaintext, _) = tokio::try_join!(
|
||||
leader.decrypt_private(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
|
||||
follower.decrypt_blind(explicit_nonce.clone(), ciphertext, aad.clone(),)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(leader_plaintext, plaintext);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_decrypt_private_bad_tag() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
|
||||
|
||||
let len = ciphertext.len();
|
||||
|
||||
// corrupt tag
|
||||
let mut corrupted = ciphertext.clone();
|
||||
corrupted[len - 1] -= 1;
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
// leader receives corrupted tag
|
||||
let err = tokio::try_join!(
|
||||
leader.decrypt_private(explicit_nonce.clone(), corrupted.clone(), aad.clone(),),
|
||||
follower.decrypt_blind(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),)
|
||||
)
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), ErrorKind::Tag);
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
// follower receives corrupted tag
|
||||
let err = tokio::try_join!(
|
||||
leader.decrypt_private(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
|
||||
follower.decrypt_blind(explicit_nonce.clone(), corrupted.clone(), aad.clone(),)
|
||||
)
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), ErrorKind::Tag);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_decrypt_public() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
let (leader_plaintext, follower_plaintext) = tokio::try_join!(
|
||||
leader.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
|
||||
follower.decrypt_public(explicit_nonce.clone(), ciphertext, aad.clone(),)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(leader_plaintext, plaintext);
|
||||
assert_eq!(leader_plaintext, follower_plaintext);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_decrypt_public_bad_tag() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
|
||||
|
||||
let len = ciphertext.len();
|
||||
|
||||
// Corrupt tag.
|
||||
let mut corrupted = ciphertext.clone();
|
||||
corrupted[len - 1] -= 1;
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
// Leader receives corrupted tag.
|
||||
let err = tokio::try_join!(
|
||||
leader.decrypt_public(explicit_nonce.clone(), corrupted.clone(), aad.clone(),),
|
||||
follower.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),)
|
||||
)
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), ErrorKind::Tag);
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
// Follower receives corrupted tag.
|
||||
let err = tokio::try_join!(
|
||||
leader.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
|
||||
follower.decrypt_public(explicit_nonce.clone(), corrupted.clone(), aad.clone(),)
|
||||
)
|
||||
.unwrap_err();
|
||||
assert_eq!(err.kind(), ErrorKind::Tag);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_aes_gcm_verify_tag() {
|
||||
let key = vec![0u8; 16];
|
||||
let iv = vec![0u8; 4];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
let plaintext = vec![1u8; 32];
|
||||
let aad = vec![2u8; 12];
|
||||
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
|
||||
|
||||
let len = ciphertext.len();
|
||||
|
||||
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
|
||||
|
||||
tokio::try_join!(
|
||||
leader.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()),
|
||||
follower.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone())
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
//Corrupt tag.
|
||||
let mut corrupted = ciphertext.clone();
|
||||
corrupted[len - 1] -= 1;
|
||||
|
||||
let (leader_res, follower_res) = tokio::join!(
|
||||
leader.verify_tag(explicit_nonce.clone(), corrupted.clone(), aad.clone()),
|
||||
follower.verify_tag(explicit_nonce.clone(), corrupted, aad.clone())
|
||||
);
|
||||
|
||||
assert_eq!(leader_res.unwrap_err().kind(), ErrorKind::Tag);
|
||||
assert_eq!(follower_res.unwrap_err().kind(), ErrorKind::Tag);
|
||||
}
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
use futures::TryFutureExt;
|
||||
use mpz_common::Context;
|
||||
use mpz_core::{
|
||||
commit::{Decommitment, HashCommit},
|
||||
hash::Hash,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{stream::IoStreamExt, SinkExt};
|
||||
use std::ops::Add;
|
||||
use tlsn_stream_cipher::{Aes128Ctr, StreamCipher};
|
||||
use tlsn_universal_hash::UniversalHash;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::aes_gcm::{AesGcmError, Role};
|
||||
|
||||
pub(crate) const TAG_LEN: usize = 16;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct TagShare([u8; TAG_LEN]);
|
||||
|
||||
impl AsRef<[u8]> for TagShare {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for TagShare {
|
||||
type Output = [u8; TAG_LEN];
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
core::array::from_fn(|i| self.0[i] ^ rhs.0[i])
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
async fn compute_tag_share<C: StreamCipher<Aes128Ctr> + ?Sized, H: UniversalHash + ?Sized>(
|
||||
aes_ctr: &mut C,
|
||||
hasher: &mut H,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<TagShare, AesGcmError> {
|
||||
let (j0, hash) = futures::try_join!(
|
||||
aes_ctr
|
||||
.share_keystream_block(explicit_nonce, 1)
|
||||
.map_err(AesGcmError::from),
|
||||
hasher
|
||||
.finalize(build_ghash_data(aad, ciphertext))
|
||||
.map_err(AesGcmError::from)
|
||||
)?;
|
||||
|
||||
debug_assert!(j0.len() == TAG_LEN);
|
||||
debug_assert!(hash.len() == TAG_LEN);
|
||||
|
||||
let tag_share = core::array::from_fn(|i| j0[i] ^ hash[i]);
|
||||
|
||||
Ok(TagShare(tag_share))
|
||||
}
|
||||
|
||||
/// Computes the tag for a ciphertext and additional data.
|
||||
///
|
||||
/// The commit-reveal step is not required for computing a tag sent to the
|
||||
/// Server, as it will be able to detect if the tag is incorrect.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn compute_tag<
|
||||
Ctx: Context,
|
||||
C: StreamCipher<Aes128Ctr> + ?Sized,
|
||||
H: UniversalHash + ?Sized,
|
||||
>(
|
||||
ctx: &mut Ctx,
|
||||
aes_ctr: &mut C,
|
||||
hasher: &mut H,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<[u8; TAG_LEN], AesGcmError> {
|
||||
let tag_share = compute_tag_share(aes_ctr, hasher, explicit_nonce, ciphertext, aad).await?;
|
||||
|
||||
// TODO: The follower doesn't really need to learn the tag,
|
||||
// we could reduce some latency by not sending it.
|
||||
let io = ctx.io_mut();
|
||||
io.send(tag_share.clone()).await?;
|
||||
let other_tag_share: TagShare = io.expect_next().await?;
|
||||
|
||||
let tag = tag_share + other_tag_share;
|
||||
|
||||
Ok(tag)
|
||||
}
|
||||
|
||||
/// Verifies a purported tag against the ciphertext and additional data.
|
||||
///
|
||||
/// Verifying a tag requires a commit-reveal protocol between the leader and
|
||||
/// follower. Without it, the party which receives the other's tag share first
|
||||
/// could trivially compute a tag share which would cause an invalid message to
|
||||
/// be accepted.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn verify_tag<
|
||||
Ctx: Context,
|
||||
C: StreamCipher<Aes128Ctr> + ?Sized,
|
||||
H: UniversalHash + ?Sized,
|
||||
>(
|
||||
ctx: &mut Ctx,
|
||||
aes_ctr: &mut C,
|
||||
hasher: &mut H,
|
||||
role: Role,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
purported_tag: [u8; TAG_LEN],
|
||||
) -> Result<(), AesGcmError> {
|
||||
let tag_share = compute_tag_share(aes_ctr, hasher, explicit_nonce, ciphertext, aad).await?;
|
||||
|
||||
let io = ctx.io_mut();
|
||||
let tag = match role {
|
||||
Role::Leader => {
|
||||
// Send commitment of tag share to follower.
|
||||
let (tag_share_decommitment, tag_share_commitment) = tag_share.clone().hash_commit();
|
||||
|
||||
io.send(tag_share_commitment).await?;
|
||||
|
||||
let follower_tag_share: TagShare = io.expect_next().await?;
|
||||
|
||||
// Send decommitment (tag share) to follower.
|
||||
io.send(tag_share_decommitment).await?;
|
||||
|
||||
tag_share + follower_tag_share
|
||||
}
|
||||
Role::Follower => {
|
||||
// Wait for commitment from leader.
|
||||
let commitment: Hash = io.expect_next().await?;
|
||||
|
||||
// Send tag share to leader.
|
||||
io.send(tag_share.clone()).await?;
|
||||
|
||||
// Expect decommitment (tag share) from leader.
|
||||
let decommitment: Decommitment<TagShare> = io.expect_next().await?;
|
||||
|
||||
// Verify decommitment.
|
||||
decommitment.verify(&commitment).map_err(|_| {
|
||||
AesGcmError::peer("leader tag share commitment verification failed")
|
||||
})?;
|
||||
|
||||
let leader_tag_share = decommitment.into_inner();
|
||||
|
||||
tag_share + leader_tag_share
|
||||
}
|
||||
};
|
||||
|
||||
// Reject if tag is incorrect.
|
||||
if tag != purported_tag {
|
||||
return Err(AesGcmError::invalid_tag());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Builds padded data for GHASH.
|
||||
fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
|
||||
let associated_data_bitlen = (aad.len() as u64) * 8;
|
||||
let text_bitlen = (ciphertext.len() as u64) * 8;
|
||||
|
||||
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
|
||||
|
||||
// Pad data to be a multiple of 16 bytes.
|
||||
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
|
||||
aad.resize(aad_padded_block_count * 16, 0);
|
||||
|
||||
let ciphertext_padded_block_count =
|
||||
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
|
||||
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
|
||||
|
||||
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
|
||||
data.extend(aad);
|
||||
data.extend(ciphertext);
|
||||
data.extend_from_slice(&len_block.to_be_bytes());
|
||||
|
||||
data
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
//! This crate provides implementations of 2PC AEADs for authenticated
|
||||
//! encryption with a shared key.
|
||||
//!
|
||||
//! Both parties can work together to encrypt and decrypt messages with
|
||||
//! different visibility configurations. See [`Aead`] for more information on
|
||||
//! the interface.
|
||||
//!
|
||||
//! For example, one party can privately provide the plaintext to encrypt, while
|
||||
//! both parties can see the ciphertext and the tag. Or, both parties can
|
||||
//! cooperate to decrypt a ciphertext and verify the tag, while only one party
|
||||
//! can see the plaintext.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod aes_gcm;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_garble::value::ValueRef;
|
||||
|
||||
/// This trait defines the interface for AEADs.
|
||||
#[async_trait]
|
||||
pub trait Aead: Send {
|
||||
/// The error type for the AEAD.
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Sets the key for the AEAD.
|
||||
async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), Self::Error>;
|
||||
|
||||
/// Decodes the key for the AEAD, revealing it to this party.
|
||||
async fn decode_key_private(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Decodes the key for the AEAD, revealing it to the other party(s).
|
||||
async fn decode_key_blind(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Sets the transcript id.
|
||||
///
|
||||
/// The AEAD assigns unique identifiers to each byte of plaintext
|
||||
/// during encryption and decryption.
|
||||
///
|
||||
/// For example, if the transcript id is set to `foo`, then the first byte
|
||||
/// will be assigned the id `foo/0`, the second byte `foo/1`, and so on.
|
||||
///
|
||||
/// Each transcript id has an independent counter.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The state of a transcript counter is preserved between calls to
|
||||
/// `set_transcript_id`.
|
||||
fn set_transcript_id(&mut self, id: &str);
|
||||
|
||||
/// Performs any necessary one-time setup for the AEAD.
|
||||
async fn setup(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Preprocesses for the given number of bytes.
|
||||
async fn preprocess(&mut self, len: usize) -> Result<(), Self::Error>;
|
||||
|
||||
/// Starts the AEAD.
|
||||
///
|
||||
/// This method performs initialization for the AEAD after setting the key.
|
||||
async fn start(&mut self) -> Result<(), Self::Error>;
|
||||
|
||||
/// Encrypts a plaintext message, returning the ciphertext and tag.
|
||||
///
|
||||
/// The plaintext is provided by both parties.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for encryption.
|
||||
/// * `plaintext` - The plaintext to encrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn encrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Encrypts a plaintext message, hiding it from the other party, returning
|
||||
/// the ciphertext and tag.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for encryption.
|
||||
/// * `plaintext` - The plaintext to encrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn encrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Encrypts a plaintext message provided by the other party, returning
|
||||
/// the ciphertext and tag.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for encryption.
|
||||
/// * `plaintext_len` - The length of the plaintext to encrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn encrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext_len: usize,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Decrypts a ciphertext message, returning the plaintext to both parties.
|
||||
///
|
||||
/// This method checks the authenticity of the ciphertext, tag and
|
||||
/// additional data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for decryption.
|
||||
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn decrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Decrypts a ciphertext message, returning the plaintext only to this
|
||||
/// party.
|
||||
///
|
||||
/// This method checks the authenticity of the ciphertext, tag and
|
||||
/// additional data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for decryption.
|
||||
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn decrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Decrypts a ciphertext message, returning the plaintext only to the other
|
||||
/// party.
|
||||
///
|
||||
/// This method checks the authenticity of the ciphertext, tag and
|
||||
/// additional data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for decryption.
|
||||
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn decrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
/// Verifies the tag of a ciphertext message.
|
||||
///
|
||||
/// This method checks the authenticity of the ciphertext, tag and
|
||||
/// additional data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for decryption.
|
||||
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn verify_tag(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
/// Locally decrypts the provided ciphertext and then proves in ZK to the
|
||||
/// other party(s) that the plaintext is correct.
|
||||
///
|
||||
/// Returns the plaintext.
|
||||
///
|
||||
/// This method requires this party to know the encryption key, which can be
|
||||
/// achieved by calling the `decode_key_private` method.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `payload` - The ciphertext and tag to decrypt and prove.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn prove_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Locally decrypts the provided ciphertext and then proves in ZK to the
|
||||
/// other party(s) that the plaintext is correct.
|
||||
///
|
||||
/// Returns the plaintext.
|
||||
///
|
||||
/// This method requires this party to know the encryption key, which can be
|
||||
/// achieved by calling the `decode_key_private` method.
|
||||
///
|
||||
/// # WARNING
|
||||
///
|
||||
/// This method does not verify the tag of the ciphertext. Only use this if
|
||||
/// you know what you're doing.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to decrypt and prove.
|
||||
async fn prove_plaintext_no_tag(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Verifies the other party(s) can prove they know a plaintext which
|
||||
/// encrypts to the given ciphertext.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `payload` - The ciphertext and tag to verify.
|
||||
/// * `aad` - Additional authenticated data.
|
||||
async fn verify_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
payload: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
/// Verifies the other party(s) can prove they know a plaintext which
|
||||
/// encrypts to the given ciphertext.
|
||||
///
|
||||
/// # WARNING
|
||||
///
|
||||
/// This method does not verify the tag of the ciphertext. Only use this if
|
||||
/// you know what you're doing.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to verify.
|
||||
async fn verify_plaintext_no_tag(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<(), Self::Error>;
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
[package]
|
||||
name = "tlsn-block-cipher"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "2PC block cipher implementation"
|
||||
keywords = ["tls", "mpc", "2pc", "block-cipher"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.8-pre"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "block_cipher"
|
||||
|
||||
[features]
|
||||
default = ["mock"]
|
||||
mock = []
|
||||
|
||||
[dependencies]
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
tlsn-utils = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
aes = { workspace = true }
|
||||
cipher = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
@@ -1,277 +0,0 @@
|
||||
use std::{collections::VecDeque, marker::PhantomData};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Load, Memory};
|
||||
use tracing::instrument;
|
||||
use utils::id::NestedId;
|
||||
|
||||
use crate::{BlockCipher, BlockCipherCircuit, BlockCipherConfig, BlockCipherError, Visibility};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
private_execution_id: NestedId,
|
||||
public_execution_id: NestedId,
|
||||
preprocessed_private: VecDeque<BlockVars>,
|
||||
preprocessed_public: VecDeque<BlockVars>,
|
||||
key: Option<ValueRef>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BlockVars {
|
||||
msg: ValueRef,
|
||||
ciphertext: ValueRef,
|
||||
}
|
||||
|
||||
/// An MPC block cipher.
|
||||
#[derive(Debug)]
|
||||
pub struct MpcBlockCipher<C, E>
|
||||
where
|
||||
C: BlockCipherCircuit,
|
||||
E: Memory + Execute + Decode + DecodePrivate + Send + Sync,
|
||||
{
|
||||
state: State,
|
||||
|
||||
executor: E,
|
||||
|
||||
_cipher: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C, E> MpcBlockCipher<C, E>
|
||||
where
|
||||
C: BlockCipherCircuit,
|
||||
E: Memory + Execute + Decode + DecodePrivate + Send + Sync,
|
||||
{
|
||||
/// Creates a new MPC block cipher.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - The configuration for the block cipher.
|
||||
/// * `executor` - The executor to use for the MPC.
|
||||
pub fn new(config: BlockCipherConfig, executor: E) -> Self {
|
||||
let private_execution_id = NestedId::new(&config.id)
|
||||
.append_string("private")
|
||||
.append_counter();
|
||||
let public_execution_id = NestedId::new(&config.id)
|
||||
.append_string("public")
|
||||
.append_counter();
|
||||
Self {
|
||||
state: State {
|
||||
private_execution_id,
|
||||
public_execution_id,
|
||||
preprocessed_private: VecDeque::new(),
|
||||
preprocessed_public: VecDeque::new(),
|
||||
key: None,
|
||||
},
|
||||
executor,
|
||||
_cipher: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn define_block(&mut self, vis: Visibility) -> BlockVars {
|
||||
let (id, msg) = match vis {
|
||||
Visibility::Private => {
|
||||
let id = self
|
||||
.state
|
||||
.private_execution_id
|
||||
.increment_in_place()
|
||||
.to_string();
|
||||
let msg = self
|
||||
.executor
|
||||
.new_private_input::<C::BLOCK>(&format!("{}/msg", &id))
|
||||
.expect("message is not defined");
|
||||
(id, msg)
|
||||
}
|
||||
Visibility::Blind => {
|
||||
let id = self
|
||||
.state
|
||||
.private_execution_id
|
||||
.increment_in_place()
|
||||
.to_string();
|
||||
let msg = self
|
||||
.executor
|
||||
.new_blind_input::<C::BLOCK>(&format!("{}/msg", &id))
|
||||
.expect("message is not defined");
|
||||
(id, msg)
|
||||
}
|
||||
Visibility::Public => {
|
||||
let id = self
|
||||
.state
|
||||
.public_execution_id
|
||||
.increment_in_place()
|
||||
.to_string();
|
||||
let msg = self
|
||||
.executor
|
||||
.new_public_input::<C::BLOCK>(&format!("{}/msg", &id))
|
||||
.expect("message is not defined");
|
||||
(id, msg)
|
||||
}
|
||||
};
|
||||
|
||||
let ciphertext = self
|
||||
.executor
|
||||
.new_output::<C::BLOCK>(&format!("{}/ciphertext", &id))
|
||||
.expect("message is not defined");
|
||||
|
||||
BlockVars { msg, ciphertext }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<C, E> BlockCipher<C> for MpcBlockCipher<C, E>
|
||||
where
|
||||
C: BlockCipherCircuit,
|
||||
E: Memory + Load + Execute + Decode + DecodePrivate + Send + Sync + Send,
|
||||
{
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
fn set_key(&mut self, key: ValueRef) {
|
||||
self.state.key = Some(key);
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn preprocess(
|
||||
&mut self,
|
||||
visibility: Visibility,
|
||||
count: usize,
|
||||
) -> Result<(), BlockCipherError> {
|
||||
let key = self
|
||||
.state
|
||||
.key
|
||||
.clone()
|
||||
.ok_or_else(BlockCipherError::key_not_set)?;
|
||||
|
||||
for _ in 0..count {
|
||||
let vars = self.define_block(visibility);
|
||||
|
||||
self.executor
|
||||
.load(
|
||||
C::circuit(),
|
||||
&[key.clone(), vars.msg.clone()],
|
||||
&[vars.ciphertext.clone()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
match visibility {
|
||||
Visibility::Private | Visibility::Blind => {
|
||||
self.state.preprocessed_private.push_back(vars)
|
||||
}
|
||||
Visibility::Public => self.state.preprocessed_public.push_back(vars),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_private(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError> {
|
||||
let len = plaintext.len();
|
||||
let block: C::BLOCK = plaintext
|
||||
.try_into()
|
||||
.map_err(|_| BlockCipherError::invalid_message_length::<C>(len))?;
|
||||
|
||||
let key = self
|
||||
.state
|
||||
.key
|
||||
.clone()
|
||||
.ok_or_else(BlockCipherError::key_not_set)?;
|
||||
|
||||
let BlockVars { msg, ciphertext } =
|
||||
if let Some(vars) = self.state.preprocessed_private.pop_front() {
|
||||
vars
|
||||
} else {
|
||||
self.define_block(Visibility::Private)
|
||||
};
|
||||
|
||||
self.executor.assign(&msg, block)?;
|
||||
|
||||
self.executor
|
||||
.execute(C::circuit(), &[key, msg], &[ciphertext.clone()])
|
||||
.await?;
|
||||
|
||||
let mut outputs = self.executor.decode(&[ciphertext]).await?;
|
||||
|
||||
let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs
|
||||
.pop()
|
||||
.expect("ciphertext should be present")
|
||||
.try_into()
|
||||
{
|
||||
ciphertext
|
||||
} else {
|
||||
panic!("ciphertext should be a block")
|
||||
};
|
||||
|
||||
Ok(ciphertext.into())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_blind(&mut self) -> Result<Vec<u8>, BlockCipherError> {
|
||||
let key = self
|
||||
.state
|
||||
.key
|
||||
.clone()
|
||||
.ok_or_else(BlockCipherError::key_not_set)?;
|
||||
|
||||
let BlockVars { msg, ciphertext } =
|
||||
if let Some(vars) = self.state.preprocessed_private.pop_front() {
|
||||
vars
|
||||
} else {
|
||||
self.define_block(Visibility::Blind)
|
||||
};
|
||||
|
||||
self.executor
|
||||
.execute(C::circuit(), &[key, msg], &[ciphertext.clone()])
|
||||
.await?;
|
||||
|
||||
let mut outputs = self.executor.decode(&[ciphertext]).await?;
|
||||
|
||||
let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs
|
||||
.pop()
|
||||
.expect("ciphertext should be present")
|
||||
.try_into()
|
||||
{
|
||||
ciphertext
|
||||
} else {
|
||||
panic!("ciphertext should be a block")
|
||||
};
|
||||
|
||||
Ok(ciphertext.into())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_share(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError> {
|
||||
let len = plaintext.len();
|
||||
let block: C::BLOCK = plaintext
|
||||
.try_into()
|
||||
.map_err(|_| BlockCipherError::invalid_message_length::<C>(len))?;
|
||||
|
||||
let key = self
|
||||
.state
|
||||
.key
|
||||
.clone()
|
||||
.ok_or_else(BlockCipherError::key_not_set)?;
|
||||
|
||||
let BlockVars { msg, ciphertext } =
|
||||
if let Some(vars) = self.state.preprocessed_public.pop_front() {
|
||||
vars
|
||||
} else {
|
||||
self.define_block(Visibility::Public)
|
||||
};
|
||||
|
||||
self.executor.assign(&msg, block)?;
|
||||
|
||||
self.executor
|
||||
.execute(C::circuit(), &[key, msg], &[ciphertext.clone()])
|
||||
.await?;
|
||||
|
||||
let mut outputs = self.executor.decode_shared(&[ciphertext]).await?;
|
||||
|
||||
let share: C::BLOCK =
|
||||
if let Ok(share) = outputs.pop().expect("share should be present").try_into() {
|
||||
share
|
||||
} else {
|
||||
panic!("share should be a block")
|
||||
};
|
||||
|
||||
Ok(share.into())
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mpz_circuits::{
|
||||
circuits::AES128,
|
||||
types::{StaticValueType, Value},
|
||||
Circuit,
|
||||
};
|
||||
|
||||
/// A block cipher circuit.
|
||||
pub trait BlockCipherCircuit: Default + Clone + Send + Sync {
|
||||
/// The key type.
|
||||
type KEY: StaticValueType + Send + Sync;
|
||||
/// The block type.
|
||||
type BLOCK: StaticValueType + TryFrom<Vec<u8>> + TryFrom<Value> + Into<Vec<u8>> + Send + Sync;
|
||||
|
||||
/// The length of the key.
|
||||
const KEY_LEN: usize;
|
||||
/// The length of the block.
|
||||
const BLOCK_LEN: usize;
|
||||
|
||||
/// Returns the circuit of the cipher.
|
||||
fn circuit() -> Arc<Circuit>;
|
||||
}
|
||||
|
||||
/// Aes128 block cipher circuit.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Aes128;
|
||||
|
||||
impl BlockCipherCircuit for Aes128 {
|
||||
type KEY = [u8; 16];
|
||||
type BLOCK = [u8; 16];
|
||||
|
||||
const KEY_LEN: usize = 16;
|
||||
const BLOCK_LEN: usize = 16;
|
||||
|
||||
fn circuit() -> Arc<Circuit> {
|
||||
AES128.clone()
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
use derive_builder::Builder;
|
||||
|
||||
/// Configuration for a block cipher.
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
pub struct BlockCipherConfig {
|
||||
/// The ID of the block cipher.
|
||||
#[builder(setter(into))]
|
||||
pub(crate) id: String,
|
||||
}
|
||||
|
||||
impl BlockCipherConfig {
|
||||
/// Creates a new builder for the block cipher configuration.
|
||||
pub fn builder() -> BlockCipherConfigBuilder {
|
||||
BlockCipherConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
use core::fmt;
|
||||
use std::error::Error;
|
||||
|
||||
use crate::BlockCipherCircuit;
|
||||
|
||||
/// A block cipher error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct BlockCipherError {
|
||||
kind: ErrorKind,
|
||||
#[source]
|
||||
source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl BlockCipherError {
|
||||
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn key_not_set() -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Key,
|
||||
source: Some("key not set".into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn invalid_message_length<C: BlockCipherCircuit>(len: usize) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Msg,
|
||||
source: Some(
|
||||
format!(
|
||||
"message length does not equal block length: {} != {}",
|
||||
len,
|
||||
C::BLOCK_LEN
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ErrorKind {
|
||||
Vm,
|
||||
Key,
|
||||
Msg,
|
||||
}
|
||||
|
||||
impl fmt::Display for BlockCipherError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::Vm => write!(f, "vm error")?,
|
||||
ErrorKind::Key => write!(f, "key error")?,
|
||||
ErrorKind::Msg => write!(f, "message error")?,
|
||||
}
|
||||
|
||||
if let Some(ref source) = self.source {
|
||||
write!(f, " caused by: {}", source)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::MemoryError> for BlockCipherError {
|
||||
fn from(error: mpz_garble::MemoryError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::LoadError> for BlockCipherError {
|
||||
fn from(error: mpz_garble::LoadError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::ExecutionError> for BlockCipherError {
|
||||
fn from(error: mpz_garble::ExecutionError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::DecodeError> for BlockCipherError {
|
||||
fn from(error: mpz_garble::DecodeError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
@@ -1,236 +0,0 @@
|
||||
//! This crate provides a 2PC block cipher implementation.
|
||||
//!
|
||||
//! Both parties work together to encrypt or share an encrypted block using a
|
||||
//! shared key.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![deny(unsafe_code)]
|
||||
|
||||
mod cipher;
|
||||
mod circuit;
|
||||
mod config;
|
||||
mod error;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use mpz_garble::value::ValueRef;
|
||||
|
||||
pub use crate::{
|
||||
cipher::MpcBlockCipher,
|
||||
circuit::{Aes128, BlockCipherCircuit},
|
||||
};
|
||||
pub use config::{BlockCipherConfig, BlockCipherConfigBuilder, BlockCipherConfigBuilderError};
|
||||
pub use error::BlockCipherError;
|
||||
|
||||
/// Visibility of a message plaintext.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Visibility {
|
||||
/// Private message.
|
||||
Private,
|
||||
/// Blind message.
|
||||
Blind,
|
||||
/// Public message.
|
||||
Public,
|
||||
}
|
||||
|
||||
/// A trait for MPC block ciphers.
|
||||
#[async_trait]
|
||||
pub trait BlockCipher<Cipher>: Send + Sync
|
||||
where
|
||||
Cipher: BlockCipherCircuit,
|
||||
{
|
||||
/// Sets the key for the block cipher.
|
||||
fn set_key(&mut self, key: ValueRef);
|
||||
|
||||
/// Preprocesses `count` blocks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `visibility` - The visibility of the plaintext.
|
||||
/// * `count` - The number of blocks to preprocess.
|
||||
async fn preprocess(
|
||||
&mut self,
|
||||
visibility: Visibility,
|
||||
count: usize,
|
||||
) -> Result<(), BlockCipherError>;
|
||||
|
||||
/// Encrypts the given plaintext keeping it hidden from the other party(s).
|
||||
///
|
||||
/// Returns the ciphertext.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `plaintext` - The plaintext to encrypt.
|
||||
async fn encrypt_private(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError>;
|
||||
|
||||
/// Encrypts a plaintext provided by the other party(s).
|
||||
///
|
||||
/// Returns the ciphertext.
|
||||
async fn encrypt_blind(&mut self) -> Result<Vec<u8>, BlockCipherError>;
|
||||
|
||||
/// Encrypts a plaintext provided by both parties. Fails if the
|
||||
/// plaintext provided by both parties does not match.
|
||||
///
|
||||
/// Returns an additive share of the ciphertext.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `plaintext` - The plaintext to encrypt.
|
||||
async fn encrypt_share(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory};
|
||||
|
||||
use crate::circuit::Aes128;
|
||||
|
||||
use ::aes::Aes128 as TestAes128;
|
||||
use ::cipher::{BlockEncrypt, KeyInit};
|
||||
|
||||
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
|
||||
let mut msg = msg.into();
|
||||
let cipher = TestAes128::new(&key.into());
|
||||
cipher.encrypt_block(&mut msg);
|
||||
msg.into()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_block_cipher_blind() {
|
||||
let leader_config = BlockCipherConfig::builder().id("test").build().unwrap();
|
||||
let follower_config = BlockCipherConfig::builder().id("test").build().unwrap();
|
||||
|
||||
let key = [0u8; 16];
|
||||
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
// Key is public just for this test, typically it is private.
|
||||
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, key).unwrap();
|
||||
follower_vm.assign(&follower_key, key).unwrap();
|
||||
|
||||
let mut leader = MpcBlockCipher::<Aes128, _>::new(leader_config, leader_vm);
|
||||
leader.set_key(leader_key);
|
||||
|
||||
let mut follower = MpcBlockCipher::<Aes128, _>::new(follower_config, follower_vm);
|
||||
follower.set_key(follower_key);
|
||||
|
||||
let plaintext = [0u8; 16];
|
||||
|
||||
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
|
||||
leader.encrypt_private(plaintext.to_vec()),
|
||||
follower.encrypt_blind()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected = aes128(key, plaintext);
|
||||
|
||||
assert_eq!(leader_ciphertext, expected.to_vec());
|
||||
assert_eq!(leader_ciphertext, follower_ciphertext);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_block_cipher_share() {
|
||||
let leader_config = BlockCipherConfig::builder().id("test").build().unwrap();
|
||||
let follower_config = BlockCipherConfig::builder().id("test").build().unwrap();
|
||||
|
||||
let key = [0u8; 16];
|
||||
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
// Key is public just for this test, typically it is private.
|
||||
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, key).unwrap();
|
||||
follower_vm.assign(&follower_key, key).unwrap();
|
||||
|
||||
let mut leader = MpcBlockCipher::<Aes128, _>::new(leader_config, leader_vm);
|
||||
leader.set_key(leader_key);
|
||||
|
||||
let mut follower = MpcBlockCipher::<Aes128, _>::new(follower_config, follower_vm);
|
||||
follower.set_key(follower_key);
|
||||
|
||||
let plaintext = [0u8; 16];
|
||||
|
||||
let (leader_share, follower_share) = tokio::try_join!(
|
||||
leader.encrypt_share(plaintext.to_vec()),
|
||||
follower.encrypt_share(plaintext.to_vec())
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected = aes128(key, plaintext);
|
||||
|
||||
let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_block_cipher_preprocess() {
|
||||
let leader_config = BlockCipherConfig::builder().id("test").build().unwrap();
|
||||
let follower_config = BlockCipherConfig::builder().id("test").build().unwrap();
|
||||
|
||||
let key = [0u8; 16];
|
||||
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
// Key is public just for this test, typically it is private.
|
||||
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, key).unwrap();
|
||||
follower_vm.assign(&follower_key, key).unwrap();
|
||||
|
||||
let mut leader = MpcBlockCipher::<Aes128, _>::new(leader_config, leader_vm);
|
||||
leader.set_key(leader_key);
|
||||
|
||||
let mut follower = MpcBlockCipher::<Aes128, _>::new(follower_config, follower_vm);
|
||||
follower.set_key(follower_key);
|
||||
|
||||
let plaintext = [0u8; 16];
|
||||
|
||||
tokio::try_join!(
|
||||
leader.preprocess(Visibility::Private, 1),
|
||||
follower.preprocess(Visibility::Blind, 1)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
|
||||
leader.encrypt_private(plaintext.to_vec()),
|
||||
follower.encrypt_blind()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected = aes128(key, plaintext);
|
||||
|
||||
assert_eq!(leader_ciphertext, expected.to_vec());
|
||||
assert_eq!(leader_ciphertext, follower_ciphertext);
|
||||
|
||||
tokio::try_join!(
|
||||
leader.preprocess(Visibility::Public, 1),
|
||||
follower.preprocess(Visibility::Public, 1)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (leader_share, follower_share) = tokio::try_join!(
|
||||
leader.encrypt_share(plaintext.to_vec()),
|
||||
follower.encrypt_share(plaintext.to_vec())
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected = aes128(key, plaintext);
|
||||
|
||||
let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
31
crates/components/cipher/Cargo.toml
Normal file
31
crates/components/cipher/Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "tlsn-cipher"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "This crate provides implementations of ciphers for two parties"
|
||||
keywords = ["tls", "mpc", "2pc", "aes"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.8-pre"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "cipher"
|
||||
|
||||
[dependencies]
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
aes = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
|
||||
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
ctr = { workspace = true }
|
||||
cipher = { workspace = true }
|
||||
@@ -1,26 +1,31 @@
|
||||
use mpz_circuits::{circuits::aes128_trace, once_cell::sync::Lazy, trace, Circuit, CircuitBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// AES encrypts a counter block.
|
||||
///
|
||||
/// # Inputs
|
||||
///
|
||||
/// 0. KEY: 16-byte encryption key
|
||||
/// 1. IV: 4-byte IV
|
||||
/// 2. EXPLICIT_NONCE: 8-byte explicit nonce
|
||||
/// 3. CTR: 4-byte counter
|
||||
///
|
||||
/// # Outputs
|
||||
///
|
||||
/// 0. ECB: 16-byte output
|
||||
pub(crate) static AES_CTR: Lazy<Arc<Circuit>> = Lazy::new(|| {
|
||||
/// `fn(key: [u8; 16], iv: [u8; 4], nonce: [u8; 8], ctr: [u8; 4]) -> [u8; 16]`
|
||||
pub(crate) static AES128_CTR: Lazy<Arc<Circuit>> = Lazy::new(|| {
|
||||
let builder = CircuitBuilder::new();
|
||||
|
||||
let key = builder.add_array_input::<u8, 16>();
|
||||
let iv = builder.add_array_input::<u8, 4>();
|
||||
let nonce = builder.add_array_input::<u8, 8>();
|
||||
let ctr = builder.add_array_input::<u8, 4>();
|
||||
let ecb = aes_ctr_trace(builder.state(), key, iv, nonce, ctr);
|
||||
builder.add_output(ecb);
|
||||
|
||||
let keystream = aes_ctr_trace(builder.state(), key, iv, nonce, ctr);
|
||||
|
||||
builder.add_output(keystream);
|
||||
|
||||
Arc::new(builder.build().unwrap())
|
||||
});
|
||||
|
||||
/// `fn(key: [u8; 16], msg: [u8; 16]) -> [u8; 16]`
|
||||
pub(crate) static AES128_ECB: Lazy<Arc<Circuit>> = Lazy::new(|| {
|
||||
let builder = CircuitBuilder::new();
|
||||
|
||||
let key = builder.add_array_input::<u8, 16>();
|
||||
let message = builder.add_array_input::<u8, 16>();
|
||||
let block = aes128_trace(builder.state(), key, message);
|
||||
|
||||
builder.add_output(block);
|
||||
|
||||
Arc::new(builder.build().unwrap())
|
||||
});
|
||||
@@ -45,13 +50,3 @@ fn aes_128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
|
||||
aes.encrypt_block(&mut ciphertext);
|
||||
ciphertext.into()
|
||||
}
|
||||
|
||||
/// Builds a circuit for computing the XOR of two arrays.
|
||||
pub(crate) fn build_array_xor(len: usize) -> Arc<Circuit> {
|
||||
let builder = CircuitBuilder::new();
|
||||
let a = builder.add_vec_input::<u8>(len);
|
||||
let b = builder.add_vec_input::<u8>(len);
|
||||
let c = a.into_iter().zip(b).map(|(a, b)| a ^ b).collect::<Vec<_>>();
|
||||
builder.add_output(c);
|
||||
Arc::new(builder.build().expect("circuit is valid"))
|
||||
}
|
||||
44
crates/components/cipher/src/aes/error.rs
Normal file
44
crates/components/cipher/src/aes/error.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
/// AES error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct AesError {
|
||||
kind: ErrorKind,
|
||||
#[source]
|
||||
source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl AesError {
|
||||
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) enum ErrorKind {
|
||||
Vm,
|
||||
Key,
|
||||
Iv,
|
||||
}
|
||||
|
||||
impl Display for AesError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::Vm => write!(f, "vm error")?,
|
||||
ErrorKind::Key => write!(f, "key error")?,
|
||||
ErrorKind::Iv => write!(f, "iv error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
write!(f, " caused by: {}", source)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
375
crates/components/cipher/src/aes/mod.rs
Normal file
375
crates/components/cipher/src/aes/mod.rs
Normal file
@@ -0,0 +1,375 @@
|
||||
//! The AES-128 block cipher.
|
||||
|
||||
use crate::{Cipher, CtrBlock, Keystream};
|
||||
use async_trait::async_trait;
|
||||
use mpz_memory_core::binary::{Binary, U8};
|
||||
use mpz_vm_core::{prelude::*, Call, Vm};
|
||||
use std::fmt::Debug;
|
||||
|
||||
mod circuit;
|
||||
mod error;
|
||||
|
||||
pub use error::AesError;
|
||||
use error::ErrorKind;
|
||||
|
||||
/// Computes AES-128.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct Aes128 {
|
||||
key: Option<Array<U8, 16>>,
|
||||
iv: Option<Array<U8, 4>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Cipher for Aes128 {
|
||||
type Error = AesError;
|
||||
type Key = Array<U8, 16>;
|
||||
type Iv = Array<U8, 4>;
|
||||
type Nonce = Array<U8, 8>;
|
||||
type Counter = Array<U8, 4>;
|
||||
type Block = Array<U8, 16>;
|
||||
|
||||
fn set_key(&mut self, key: Array<U8, 16>) {
|
||||
self.key = Some(key);
|
||||
}
|
||||
|
||||
fn set_iv(&mut self, iv: Array<U8, 4>) {
|
||||
self.iv = Some(iv);
|
||||
}
|
||||
|
||||
fn key(&self) -> Option<&Array<U8, 16>> {
|
||||
self.key.as_ref()
|
||||
}
|
||||
|
||||
fn iv(&self) -> Option<&Array<U8, 4>> {
|
||||
self.iv.as_ref()
|
||||
}
|
||||
|
||||
fn alloc_block(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
input: Array<U8, 16>,
|
||||
) -> Result<Self::Block, Self::Error> {
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::new(circuit::AES128_ECB.clone())
|
||||
.arg(key)
|
||||
.arg(input)
|
||||
.build()
|
||||
.expect("call should be valid"),
|
||||
)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn alloc_ctr_block(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
|
||||
let iv = self
|
||||
.iv
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Iv, "iv not set"))?;
|
||||
|
||||
let explicit_nonce: Array<U8, 8> = vm
|
||||
.alloc()
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
vm.mark_public(explicit_nonce)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
let counter: Array<U8, 4> = vm
|
||||
.alloc()
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
vm.mark_public(counter)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::new(circuit::AES128_CTR.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
.arg(counter)
|
||||
.build()
|
||||
.expect("call should be valid"),
|
||||
)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
Ok(CtrBlock {
|
||||
explicit_nonce,
|
||||
counter,
|
||||
output,
|
||||
})
|
||||
}
|
||||
|
||||
fn alloc_keystream(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
|
||||
let iv = self
|
||||
.iv
|
||||
.ok_or_else(|| AesError::new(ErrorKind::Iv, "iv not set"))?;
|
||||
|
||||
let block_count = len.div_ceil(16);
|
||||
|
||||
let inputs = (0..block_count)
|
||||
.map(|_| {
|
||||
let explicit_nonce: Array<U8, 8> = vm
|
||||
.alloc()
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
let counter: Array<U8, 4> = vm
|
||||
.alloc()
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
vm.mark_public(explicit_nonce)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
vm.mark_public(counter)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
Ok((explicit_nonce, counter))
|
||||
})
|
||||
.collect::<Result<Vec<_>, AesError>>()?;
|
||||
|
||||
let blocks = inputs
|
||||
.into_iter()
|
||||
.map(|(explicit_nonce, counter)| {
|
||||
let output = vm
|
||||
.call(
|
||||
Call::new(circuit::AES128_CTR.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
.arg(counter)
|
||||
.build()
|
||||
.expect("call should be valid"),
|
||||
)
|
||||
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
|
||||
|
||||
Ok(CtrBlock {
|
||||
explicit_nonce,
|
||||
counter,
|
||||
output,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, AesError>>()?;
|
||||
|
||||
Ok(Keystream::new(&blocks))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Cipher;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
correlated::Delta,
|
||||
Array, MemoryExt, Vector, ViewExt,
|
||||
};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{Execute, Vm};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aes_ctr() {
|
||||
let key = [42_u8; 16];
|
||||
let iv = [3_u8; 4];
|
||||
let nonce = [5_u8; 8];
|
||||
let start_counter = 3u32;
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let aes_gen = setup_ctr(key, iv, &mut gen);
|
||||
let aes_ev = setup_ctr(key, iv, &mut ev);
|
||||
|
||||
let msg = vec![42u8; 128];
|
||||
|
||||
let keystream_gen = aes_gen.alloc_keystream(&mut gen, msg.len()).unwrap();
|
||||
let keystream_ev = aes_ev.alloc_keystream(&mut ev, msg.len()).unwrap();
|
||||
|
||||
let msg_ref_gen: Vector<U8> = gen.alloc_vec(msg.len()).unwrap();
|
||||
gen.mark_public(msg_ref_gen).unwrap();
|
||||
gen.assign(msg_ref_gen, msg.clone()).unwrap();
|
||||
gen.commit(msg_ref_gen).unwrap();
|
||||
|
||||
let msg_ref_ev: Vector<U8> = ev.alloc_vec(msg.len()).unwrap();
|
||||
ev.mark_public(msg_ref_ev).unwrap();
|
||||
ev.assign(msg_ref_ev, msg.clone()).unwrap();
|
||||
ev.commit(msg_ref_ev).unwrap();
|
||||
|
||||
let mut ctr = start_counter..;
|
||||
keystream_gen
|
||||
.assign(&mut gen, nonce, move || ctr.next().unwrap().to_be_bytes())
|
||||
.unwrap();
|
||||
let mut ctr = start_counter..;
|
||||
keystream_ev
|
||||
.assign(&mut ev, nonce, move || ctr.next().unwrap().to_be_bytes())
|
||||
.unwrap();
|
||||
|
||||
let cipher_out_gen = keystream_gen.apply(&mut gen, msg_ref_gen).unwrap();
|
||||
let cipher_out_ev = keystream_ev.apply(&mut ev, msg_ref_ev).unwrap();
|
||||
|
||||
let (ct_gen, ct_ev) = tokio::try_join!(
|
||||
async {
|
||||
let out = gen.decode(cipher_out_gen).unwrap();
|
||||
gen.flush(&mut ctx_a).await.unwrap();
|
||||
gen.execute(&mut ctx_a).await.unwrap();
|
||||
gen.flush(&mut ctx_a).await.unwrap();
|
||||
out.await
|
||||
},
|
||||
async {
|
||||
let out = ev.decode(cipher_out_ev).unwrap();
|
||||
ev.flush(&mut ctx_b).await.unwrap();
|
||||
ev.execute(&mut ctx_b).await.unwrap();
|
||||
ev.flush(&mut ctx_b).await.unwrap();
|
||||
out.await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ct_gen, ct_ev);
|
||||
|
||||
let expected = aes_apply_keystream(key, iv, nonce, start_counter as usize, msg);
|
||||
assert_eq!(ct_gen, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aes_ecb() {
|
||||
let key = [1_u8; 16];
|
||||
let input = [5_u8; 16];
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let aes_gen = setup_block(key, &mut gen);
|
||||
let aes_ev = setup_block(key, &mut ev);
|
||||
|
||||
let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap();
|
||||
gen.mark_public(block_ref_gen).unwrap();
|
||||
gen.assign(block_ref_gen, input).unwrap();
|
||||
gen.commit(block_ref_gen).unwrap();
|
||||
|
||||
let block_ref_ev: Array<U8, 16> = ev.alloc().unwrap();
|
||||
ev.mark_public(block_ref_ev).unwrap();
|
||||
ev.assign(block_ref_ev, input).unwrap();
|
||||
ev.commit(block_ref_ev).unwrap();
|
||||
|
||||
let block_gen = aes_gen.alloc_block(&mut gen, block_ref_gen).unwrap();
|
||||
let block_ev = aes_ev.alloc_block(&mut ev, block_ref_ev).unwrap();
|
||||
|
||||
let (ciphertext_gen, ciphetext_ev) = tokio::try_join!(
|
||||
async {
|
||||
let out = gen.decode(block_gen).unwrap();
|
||||
gen.flush(&mut ctx_a).await.unwrap();
|
||||
gen.execute(&mut ctx_a).await.unwrap();
|
||||
gen.flush(&mut ctx_a).await.unwrap();
|
||||
out.await
|
||||
},
|
||||
async {
|
||||
let out = ev.decode(block_ev).unwrap();
|
||||
ev.flush(&mut ctx_b).await.unwrap();
|
||||
ev.execute(&mut ctx_b).await.unwrap();
|
||||
ev.flush(&mut ctx_b).await.unwrap();
|
||||
out.await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ciphertext_gen, ciphetext_ev);
|
||||
|
||||
let expected = aes128(key, input);
|
||||
assert_eq!(ciphertext_gen, expected);
|
||||
}
|
||||
|
||||
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Generator::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
}
|
||||
|
||||
fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 {
|
||||
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
|
||||
vm.mark_public(key_ref).unwrap();
|
||||
vm.assign(key_ref, key).unwrap();
|
||||
vm.commit(key_ref).unwrap();
|
||||
|
||||
let iv_ref: Array<U8, 4> = vm.alloc().unwrap();
|
||||
vm.mark_public(iv_ref).unwrap();
|
||||
vm.assign(iv_ref, iv).unwrap();
|
||||
vm.commit(iv_ref).unwrap();
|
||||
|
||||
let mut aes = Aes128::default();
|
||||
|
||||
aes.set_key(key_ref);
|
||||
aes.set_iv(iv_ref);
|
||||
|
||||
aes
|
||||
}
|
||||
|
||||
fn setup_block(key: [u8; 16], vm: &mut dyn Vm<Binary>) -> Aes128 {
|
||||
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
|
||||
vm.mark_public(key_ref).unwrap();
|
||||
vm.assign(key_ref, key).unwrap();
|
||||
vm.commit(key_ref).unwrap();
|
||||
|
||||
let mut aes = Aes128::default();
|
||||
aes.set_key(key_ref);
|
||||
|
||||
aes
|
||||
}
|
||||
|
||||
fn aes_apply_keystream(
|
||||
key: [u8; 16],
|
||||
iv: [u8; 4],
|
||||
explicit_nonce: [u8; 8],
|
||||
start_ctr: usize,
|
||||
msg: Vec<u8>,
|
||||
) -> Vec<u8> {
|
||||
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
|
||||
use aes::Aes128;
|
||||
use ctr::Ctr32BE;
|
||||
|
||||
let mut full_iv = [0u8; 16];
|
||||
full_iv[0..4].copy_from_slice(&iv);
|
||||
full_iv[4..12].copy_from_slice(&explicit_nonce);
|
||||
|
||||
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
|
||||
let mut out = msg.clone();
|
||||
|
||||
cipher
|
||||
.try_seek(start_ctr * 16)
|
||||
.expect("start counter is less than keystream length");
|
||||
cipher.apply_keystream(&mut out);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
|
||||
use ::aes::Aes128 as TestAes128;
|
||||
use ::cipher::{BlockEncrypt, KeyInit};
|
||||
|
||||
let mut msg = msg.into();
|
||||
let cipher = TestAes128::new(&key.into());
|
||||
cipher.encrypt_block(&mut msg);
|
||||
msg.into()
|
||||
}
|
||||
}
|
||||
23
crates/components/cipher/src/circuit.rs
Normal file
23
crates/components/cipher/src/circuit.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! Ciphers and circuits.
|
||||
|
||||
use mpz_circuits::{types::ValueType, Circuit, CircuitBuilder, Tracer};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Builds a circuit which XORs the provided values.
|
||||
pub(crate) fn build_xor_circuit(inputs: &[ValueType]) -> Arc<Circuit> {
|
||||
let builder = CircuitBuilder::new();
|
||||
|
||||
for input_ty in inputs {
|
||||
let input_0 = builder.add_input_by_type(input_ty.clone());
|
||||
let input_1 = builder.add_input_by_type(input_ty.clone());
|
||||
|
||||
let input_0 = Tracer::new(builder.state(), input_0);
|
||||
let input_1 = Tracer::new(builder.state(), input_1);
|
||||
let output = input_0 ^ input_1;
|
||||
builder.add_output(output);
|
||||
}
|
||||
|
||||
let circ = builder.build().expect("circuit should be valid");
|
||||
|
||||
Arc::new(circ)
|
||||
}
|
||||
299
crates/components/cipher/src/lib.rs
Normal file
299
crates/components/cipher/src/lib.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
//! This crate provides implementations of 2PC ciphers for encryption with a
|
||||
//! shared key.
|
||||
//!
|
||||
//! Both parties can work together to encrypt and decrypt messages with
|
||||
//! different visibility configurations. See [`Cipher`] and [`Keystream`] for
|
||||
//! more information on the interface.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod aes;
|
||||
mod circuit;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use circuit::build_xor_circuit;
|
||||
use mpz_circuits::types::ValueType;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
FromRaw, MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, CallBuilder, CallError, Vm};
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Provides computation of 2PC ciphers in counter and ECB mode.
|
||||
///
|
||||
/// After setting `key` and `iv` allows to compute the keystream via
|
||||
/// [`Cipher::alloc`] or a single block in ECB mode via
|
||||
/// [`Cipher::assign_block`]. [`Keystream`] provides more tooling to compute the
|
||||
/// final cipher output in counter mode.
|
||||
#[async_trait]
|
||||
pub trait Cipher {
|
||||
/// The error type for the cipher.
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
/// Cipher key.
|
||||
type Key;
|
||||
/// Cipher IV.
|
||||
type Iv;
|
||||
/// Cipher nonce.
|
||||
type Nonce;
|
||||
/// Cipher counter.
|
||||
type Counter;
|
||||
/// Cipher block.
|
||||
type Block;
|
||||
|
||||
/// Sets the key.
|
||||
fn set_key(&mut self, key: Self::Key);
|
||||
|
||||
/// Sets the initialization vector.
|
||||
fn set_iv(&mut self, iv: Self::Iv);
|
||||
|
||||
/// Returns the key reference.
|
||||
fn key(&self) -> Option<&Self::Key>;
|
||||
|
||||
/// Returns the iv reference.
|
||||
fn iv(&self) -> Option<&Self::Iv>;
|
||||
|
||||
/// Allocates a single block in ECB mode.
|
||||
fn alloc_block(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
input: Self::Block,
|
||||
) -> Result<Self::Block, Self::Error>;
|
||||
|
||||
/// Allocates a single block in counter mode.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn alloc_ctr_block(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
|
||||
|
||||
/// Allocates a keystream in counter mode.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine to allocate into.
|
||||
/// * `len` - Length of the stream in bytes.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn alloc_keystream(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
|
||||
}
|
||||
|
||||
/// A block in counter mode.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CtrBlock<N, C, O> {
|
||||
/// Explicit nonce reference.
|
||||
pub explicit_nonce: N,
|
||||
/// Counter reference.
|
||||
pub counter: C,
|
||||
/// Output reference.
|
||||
pub output: O,
|
||||
}
|
||||
|
||||
/// The keystream of the cipher.
|
||||
///
|
||||
/// Can be used to XOR with the cipher input to operate the cipher in counter
|
||||
/// mode.
|
||||
pub struct Keystream<N, C, O> {
|
||||
blocks: VecDeque<CtrBlock<N, C, O>>,
|
||||
}
|
||||
|
||||
impl<N, C, O> Default for Keystream<N, C, O> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
blocks: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, C, O> Keystream<N, C, O>
|
||||
where
|
||||
N: Repr<Binary> + StaticSize<Binary> + Copy,
|
||||
C: Repr<Binary> + StaticSize<Binary> + Copy,
|
||||
O: Repr<Binary> + StaticSize<Binary> + Copy,
|
||||
{
|
||||
/// Creates a new keystream from the provided blocks.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the output of the keystream is not ordered and contiguous in
|
||||
/// memory.
|
||||
pub fn new(blocks: &[CtrBlock<N, C, O>]) -> Self {
|
||||
let mut pos = blocks
|
||||
.first()
|
||||
.map(|block| block.output.to_raw().ptr().as_usize())
|
||||
.unwrap_or(0);
|
||||
|
||||
for block in blocks {
|
||||
if block.output.to_raw().ptr().as_usize() != pos {
|
||||
panic!("output of keystream blocks must be ordered and contiguous in memory");
|
||||
}
|
||||
|
||||
pos += O::SIZE;
|
||||
}
|
||||
|
||||
Self {
|
||||
blocks: VecDeque::from_iter(blocks.iter().copied()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes keystream material.
|
||||
///
|
||||
/// Returns the consumed keystream material, leaving the remaining material
|
||||
/// in place.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `len` - Length of the keystream in bytes to return.
|
||||
pub fn consume(&mut self, len: usize) -> Result<Self, CipherError> {
|
||||
let block_count = len.div_ceil(self.block_size());
|
||||
|
||||
if block_count > self.blocks.len() {
|
||||
return Err(CipherError::new("insufficient keystream"));
|
||||
}
|
||||
|
||||
let blocks = self.blocks.split_off(self.blocks.len() - block_count);
|
||||
|
||||
Ok(Self { blocks })
|
||||
}
|
||||
|
||||
/// Applies the keystream to the provided input.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `input` - Input data.
|
||||
pub fn apply(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
input: Vector<U8>,
|
||||
) -> Result<Vector<U8>, CipherError> {
|
||||
if input.len() != self.len() {
|
||||
return Err(CipherError::new("input length must match keystream length"));
|
||||
} else if self.blocks.is_empty() {
|
||||
return Err(CipherError::new("no keystream material available"));
|
||||
}
|
||||
|
||||
let xor = build_xor_circuit(&[ValueType::new_array::<u8>(self.block_size())]);
|
||||
let mut pos = 0;
|
||||
let mut outputs = Vec::with_capacity(self.blocks.len());
|
||||
for block in &self.blocks {
|
||||
let call = CallBuilder::new(xor.clone())
|
||||
.arg(block.output)
|
||||
.arg(
|
||||
input
|
||||
.get(pos..pos + self.block_size())
|
||||
.expect("input length was checked"),
|
||||
)
|
||||
.build()?;
|
||||
let output: Vector<U8> = vm.call(call).map_err(CipherError::new)?;
|
||||
outputs.push(output);
|
||||
pos += self.block_size();
|
||||
}
|
||||
|
||||
// Calls were performed contiguously, so the output data is contiguous.
|
||||
let ptr = outputs
|
||||
.first()
|
||||
.map(|output| output.to_raw().ptr())
|
||||
.expect("keystream is not empty");
|
||||
let size = self.blocks.len() * O::SIZE;
|
||||
|
||||
let output = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Returns `len` bytes of the keystream as a vector.
|
||||
pub fn to_vector(&self, len: usize) -> Result<Vector<U8>, CipherError> {
|
||||
if len == 0 {
|
||||
return Err(CipherError::new("length must be greater than 0"));
|
||||
} else if self.blocks.is_empty() {
|
||||
return Err(CipherError::new("no keystream material available"));
|
||||
}
|
||||
|
||||
let block_count = len.div_ceil(self.block_size());
|
||||
if block_count != self.blocks.len() {
|
||||
return Err(CipherError::new("length does not match keystream length"));
|
||||
}
|
||||
|
||||
let ptr = self
|
||||
.blocks
|
||||
.front()
|
||||
.map(|block| block.output.to_raw().ptr())
|
||||
.expect("block count should be greater than 0");
|
||||
let size = block_count * O::SIZE;
|
||||
|
||||
let mut keystream = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
|
||||
keystream.truncate(len);
|
||||
|
||||
Ok(keystream)
|
||||
}
|
||||
|
||||
/// Assigns the keystream inputs.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `ctr` - Counter function. The provided function will be called to
|
||||
/// assign the counter values for each block.
|
||||
pub fn assign(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
explicit_nonce: N::Clear,
|
||||
mut ctr: impl FnMut() -> C::Clear,
|
||||
) -> Result<(), CipherError>
|
||||
where
|
||||
N::Clear: Copy,
|
||||
C::Clear: Copy,
|
||||
{
|
||||
for block in &self.blocks {
|
||||
vm.assign(block.explicit_nonce, explicit_nonce)
|
||||
.map_err(CipherError::new)?;
|
||||
vm.commit(block.explicit_nonce).map_err(CipherError::new)?;
|
||||
vm.assign(block.counter, ctr()).map_err(CipherError::new)?;
|
||||
vm.commit(block.counter).map_err(CipherError::new)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the block size in bytes.
|
||||
fn block_size(&self) -> usize {
|
||||
O::SIZE / 8
|
||||
}
|
||||
|
||||
/// Returns the length of the keystream in bytes.
|
||||
fn len(&self) -> usize {
|
||||
self.block_size() * self.blocks.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// A cipher error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{source}")]
|
||||
pub struct CipherError {
|
||||
#[source]
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
}
|
||||
|
||||
impl CipherError {
|
||||
pub(crate) fn new<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
source: source.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CallError> for CipherError {
|
||||
fn from(value: CallError) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
25
crates/components/deap/Cargo.toml
Normal file
25
crates/components/deap/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "tlsn-deap"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
mpz-core = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
tlsn-utils = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-zk = { workspace = true }
|
||||
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
516
crates/components/deap/src/lib.rs
Normal file
516
crates/components/deap/src/lib.rs
Normal file
@@ -0,0 +1,516 @@
|
||||
//! Dual-execution with Asymmetric Privacy (DEAP) protocol.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use std::{
|
||||
mem,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_common::{scoped_futures::ScopedFutureExt as _, Context};
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Slice, View},
|
||||
Call, Callable, Execute, Vm, VmError,
|
||||
};
|
||||
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
|
||||
use utils::range::{Difference, RangeSet, UnionMut};
|
||||
|
||||
type Error = DeapError;
|
||||
|
||||
/// The role of the DEAP VM.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum Role {
|
||||
Leader,
|
||||
Follower,
|
||||
}
|
||||
|
||||
/// DEAP Vm.
|
||||
#[derive(Debug)]
|
||||
pub struct Deap<Mpc, Zk> {
|
||||
role: Role,
|
||||
mpc: Arc<Mutex<Mpc>>,
|
||||
zk: Arc<Mutex<Zk>>,
|
||||
/// Private inputs of the follower.
|
||||
follower_inputs: RangeSet<usize>,
|
||||
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
|
||||
/// Whether the memories of the two VMs are potentially desynchronized.
|
||||
desync: AtomicBool,
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
/// Create a new DEAP Vm.
|
||||
pub fn new(role: Role, mpc: Mpc, zk: Zk) -> Self {
|
||||
Self {
|
||||
role,
|
||||
mpc: Arc::new(Mutex::new(mpc)),
|
||||
zk: Arc::new(Mutex::new(zk)),
|
||||
follower_inputs: RangeSet::default(),
|
||||
outputs: Vec::default(),
|
||||
desync: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the MPC and ZK VMs.
|
||||
pub fn into_inner(self) -> (Mpc, Zk) {
|
||||
(
|
||||
Arc::into_inner(self.mpc).unwrap().into_inner(),
|
||||
Arc::into_inner(self.zk).unwrap().into_inner(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the ZK VM.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// After calling this method, allocations will no longer be allowed in the
|
||||
/// DEAP VM as the memory will potentially be desynchronized.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the mutex locked by another thread.
|
||||
pub fn zk(&self) -> MutexGuard<'_, Zk> {
|
||||
self.desync.store(true, Ordering::Relaxed);
|
||||
self.zk.try_lock().unwrap()
|
||||
}
|
||||
|
||||
/// Returns an owned mutex guard to the ZK VM.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// After calling this method, allocations will no longer be allowed in the
|
||||
/// DEAP VM as the memory will potentially be desynchronized.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the mutex locked by another thread.
|
||||
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
|
||||
self.desync.store(true, Ordering::Relaxed);
|
||||
self.zk.clone().try_lock_owned().unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn mpc(&self) -> MutexGuard<'_, Mpc> {
|
||||
self.mpc.try_lock().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> Deap<Mpc, Zk>
|
||||
where
|
||||
Mpc: Vm<Binary> + Send + 'static,
|
||||
Zk: Vm<Binary> + Send + 'static,
|
||||
{
|
||||
/// Finalize the DEAP Vm.
|
||||
///
|
||||
/// This reveals all private inputs of the follower.
|
||||
pub async fn finalize(&mut self, ctx: &mut Context) -> Result<(), VmError> {
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
|
||||
// Decode the private inputs of the follower.
|
||||
//
|
||||
// # Security
|
||||
//
|
||||
// This assumes that the decoding process is authenticated from the leader's
|
||||
// perspective. In the case of garbled circuits, the leader should be the
|
||||
// generator such that the follower proves their inputs using their committed
|
||||
// MACs.
|
||||
let input_futs = self
|
||||
.follower_inputs
|
||||
.iter_ranges()
|
||||
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
mpc.execute_all(ctx).await?;
|
||||
|
||||
// Assign inputs to the ZK VM.
|
||||
for (mut decode, input) in input_futs
|
||||
.into_iter()
|
||||
.zip(self.follower_inputs.iter_ranges())
|
||||
{
|
||||
let input = Slice::from_range_unchecked(input);
|
||||
|
||||
// Follower has already assigned the inputs.
|
||||
if let Role::Leader = self.role {
|
||||
let value = decode
|
||||
.try_recv()
|
||||
.map_err(VmError::memory)?
|
||||
.expect("input should be decoded");
|
||||
zk.assign_raw(input, value)?;
|
||||
}
|
||||
|
||||
// Now the follower's inputs are public.
|
||||
zk.commit_raw(input)?;
|
||||
}
|
||||
|
||||
zk.execute_all(ctx).await.map_err(VmError::execute)?;
|
||||
|
||||
// Follower verifies the outputs are consistent.
|
||||
if let Role::Follower = self.role {
|
||||
for (output, mut value) in mem::take(&mut self.outputs) {
|
||||
// If the output is not available in the MPC VM, we did not execute and decode
|
||||
// it. Therefore, we do not need to check for equality.
|
||||
//
|
||||
// This can occur if some function was preprocessed but ultimately not used.
|
||||
if let Some(mpc_output) = mpc.get_raw(output)? {
|
||||
let zk_output = value
|
||||
.try_recv()
|
||||
.map_err(VmError::memory)?
|
||||
.expect("output should be decoded");
|
||||
|
||||
// Asserts equality of all the output values from both VMs.
|
||||
if zk_output != mpc_output {
|
||||
return Err(VmError::execute(Error::from(ErrorRepr::EqualityCheck)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> Memory<Binary> for Deap<Mpc, Zk>
|
||||
where
|
||||
Mpc: Memory<Binary, Error = VmError>,
|
||||
Zk: Memory<Binary, Error = VmError>,
|
||||
{
|
||||
type Error = VmError;
|
||||
|
||||
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
|
||||
if self.desync.load(Ordering::Relaxed) {
|
||||
return Err(VmError::memory(
|
||||
"DEAP VM memories are potentially desynchronized",
|
||||
));
|
||||
}
|
||||
|
||||
self.zk.try_lock().unwrap().alloc_raw(size)?;
|
||||
self.mpc.try_lock().unwrap().alloc_raw(size)
|
||||
}
|
||||
|
||||
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
|
||||
self.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.assign_raw(slice, data.clone())?;
|
||||
self.mpc.try_lock().unwrap().assign_raw(slice, data)
|
||||
}
|
||||
|
||||
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
// Follower's private inputs are not committed in the ZK VM until finalization.
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
for input in input_minus_follower.iter_ranges() {
|
||||
zk.commit_raw(Slice::from_range_unchecked(input))?;
|
||||
}
|
||||
|
||||
self.mpc.try_lock().unwrap().commit_raw(slice)
|
||||
}
|
||||
|
||||
fn get_raw(&self, slice: Slice) -> Result<Option<BitVec>, VmError> {
|
||||
self.mpc.try_lock().unwrap().get_raw(slice)
|
||||
}
|
||||
|
||||
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
|
||||
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
|
||||
self.outputs.push((slice, fut));
|
||||
|
||||
self.mpc.try_lock().unwrap().decode_raw(slice)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> View<Binary> for Deap<Mpc, Zk>
|
||||
where
|
||||
Mpc: View<Binary, Error = VmError>,
|
||||
Zk: View<Binary, Error = VmError>,
|
||||
{
|
||||
type Error = VmError;
|
||||
|
||||
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
|
||||
self.mpc.try_lock().unwrap().mark_public_raw(slice)
|
||||
}
|
||||
|
||||
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
zk.mark_private_raw(slice)?;
|
||||
mpc.mark_private_raw(slice)?;
|
||||
}
|
||||
Role::Follower => {
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(slice)?;
|
||||
mpc.mark_private_raw(slice)?;
|
||||
self.follower_inputs.union_mut(&slice.to_range());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mark_blind_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(slice)?;
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
self.follower_inputs.union_mut(&slice.to_range());
|
||||
}
|
||||
Role::Follower => {
|
||||
zk.mark_blind_raw(slice)?;
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> Callable<Binary> for Deap<Mpc, Zk>
|
||||
where
|
||||
Mpc: Vm<Binary>,
|
||||
Zk: Vm<Binary>,
|
||||
{
|
||||
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
|
||||
if self.desync.load(Ordering::Relaxed) {
|
||||
return Err(VmError::memory(
|
||||
"DEAP VM memories are potentially desynchronized",
|
||||
));
|
||||
}
|
||||
|
||||
self.zk.try_lock().unwrap().call_raw(call.clone())?;
|
||||
self.mpc.try_lock().unwrap().call_raw(call)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<Mpc, Zk> Execute for Deap<Mpc, Zk>
|
||||
where
|
||||
Mpc: Execute + Send + 'static,
|
||||
Zk: Execute + Send + 'static,
|
||||
{
|
||||
fn wants_flush(&self) -> bool {
|
||||
self.mpc.try_lock().unwrap().wants_flush() || self.zk.try_lock().unwrap().wants_flush()
|
||||
}
|
||||
|
||||
async fn flush(&mut self, ctx: &mut Context) -> Result<(), VmError> {
|
||||
let mut zk = self.zk.clone().try_lock_owned().unwrap();
|
||||
let mut mpc = self.mpc.clone().try_lock_owned().unwrap();
|
||||
ctx.try_join(
|
||||
|ctx| async move { zk.flush(ctx).await }.scope_boxed(),
|
||||
|ctx| async move { mpc.flush(ctx).await }.scope_boxed(),
|
||||
)
|
||||
.await
|
||||
.map_err(VmError::execute)??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wants_preprocess(&self) -> bool {
|
||||
self.mpc.try_lock().unwrap().wants_preprocess()
|
||||
|| self.zk.try_lock().unwrap().wants_preprocess()
|
||||
}
|
||||
|
||||
async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), VmError> {
|
||||
let mut zk = self.zk.clone().try_lock_owned().unwrap();
|
||||
let mut mpc = self.mpc.clone().try_lock_owned().unwrap();
|
||||
ctx.try_join(
|
||||
|ctx| async move { zk.preprocess(ctx).await }.scope_boxed(),
|
||||
|ctx| async move { mpc.preprocess(ctx).await }.scope_boxed(),
|
||||
)
|
||||
.await
|
||||
.map_err(VmError::execute)??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wants_execute(&self) -> bool {
|
||||
self.mpc.try_lock().unwrap().wants_execute()
|
||||
}
|
||||
|
||||
async fn execute(&mut self, ctx: &mut Context) -> Result<(), VmError> {
|
||||
// Only MPC VM is executed until finalization.
|
||||
self.mpc.try_lock().unwrap().execute(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct DeapError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("equality check failed")]
|
||||
EqualityCheck,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_circuits::circuits::AES128;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, correlated::Delta, Array},
|
||||
prelude::*,
|
||||
};
|
||||
use mpz_zk::{Prover, Verifier};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deap() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gb = Generator::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(rcot_recv);
|
||||
let verifier = Verifier::new(delta, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
let (ct_leader, ct_follower) = futures::join!(
|
||||
async {
|
||||
let key: Array<U8, 16> = leader.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = leader.alloc().unwrap();
|
||||
|
||||
leader.mark_private(key).unwrap();
|
||||
leader.mark_blind(msg).unwrap();
|
||||
leader.assign(key, [42u8; 16]).unwrap();
|
||||
leader.commit(key).unwrap();
|
||||
leader.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = leader
|
||||
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
|
||||
.unwrap();
|
||||
let ct = leader.decode(ct).unwrap();
|
||||
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.execute(&mut ctx_a).await.unwrap();
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.finalize(&mut ctx_a).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
},
|
||||
async {
|
||||
let key: Array<U8, 16> = follower.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = follower.alloc().unwrap();
|
||||
|
||||
follower.mark_blind(key).unwrap();
|
||||
follower.mark_private(msg).unwrap();
|
||||
follower.assign(msg, [69u8; 16]).unwrap();
|
||||
follower.commit(key).unwrap();
|
||||
follower.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = follower
|
||||
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
|
||||
.unwrap();
|
||||
let ct = follower.decode(ct).unwrap();
|
||||
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.execute(&mut ctx_b).await.unwrap();
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.finalize(&mut ctx_b).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(ct_leader, ct_follower);
|
||||
}
|
||||
|
||||
// Tests that the leader can not use different inputs in each VM without
|
||||
// detection by the follower.
|
||||
#[tokio::test]
|
||||
async fn test_malicious() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gb = Generator::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(rcot_recv);
|
||||
let verifier = Verifier::new(delta, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
let (_, follower_res) = futures::join!(
|
||||
async {
|
||||
let key: Array<U8, 16> = leader.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = leader.alloc().unwrap();
|
||||
|
||||
leader.mark_private(key).unwrap();
|
||||
leader.mark_blind(msg).unwrap();
|
||||
|
||||
// Use different inputs in each VM.
|
||||
leader.mpc().assign(key, [42u8; 16]).unwrap();
|
||||
leader
|
||||
.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.assign(key, [69u8; 16])
|
||||
.unwrap();
|
||||
|
||||
leader.commit(key).unwrap();
|
||||
leader.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = leader
|
||||
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
|
||||
.unwrap();
|
||||
let ct = leader.decode(ct).unwrap();
|
||||
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.execute(&mut ctx_a).await.unwrap();
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.finalize(&mut ctx_a).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
},
|
||||
async {
|
||||
let key: Array<U8, 16> = follower.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = follower.alloc().unwrap();
|
||||
|
||||
follower.mark_blind(key).unwrap();
|
||||
follower.mark_private(msg).unwrap();
|
||||
follower.assign(msg, [69u8; 16]).unwrap();
|
||||
follower.commit(key).unwrap();
|
||||
follower.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = follower
|
||||
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
|
||||
.unwrap();
|
||||
drop(follower.decode(ct).unwrap());
|
||||
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.execute(&mut ctx_b).await.unwrap();
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.finalize(&mut ctx_b).await
|
||||
}
|
||||
);
|
||||
|
||||
assert!(follower_res.is_err());
|
||||
}
|
||||
}
|
||||
@@ -19,21 +19,23 @@ mock = []
|
||||
[dependencies]
|
||||
tlsn-hmac-sha256-circuits = { workspace = true }
|
||||
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-common = { workspace = true, features = ["cpu"] }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true, features = ["async_tokio"] }
|
||||
mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
|
||||
criterion = { workspace = true, features = ["async_tokio"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "prf"
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
#![allow(clippy::let_underscore_future)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
|
||||
use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role};
|
||||
use mpz_common::executor::test_mt_executor;
|
||||
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory};
|
||||
use mpz_ot::ideal::ot::ideal_ot;
|
||||
use hmac_sha256::{MpcPrf, PrfConfig, Role};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, correlated::Delta, Array},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
#[allow(clippy::unit_arg)]
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
@@ -11,178 +18,113 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
group.sample_size(10);
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
|
||||
group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess));
|
||||
group.bench_function("prf", |b| b.to_async(&rt).iter(prf));
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
|
||||
async fn preprocess() {
|
||||
let (mut leader_exec, mut follower_exec) = test_mt_executor(128);
|
||||
|
||||
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
|
||||
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
|
||||
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
|
||||
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
|
||||
|
||||
let leader_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Leader,
|
||||
[0u8; 32],
|
||||
leader_exec.new_thread().await.unwrap(),
|
||||
leader_ot_send_0,
|
||||
leader_ot_recv_0,
|
||||
);
|
||||
let leader_thread_1 = leader_thread_0
|
||||
.new_thread(
|
||||
leader_exec.new_thread().await.unwrap(),
|
||||
leader_ot_send_1,
|
||||
leader_ot_recv_1,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let follower_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Follower,
|
||||
[0u8; 32],
|
||||
follower_exec.new_thread().await.unwrap(),
|
||||
follower_ot_send_0,
|
||||
follower_ot_recv_0,
|
||||
);
|
||||
let follower_thread_1 = follower_thread_0
|
||||
.new_thread(
|
||||
follower_exec.new_thread().await.unwrap(),
|
||||
follower_ot_send_1,
|
||||
follower_ot_recv_1,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
|
||||
let follower_pms = follower_thread_0
|
||||
.new_public_input::<[u8; 32]>("pms")
|
||||
.unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Leader).build().unwrap(),
|
||||
leader_thread_0,
|
||||
leader_thread_1,
|
||||
);
|
||||
let mut follower = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Follower).build().unwrap(),
|
||||
follower_thread_0,
|
||||
follower_thread_1,
|
||||
);
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader.setup(leader_pms).await.unwrap();
|
||||
leader.set_client_random(Some([0u8; 32])).await.unwrap();
|
||||
leader.preprocess().await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower.setup(follower_pms).await.unwrap();
|
||||
follower.set_client_random(None).await.unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
async fn prf() {
|
||||
let (mut leader_exec, mut follower_exec) = test_mt_executor(128);
|
||||
|
||||
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
|
||||
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
|
||||
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
|
||||
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
|
||||
|
||||
let leader_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Leader,
|
||||
[0u8; 32],
|
||||
leader_exec.new_thread().await.unwrap(),
|
||||
leader_ot_send_0,
|
||||
leader_ot_recv_0,
|
||||
);
|
||||
let leader_thread_1 = leader_thread_0
|
||||
.new_thread(
|
||||
leader_exec.new_thread().await.unwrap(),
|
||||
leader_ot_send_1,
|
||||
leader_ot_recv_1,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let follower_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Follower,
|
||||
[0u8; 32],
|
||||
follower_exec.new_thread().await.unwrap(),
|
||||
follower_ot_send_0,
|
||||
follower_ot_recv_0,
|
||||
);
|
||||
let follower_thread_1 = follower_thread_0
|
||||
.new_thread(
|
||||
follower_exec.new_thread().await.unwrap(),
|
||||
follower_ot_send_1,
|
||||
follower_ot_recv_1,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
|
||||
let follower_pms = follower_thread_0
|
||||
.new_public_input::<[u8; 32]>("pms")
|
||||
.unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Leader).build().unwrap(),
|
||||
leader_thread_0,
|
||||
leader_thread_1,
|
||||
);
|
||||
let mut follower = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Follower).build().unwrap(),
|
||||
follower_thread_0,
|
||||
follower_thread_1,
|
||||
);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let pms = [42u8; 32];
|
||||
let client_random = [0u8; 32];
|
||||
let server_random = [1u8; 32];
|
||||
let cf_hs_hash = [2u8; 32];
|
||||
let sf_hs_hash = [3u8; 32];
|
||||
let client_random = [69u8; 32];
|
||||
let server_random: [u8; 32] = [96u8; 32];
|
||||
|
||||
let (mut leader_exec, mut follower_exec) = test_mt_context(8);
|
||||
let mut leader_ctx = leader_exec.new_context().await.unwrap();
|
||||
let mut follower_ctx = follower_exec.new_context().await.unwrap();
|
||||
|
||||
let delta = Delta::random(&mut rng);
|
||||
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
|
||||
let mut follower_vm = Evaluator::new(ot_recv);
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
|
||||
leader_vm.mark_public(leader_pms).unwrap();
|
||||
leader_vm.assign(leader_pms, pms).unwrap();
|
||||
leader_vm.commit(leader_pms).unwrap();
|
||||
|
||||
let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
|
||||
follower_vm.mark_public(follower_pms).unwrap();
|
||||
follower_vm.assign(follower_pms, pms).unwrap();
|
||||
follower_vm.commit(follower_pms).unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
|
||||
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
|
||||
|
||||
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
|
||||
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
|
||||
|
||||
leader
|
||||
.set_client_random(&mut leader_vm, Some(client_random))
|
||||
.unwrap();
|
||||
follower.set_client_random(&mut follower_vm, None).unwrap();
|
||||
|
||||
leader
|
||||
.set_server_random(&mut leader_vm, server_random)
|
||||
.unwrap();
|
||||
follower
|
||||
.set_server_random(&mut follower_vm, server_random)
|
||||
.unwrap();
|
||||
|
||||
let _ = leader_vm
|
||||
.decode(leader_output.keys.client_write_key)
|
||||
.unwrap();
|
||||
let _ = leader_vm
|
||||
.decode(leader_output.keys.server_write_key)
|
||||
.unwrap();
|
||||
let _ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
|
||||
let _ = leader_vm.decode(leader_output.keys.server_iv).unwrap();
|
||||
|
||||
let _ = follower_vm
|
||||
.decode(follower_output.keys.client_write_key)
|
||||
.unwrap();
|
||||
let _ = follower_vm
|
||||
.decode(follower_output.keys.server_write_key)
|
||||
.unwrap();
|
||||
let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
|
||||
let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader.setup(leader_pms.clone()).await.unwrap();
|
||||
leader.set_client_random(Some(client_random)).await.unwrap();
|
||||
leader.preprocess().await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower.setup(follower_pms.clone()).await.unwrap();
|
||||
follower.set_client_random(None).await.unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
|
||||
leader.thread_mut().assign(&leader_pms, pms).unwrap();
|
||||
follower.thread_mut().assign(&follower_pms, pms).unwrap();
|
||||
let cf_hs_hash = [1u8; 32];
|
||||
let sf_hs_hash = [2u8; 32];
|
||||
|
||||
let (_leader_keys, _follower_keys) = futures::try_join!(
|
||||
leader.compute_session_keys(server_random),
|
||||
follower.compute_session_keys(server_random)
|
||||
)
|
||||
.unwrap();
|
||||
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
|
||||
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
|
||||
|
||||
let _ = futures::try_join!(
|
||||
leader.compute_client_finished_vd(cf_hs_hash),
|
||||
follower.compute_client_finished_vd(cf_hs_hash)
|
||||
)
|
||||
.unwrap();
|
||||
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
|
||||
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
|
||||
|
||||
let _ = futures::try_join!(
|
||||
leader.compute_server_finished_vd(sf_hs_hash),
|
||||
follower.compute_server_finished_vd(sf_hs_hash)
|
||||
)
|
||||
.unwrap();
|
||||
let _ = leader_vm.decode(leader_output.cf_vd).unwrap();
|
||||
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();
|
||||
|
||||
futures::try_join!(
|
||||
leader.thread_mut().finalize(),
|
||||
follower.thread_mut().finalize()
|
||||
)
|
||||
.unwrap();
|
||||
let _ = follower_vm.decode(follower_output.cf_vd).unwrap();
|
||||
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -33,6 +33,10 @@ impl PrfError {
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
|
||||
Self::new(ErrorKind::Vm, err)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -58,26 +62,8 @@ impl fmt::Display for PrfError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::MemoryError> for PrfError {
|
||||
fn from(error: mpz_garble::MemoryError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::LoadError> for PrfError {
|
||||
fn from(error: mpz_garble::LoadError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::ExecutionError> for PrfError {
|
||||
fn from(error: mpz_garble::ExecutionError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::DecodeError> for PrfError {
|
||||
fn from(error: mpz_garble::DecodeError) -> Self {
|
||||
impl From<mpz_common::ContextError> for PrfError {
|
||||
fn from(error: mpz_common::ContextError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,84 +12,52 @@ pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role};
|
||||
pub use error::PrfError;
|
||||
pub use prf::MpcPrf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use mpz_garble::value::ValueRef;
|
||||
use mpz_vm_core::memory::{binary::U8, Array};
|
||||
|
||||
pub(crate) static CF_LABEL: &[u8] = b"client finished";
|
||||
pub(crate) static SF_LABEL: &[u8] = b"server finished";
|
||||
|
||||
/// Session keys computed by the PRF.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionKeys {
|
||||
/// Client write key.
|
||||
pub client_write_key: ValueRef,
|
||||
/// Server write key.
|
||||
pub server_write_key: ValueRef,
|
||||
/// Client IV.
|
||||
pub client_iv: ValueRef,
|
||||
/// Server IV.
|
||||
pub server_iv: ValueRef,
|
||||
/// Builds the circuits for the PRF.
|
||||
///
|
||||
/// This function can be used ahead of time to build the circuits for the PRF,
|
||||
/// which at the moment is CPU and memory intensive.
|
||||
pub async fn build_circuits() {
|
||||
prf::Circuits::get().await;
|
||||
}
|
||||
|
||||
/// PRF trait for computing TLS PRF.
|
||||
#[async_trait]
|
||||
pub trait Prf {
|
||||
/// Sets up the PRF.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pms` - The pre-master secret.
|
||||
async fn setup(&mut self, pms: ValueRef) -> Result<SessionKeys, PrfError>;
|
||||
/// PRF output.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PrfOutput {
|
||||
/// TLS session keys.
|
||||
pub keys: SessionKeys,
|
||||
/// Client finished verify data.
|
||||
pub cf_vd: Array<U8, 12>,
|
||||
/// Server finished verify data.
|
||||
pub sf_vd: Array<U8, 12>,
|
||||
}
|
||||
|
||||
/// Sets the client random.
|
||||
///
|
||||
/// This must be set after calling [`Prf::setup`].
|
||||
///
|
||||
/// Only the leader can provide the client random.
|
||||
async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError>;
|
||||
|
||||
/// Preprocesses the PRF.
|
||||
async fn preprocess(&mut self) -> Result<(), PrfError>;
|
||||
|
||||
/// Computes the client finished verify data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `handshake_hash` - The handshake transcript hash.
|
||||
async fn compute_client_finished_vd(
|
||||
&mut self,
|
||||
handshake_hash: [u8; 32],
|
||||
) -> Result<[u8; 12], PrfError>;
|
||||
|
||||
/// Computes the server finished verify data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `handshake_hash` - The handshake transcript hash.
|
||||
async fn compute_server_finished_vd(
|
||||
&mut self,
|
||||
handshake_hash: [u8; 32],
|
||||
) -> Result<[u8; 12], PrfError>;
|
||||
|
||||
/// Computes the session keys.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `server_random` - The server random.
|
||||
async fn compute_session_keys(
|
||||
&mut self,
|
||||
server_random: [u8; 32],
|
||||
) -> Result<SessionKeys, PrfError>;
|
||||
/// Session keys computed by the PRF.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SessionKeys {
|
||||
/// Client write key.
|
||||
pub client_write_key: Array<U8, 16>,
|
||||
/// Server write key.
|
||||
pub server_write_key: Array<U8, 16>,
|
||||
/// Client IV.
|
||||
pub client_iv: Array<U8, 4>,
|
||||
/// Server IV.
|
||||
pub server_iv: Array<U8, 4>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_common::executor::test_st_executor;
|
||||
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Decode, Memory};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
|
||||
use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys};
|
||||
use mpz_ot::ideal::ot::ideal_ot;
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{memory::correlated::Delta, prelude::*};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -113,120 +81,89 @@ mod tests {
|
||||
#[ignore = "expensive"]
|
||||
#[tokio::test]
|
||||
async fn test_prf() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let pms = [42u8; 32];
|
||||
let client_random = [69u8; 32];
|
||||
let server_random: [u8; 32] = [96u8; 32];
|
||||
let ms = compute_ms(pms, client_random, server_random);
|
||||
|
||||
let (leader_ctx_0, follower_ctx_0) = test_st_executor(128);
|
||||
let (leader_ctx_1, follower_ctx_1) = test_st_executor(128);
|
||||
let (mut leader_ctx, mut follower_ctx) = test_st_context(128);
|
||||
|
||||
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
|
||||
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
|
||||
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
|
||||
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
|
||||
let delta = Delta::random(&mut rng);
|
||||
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let leader_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Leader,
|
||||
[0u8; 32],
|
||||
leader_ctx_0,
|
||||
leader_ot_send_0,
|
||||
leader_ot_recv_0,
|
||||
);
|
||||
let leader_thread_1 = leader_thread_0
|
||||
.new_thread(leader_ctx_1, leader_ot_send_1, leader_ot_recv_1)
|
||||
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
|
||||
let mut follower_vm = Evaluator::new(ot_recv);
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
|
||||
leader_vm.mark_public(leader_pms).unwrap();
|
||||
leader_vm.assign(leader_pms, pms).unwrap();
|
||||
leader_vm.commit(leader_pms).unwrap();
|
||||
|
||||
let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
|
||||
follower_vm.mark_public(follower_pms).unwrap();
|
||||
follower_vm.assign(follower_pms, pms).unwrap();
|
||||
follower_vm.commit(follower_pms).unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
|
||||
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
|
||||
|
||||
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
|
||||
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
|
||||
|
||||
leader
|
||||
.set_client_random(&mut leader_vm, Some(client_random))
|
||||
.unwrap();
|
||||
follower.set_client_random(&mut follower_vm, None).unwrap();
|
||||
|
||||
leader
|
||||
.set_server_random(&mut leader_vm, server_random)
|
||||
.unwrap();
|
||||
follower
|
||||
.set_server_random(&mut follower_vm, server_random)
|
||||
.unwrap();
|
||||
|
||||
let follower_thread_0 = DEAPThread::new(
|
||||
DEAPRole::Follower,
|
||||
[0u8; 32],
|
||||
follower_ctx_0,
|
||||
follower_ot_send_0,
|
||||
follower_ot_recv_0,
|
||||
);
|
||||
let follower_thread_1 = follower_thread_0
|
||||
.new_thread(follower_ctx_1, follower_ot_send_1, follower_ot_recv_1)
|
||||
let leader_cwk = leader_vm
|
||||
.decode(leader_output.keys.client_write_key)
|
||||
.unwrap();
|
||||
|
||||
// Set up public PMS for testing.
|
||||
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
|
||||
let follower_pms = follower_thread_0
|
||||
.new_public_input::<[u8; 32]>("pms")
|
||||
let leader_swk = leader_vm
|
||||
.decode(leader_output.keys.server_write_key)
|
||||
.unwrap();
|
||||
let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
|
||||
let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap();
|
||||
|
||||
leader_thread_0.assign(&leader_pms, pms).unwrap();
|
||||
follower_thread_0.assign(&follower_pms, pms).unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Leader).build().unwrap(),
|
||||
leader_thread_0,
|
||||
leader_thread_1,
|
||||
);
|
||||
let mut follower = MpcPrf::new(
|
||||
PrfConfig::builder().role(Role::Follower).build().unwrap(),
|
||||
follower_thread_0,
|
||||
follower_thread_1,
|
||||
);
|
||||
let follower_cwk = follower_vm
|
||||
.decode(follower_output.keys.client_write_key)
|
||||
.unwrap();
|
||||
let follower_swk = follower_vm
|
||||
.decode(follower_output.keys.server_write_key)
|
||||
.unwrap();
|
||||
let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
|
||||
let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader.setup(leader_pms).await.unwrap();
|
||||
leader.set_client_random(Some(client_random)).await.unwrap();
|
||||
leader.preprocess().await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower.setup(follower_pms).await.unwrap();
|
||||
follower.set_client_random(None).await.unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
|
||||
let (leader_session_keys, follower_session_keys) = futures::try_join!(
|
||||
leader.compute_session_keys(server_random),
|
||||
follower.compute_session_keys(server_random)
|
||||
)
|
||||
.unwrap();
|
||||
let leader_cwk = leader_cwk.await.unwrap();
|
||||
let leader_swk = leader_swk.await.unwrap();
|
||||
let leader_civ = leader_civ.await.unwrap();
|
||||
let leader_siv = leader_siv.await.unwrap();
|
||||
|
||||
let SessionKeys {
|
||||
client_write_key: leader_cwk,
|
||||
server_write_key: leader_swk,
|
||||
client_iv: leader_civ,
|
||||
server_iv: leader_siv,
|
||||
} = leader_session_keys;
|
||||
|
||||
let SessionKeys {
|
||||
client_write_key: follower_cwk,
|
||||
server_write_key: follower_swk,
|
||||
client_iv: follower_civ,
|
||||
server_iv: follower_siv,
|
||||
} = follower_session_keys;
|
||||
|
||||
// Decode session keys
|
||||
let (leader_session_keys, follower_session_keys) = futures::try_join!(
|
||||
async {
|
||||
leader
|
||||
.thread_mut()
|
||||
.decode(&[leader_cwk, leader_swk, leader_civ, leader_siv])
|
||||
.await
|
||||
},
|
||||
async {
|
||||
follower
|
||||
.thread_mut()
|
||||
.decode(&[follower_cwk, follower_swk, follower_civ, follower_siv])
|
||||
.await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let leader_cwk: [u8; 16] = leader_session_keys[0].clone().try_into().unwrap();
|
||||
let leader_swk: [u8; 16] = leader_session_keys[1].clone().try_into().unwrap();
|
||||
let leader_civ: [u8; 4] = leader_session_keys[2].clone().try_into().unwrap();
|
||||
let leader_siv: [u8; 4] = leader_session_keys[3].clone().try_into().unwrap();
|
||||
|
||||
let follower_cwk: [u8; 16] = follower_session_keys[0].clone().try_into().unwrap();
|
||||
let follower_swk: [u8; 16] = follower_session_keys[1].clone().try_into().unwrap();
|
||||
let follower_civ: [u8; 4] = follower_session_keys[2].clone().try_into().unwrap();
|
||||
let follower_siv: [u8; 4] = follower_session_keys[3].clone().try_into().unwrap();
|
||||
let follower_cwk = follower_cwk.await.unwrap();
|
||||
let follower_swk = follower_swk.await.unwrap();
|
||||
let follower_civ = follower_civ.await.unwrap();
|
||||
let follower_siv = follower_siv.await.unwrap();
|
||||
|
||||
let (expected_cwk, expected_swk, expected_civ, expected_siv) =
|
||||
session_keys(pms, client_random, server_random);
|
||||
@@ -244,24 +181,43 @@ mod tests {
|
||||
let cf_hs_hash = [1u8; 32];
|
||||
let sf_hs_hash = [2u8; 32];
|
||||
|
||||
let (cf_vd, _) = futures::try_join!(
|
||||
leader.compute_client_finished_vd(cf_hs_hash),
|
||||
follower.compute_client_finished_vd(cf_hs_hash)
|
||||
)
|
||||
.unwrap();
|
||||
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
|
||||
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
|
||||
|
||||
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
|
||||
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
|
||||
|
||||
let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap();
|
||||
let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap();
|
||||
|
||||
let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap();
|
||||
let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
|
||||
let leader_cf_vd = leader_cf_vd.await.unwrap();
|
||||
let leader_sf_vd = leader_sf_vd.await.unwrap();
|
||||
|
||||
let follower_cf_vd = follower_cf_vd.await.unwrap();
|
||||
let follower_sf_vd = follower_sf_vd.await.unwrap();
|
||||
|
||||
let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash);
|
||||
|
||||
assert_eq!(cf_vd, expected_cf_vd);
|
||||
|
||||
let (sf_vd, _) = futures::try_join!(
|
||||
leader.compute_server_finished_vd(sf_hs_hash),
|
||||
follower.compute_server_finished_vd(sf_hs_hash)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash);
|
||||
|
||||
assert_eq!(sf_vd, expected_sf_vd);
|
||||
assert_eq!(leader_cf_vd, expected_cf_vd);
|
||||
assert_eq!(leader_sf_vd, expected_sf_vd);
|
||||
assert_eq!(follower_cf_vd, expected_cf_vd);
|
||||
assert_eq!(follower_sf_vd, expected_sf_vd);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,60 +3,65 @@ use std::{
|
||||
sync::{Arc, OnceLock},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use hmac_sha256_circuits::{build_session_keys, build_verify_data};
|
||||
use mpz_circuits::Circuit;
|
||||
use mpz_common::cpu::CpuBackend;
|
||||
use mpz_garble::{config::Visibility, value::ValueRef, Decode, Execute, Load, Memory};
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U32, U8},
|
||||
Array,
|
||||
},
|
||||
prelude::*,
|
||||
Call, Vm,
|
||||
};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{Prf, PrfConfig, PrfError, Role, SessionKeys, CF_LABEL, SF_LABEL};
|
||||
use crate::{PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL};
|
||||
|
||||
/// Circuit for computing TLS session keys.
|
||||
static SESSION_KEYS_CIRC: OnceLock<Arc<Circuit>> = OnceLock::new();
|
||||
/// Circuit for computing TLS client verify data.
|
||||
static CLIENT_VD_CIRC: OnceLock<Arc<Circuit>> = OnceLock::new();
|
||||
/// Circuit for computing TLS server verify data.
|
||||
static SERVER_VD_CIRC: OnceLock<Arc<Circuit>> = OnceLock::new();
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Randoms {
|
||||
pub(crate) client_random: ValueRef,
|
||||
pub(crate) server_random: ValueRef,
|
||||
pub(crate) struct Circuits {
|
||||
session_keys: Arc<Circuit>,
|
||||
client_vd: Arc<Circuit>,
|
||||
server_vd: Arc<Circuit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct HashState {
|
||||
pub(crate) ms_outer_hash_state: ValueRef,
|
||||
pub(crate) ms_inner_hash_state: ValueRef,
|
||||
}
|
||||
impl Circuits {
|
||||
pub(crate) async fn get() -> &'static Self {
|
||||
static CIRCUITS: OnceLock<Circuits> = OnceLock::new();
|
||||
if let Some(circuits) = CIRCUITS.get() {
|
||||
return circuits;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct VerifyData {
|
||||
pub(crate) handshake_hash: ValueRef,
|
||||
pub(crate) vd: ValueRef,
|
||||
let (session_keys, client_vd, server_vd) = futures::join!(
|
||||
CpuBackend::blocking(build_session_keys),
|
||||
CpuBackend::blocking(|| build_verify_data(CF_LABEL)),
|
||||
CpuBackend::blocking(|| build_verify_data(SF_LABEL)),
|
||||
);
|
||||
|
||||
_ = CIRCUITS.set(Circuits {
|
||||
session_keys,
|
||||
client_vd,
|
||||
server_vd,
|
||||
});
|
||||
|
||||
CIRCUITS.get().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum State {
|
||||
Initialized,
|
||||
SessionKeys {
|
||||
pms: ValueRef,
|
||||
randoms: Randoms,
|
||||
hash_state: HashState,
|
||||
keys: crate::SessionKeys,
|
||||
cf_vd: VerifyData,
|
||||
sf_vd: VerifyData,
|
||||
client_random: Array<U8, 32>,
|
||||
server_random: Array<U8, 32>,
|
||||
cf_hash: Array<U8, 32>,
|
||||
sf_hash: Array<U8, 32>,
|
||||
},
|
||||
ClientFinished {
|
||||
hash_state: HashState,
|
||||
cf_vd: VerifyData,
|
||||
sf_vd: VerifyData,
|
||||
cf_hash: Array<U8, 32>,
|
||||
sf_hash: Array<U8, 32>,
|
||||
},
|
||||
ServerFinished {
|
||||
hash_state: HashState,
|
||||
sf_vd: VerifyData,
|
||||
sf_hash: Array<U8, 32>,
|
||||
},
|
||||
Complete,
|
||||
Error,
|
||||
@@ -69,14 +74,12 @@ impl State {
|
||||
}
|
||||
|
||||
/// MPC PRF for computing TLS HMAC-SHA256 PRF.
|
||||
pub struct MpcPrf<E> {
|
||||
pub struct MpcPrf {
|
||||
config: PrfConfig,
|
||||
state: State,
|
||||
thread_0: E,
|
||||
thread_1: E,
|
||||
}
|
||||
|
||||
impl<E> Debug for MpcPrf<E> {
|
||||
impl Debug for MpcPrf {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("MpcPrf")
|
||||
.field("config", &self.config)
|
||||
@@ -85,359 +88,225 @@ impl<E> Debug for MpcPrf<E> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> MpcPrf<E>
|
||||
where
|
||||
E: Load + Memory + Execute + Decode + Send,
|
||||
{
|
||||
impl MpcPrf {
|
||||
/// Creates a new instance of the PRF.
|
||||
pub fn new(config: PrfConfig, thread_0: E, thread_1: E) -> MpcPrf<E> {
|
||||
pub fn new(config: PrfConfig) -> MpcPrf {
|
||||
MpcPrf {
|
||||
config,
|
||||
state: State::Initialized,
|
||||
thread_0,
|
||||
thread_1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the MPC thread.
|
||||
pub fn thread_mut(&mut self) -> &mut E {
|
||||
&mut self.thread_0
|
||||
}
|
||||
|
||||
/// Executes a circuit which computes TLS session keys.
|
||||
/// Allocates resources for the PRF.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `pms` - The pre-master secret.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn execute_session_keys(
|
||||
pub fn alloc(
|
||||
&mut self,
|
||||
server_random: [u8; 32],
|
||||
) -> Result<SessionKeys, PrfError> {
|
||||
let State::SessionKeys {
|
||||
pms,
|
||||
randoms: randoms_refs,
|
||||
hash_state,
|
||||
keys,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(PrfError::state("session keys not initialized"));
|
||||
};
|
||||
|
||||
let circ = SESSION_KEYS_CIRC
|
||||
.get()
|
||||
.expect("session keys circuit is set");
|
||||
|
||||
self.thread_0
|
||||
.assign(&randoms_refs.server_random, server_random)?;
|
||||
|
||||
self.thread_0
|
||||
.execute(
|
||||
circ.clone(),
|
||||
&[pms, randoms_refs.client_random, randoms_refs.server_random],
|
||||
&[
|
||||
keys.client_write_key.clone(),
|
||||
keys.server_write_key.clone(),
|
||||
keys.client_iv.clone(),
|
||||
keys.server_iv.clone(),
|
||||
hash_state.ms_outer_hash_state.clone(),
|
||||
hash_state.ms_inner_hash_state.clone(),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.state = State::ClientFinished {
|
||||
hash_state,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
};
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn execute_cf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> {
|
||||
let State::ClientFinished {
|
||||
hash_state,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(PrfError::state("PRF not in client finished state"));
|
||||
};
|
||||
|
||||
let circ = CLIENT_VD_CIRC.get().expect("client vd circuit is set");
|
||||
|
||||
self.thread_0
|
||||
.assign(&cf_vd.handshake_hash, handshake_hash)?;
|
||||
|
||||
self.thread_0
|
||||
.execute(
|
||||
circ.clone(),
|
||||
&[
|
||||
hash_state.ms_outer_hash_state.clone(),
|
||||
hash_state.ms_inner_hash_state.clone(),
|
||||
cf_vd.handshake_hash,
|
||||
],
|
||||
&[cf_vd.vd.clone()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut outputs = self.thread_0.decode(&[cf_vd.vd]).await?;
|
||||
let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes");
|
||||
|
||||
self.state = State::ServerFinished { hash_state, sf_vd };
|
||||
|
||||
Ok(vd)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn execute_sf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> {
|
||||
let State::ServerFinished { hash_state, sf_vd } = self.state.take() else {
|
||||
return Err(PrfError::state("PRF not in server finished state"));
|
||||
};
|
||||
|
||||
let circ = SERVER_VD_CIRC.get().expect("server vd circuit is set");
|
||||
|
||||
self.thread_0
|
||||
.assign(&sf_vd.handshake_hash, handshake_hash)?;
|
||||
|
||||
self.thread_0
|
||||
.execute(
|
||||
circ.clone(),
|
||||
&[
|
||||
hash_state.ms_outer_hash_state,
|
||||
hash_state.ms_inner_hash_state,
|
||||
sf_vd.handshake_hash,
|
||||
],
|
||||
&[sf_vd.vd.clone()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut outputs = self.thread_0.decode(&[sf_vd.vd]).await?;
|
||||
let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes");
|
||||
|
||||
self.state = State::Complete;
|
||||
|
||||
Ok(vd)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<E> Prf for MpcPrf<E>
|
||||
where
|
||||
E: Memory + Load + Execute + Decode + Send,
|
||||
{
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn setup(&mut self, pms: ValueRef) -> Result<SessionKeys, PrfError> {
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
pms: Array<U8, 32>,
|
||||
) -> Result<PrfOutput, PrfError> {
|
||||
let State::Initialized = self.state.take() else {
|
||||
return Err(PrfError::state("PRF not in initialized state"));
|
||||
};
|
||||
|
||||
let thread = &mut self.thread_0;
|
||||
let circuits = futures::executor::block_on(Circuits::get());
|
||||
|
||||
let randoms = Randoms {
|
||||
// The client random is kept private so that the handshake transcript
|
||||
// hashes do not leak information about the server's identity.
|
||||
client_random: thread.new_input::<[u8; 32]>(
|
||||
"client_random",
|
||||
match self.config.role {
|
||||
Role::Leader => Visibility::Private,
|
||||
Role::Follower => Visibility::Blind,
|
||||
},
|
||||
)?,
|
||||
server_random: thread.new_input::<[u8; 32]>("server_random", Visibility::Public)?,
|
||||
};
|
||||
let client_random = vm.alloc().map_err(PrfError::vm)?;
|
||||
let server_random = vm.alloc().map_err(PrfError::vm)?;
|
||||
|
||||
// The client random is kept private so that the handshake transcript
|
||||
// hashes do not leak information about the server's identity.
|
||||
match self.config.role {
|
||||
Role::Leader => vm.mark_private(client_random),
|
||||
Role::Follower => vm.mark_blind(client_random),
|
||||
}
|
||||
.map_err(PrfError::vm)?;
|
||||
|
||||
vm.mark_public(server_random).map_err(PrfError::vm)?;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let (
|
||||
client_write_key,
|
||||
server_write_key,
|
||||
client_iv,
|
||||
server_iv,
|
||||
ms_outer_hash_state,
|
||||
ms_inner_hash_state,
|
||||
): (
|
||||
Array<U8, 16>,
|
||||
Array<U8, 16>,
|
||||
Array<U8, 4>,
|
||||
Array<U8, 4>,
|
||||
Array<U32, 8>,
|
||||
Array<U32, 8>,
|
||||
) = vm
|
||||
.call(
|
||||
Call::new(circuits.session_keys.clone())
|
||||
.arg(pms)
|
||||
.arg(client_random)
|
||||
.arg(server_random)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?,
|
||||
)
|
||||
.map_err(PrfError::vm)?;
|
||||
|
||||
let keys = SessionKeys {
|
||||
client_write_key: thread.new_output::<[u8; 16]>("client_write_key")?,
|
||||
server_write_key: thread.new_output::<[u8; 16]>("server_write_key")?,
|
||||
client_iv: thread.new_output::<[u8; 4]>("client_write_iv")?,
|
||||
server_iv: thread.new_output::<[u8; 4]>("server_write_iv")?,
|
||||
client_write_key,
|
||||
server_write_key,
|
||||
client_iv,
|
||||
server_iv,
|
||||
};
|
||||
|
||||
let hash_state = HashState {
|
||||
ms_outer_hash_state: thread.new_output::<[u32; 8]>("ms_outer_hash_state")?,
|
||||
ms_inner_hash_state: thread.new_output::<[u32; 8]>("ms_inner_hash_state")?,
|
||||
};
|
||||
let cf_hash = vm.alloc().map_err(PrfError::vm)?;
|
||||
vm.mark_public(cf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
let cf_vd = VerifyData {
|
||||
handshake_hash: thread.new_input::<[u8; 32]>("cf_hash", Visibility::Public)?,
|
||||
vd: thread.new_output::<[u8; 12]>("cf_vd")?,
|
||||
};
|
||||
let cf_vd = vm
|
||||
.call(
|
||||
Call::new(circuits.client_vd.clone())
|
||||
.arg(ms_outer_hash_state)
|
||||
.arg(ms_inner_hash_state)
|
||||
.arg(cf_hash)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?,
|
||||
)
|
||||
.map_err(PrfError::vm)?;
|
||||
|
||||
let sf_vd = VerifyData {
|
||||
handshake_hash: thread.new_input::<[u8; 32]>("sf_hash", Visibility::Public)?,
|
||||
vd: thread.new_output::<[u8; 12]>("sf_vd")?,
|
||||
};
|
||||
let sf_hash = vm.alloc().map_err(PrfError::vm)?;
|
||||
vm.mark_public(sf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
let sf_vd = vm
|
||||
.call(
|
||||
Call::new(circuits.server_vd.clone())
|
||||
.arg(ms_outer_hash_state)
|
||||
.arg(ms_inner_hash_state)
|
||||
.arg(sf_hash)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?,
|
||||
)
|
||||
.map_err(PrfError::vm)?;
|
||||
|
||||
self.state = State::SessionKeys {
|
||||
pms,
|
||||
randoms,
|
||||
hash_state,
|
||||
keys: keys.clone(),
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
client_random,
|
||||
server_random,
|
||||
cf_hash,
|
||||
sf_hash,
|
||||
};
|
||||
|
||||
Ok(keys)
|
||||
Ok(PrfOutput { keys, cf_vd, sf_vd })
|
||||
}
|
||||
|
||||
/// Sets the client random.
|
||||
///
|
||||
/// Only the leader can provide the client random.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `client_random` - The client random.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError> {
|
||||
let State::SessionKeys { randoms, .. } = &self.state else {
|
||||
pub fn set_client_random(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
random: Option<[u8; 32]>,
|
||||
) -> Result<(), PrfError> {
|
||||
let State::SessionKeys { client_random, .. } = &self.state else {
|
||||
return Err(PrfError::state("PRF not set up"));
|
||||
};
|
||||
|
||||
if self.config.role == Role::Leader {
|
||||
let Some(client_random) = client_random else {
|
||||
let Some(random) = random else {
|
||||
return Err(PrfError::role("leader must provide client random"));
|
||||
};
|
||||
|
||||
self.thread_0
|
||||
.assign(&randoms.client_random, client_random)?;
|
||||
} else if client_random.is_some() {
|
||||
vm.assign(*client_random, random).map_err(PrfError::vm)?;
|
||||
} else if random.is_some() {
|
||||
return Err(PrfError::role("only leader can set client random"));
|
||||
}
|
||||
|
||||
self.thread_0
|
||||
.commit(&[randoms.client_random.clone()])
|
||||
.await?;
|
||||
vm.commit(*client_random).map_err(PrfError::vm)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the server random.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `server_random` - The server random.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn preprocess(&mut self) -> Result<(), PrfError> {
|
||||
pub fn set_server_random(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
random: [u8; 32],
|
||||
) -> Result<(), PrfError> {
|
||||
let State::SessionKeys {
|
||||
pms,
|
||||
randoms,
|
||||
hash_state,
|
||||
keys,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
server_random,
|
||||
cf_hash,
|
||||
sf_hash,
|
||||
..
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(PrfError::state("PRF not set up"));
|
||||
};
|
||||
|
||||
// Builds all circuits in parallel and preprocesses the session keys circuit.
|
||||
futures::try_join!(
|
||||
async {
|
||||
if SESSION_KEYS_CIRC.get().is_none() {
|
||||
_ = SESSION_KEYS_CIRC.set(CpuBackend::blocking(build_session_keys).await);
|
||||
}
|
||||
vm.assign(server_random, random).map_err(PrfError::vm)?;
|
||||
vm.commit(server_random).map_err(PrfError::vm)?;
|
||||
|
||||
let circ = SESSION_KEYS_CIRC
|
||||
.get()
|
||||
.expect("session keys circuit should be built");
|
||||
|
||||
self.thread_0
|
||||
.load(
|
||||
circ.clone(),
|
||||
&[
|
||||
pms.clone(),
|
||||
randoms.client_random.clone(),
|
||||
randoms.server_random.clone(),
|
||||
],
|
||||
&[
|
||||
keys.client_write_key.clone(),
|
||||
keys.server_write_key.clone(),
|
||||
keys.client_iv.clone(),
|
||||
keys.server_iv.clone(),
|
||||
hash_state.ms_outer_hash_state.clone(),
|
||||
hash_state.ms_inner_hash_state.clone(),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok::<_, PrfError>(())
|
||||
},
|
||||
async {
|
||||
if CLIENT_VD_CIRC.get().is_none() {
|
||||
_ = CLIENT_VD_CIRC
|
||||
.set(CpuBackend::blocking(move || build_verify_data(CF_LABEL)).await);
|
||||
}
|
||||
|
||||
Ok::<_, PrfError>(())
|
||||
},
|
||||
async {
|
||||
if SERVER_VD_CIRC.get().is_none() {
|
||||
_ = SERVER_VD_CIRC
|
||||
.set(CpuBackend::blocking(move || build_verify_data(SF_LABEL)).await);
|
||||
}
|
||||
|
||||
Ok::<_, PrfError>(())
|
||||
}
|
||||
)?;
|
||||
|
||||
// Finishes preprocessing the verify data circuits.
|
||||
futures::try_join!(
|
||||
async {
|
||||
self.thread_0
|
||||
.load(
|
||||
CLIENT_VD_CIRC
|
||||
.get()
|
||||
.expect("client finished circuit should be built")
|
||||
.clone(),
|
||||
&[
|
||||
hash_state.ms_outer_hash_state.clone(),
|
||||
hash_state.ms_inner_hash_state.clone(),
|
||||
cf_vd.handshake_hash.clone(),
|
||||
],
|
||||
&[cf_vd.vd.clone()],
|
||||
)
|
||||
.await
|
||||
},
|
||||
async {
|
||||
self.thread_1
|
||||
.load(
|
||||
SERVER_VD_CIRC
|
||||
.get()
|
||||
.expect("server finished circuit should be built")
|
||||
.clone(),
|
||||
&[
|
||||
hash_state.ms_outer_hash_state.clone(),
|
||||
hash_state.ms_inner_hash_state.clone(),
|
||||
sf_vd.handshake_hash.clone(),
|
||||
],
|
||||
&[sf_vd.vd.clone()],
|
||||
)
|
||||
.await
|
||||
}
|
||||
)?;
|
||||
|
||||
self.state = State::SessionKeys {
|
||||
pms,
|
||||
randoms,
|
||||
hash_state,
|
||||
keys,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
};
|
||||
self.state = State::ClientFinished { cf_hash, sf_hash };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the client finished handshake hash.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `handshake_hash` - The handshake transcript hash.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn compute_client_finished_vd(
|
||||
pub fn set_cf_hash(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
handshake_hash: [u8; 32],
|
||||
) -> Result<[u8; 12], PrfError> {
|
||||
self.execute_cf_vd(handshake_hash).await
|
||||
) -> Result<(), PrfError> {
|
||||
let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else {
|
||||
return Err(PrfError::state("PRF not in client finished state"));
|
||||
};
|
||||
|
||||
vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?;
|
||||
vm.commit(cf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
self.state = State::ServerFinished { sf_hash };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the server finished handshake hash.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `handshake_hash` - The handshake transcript hash.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn compute_server_finished_vd(
|
||||
pub fn set_sf_hash(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
handshake_hash: [u8; 32],
|
||||
) -> Result<[u8; 12], PrfError> {
|
||||
self.execute_sf_vd(handshake_hash).await
|
||||
}
|
||||
) -> Result<(), PrfError> {
|
||||
let State::ServerFinished { sf_hash } = self.state.take() else {
|
||||
return Err(PrfError::state("PRF not in server finished state"));
|
||||
};
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn compute_session_keys(
|
||||
&mut self,
|
||||
server_random: [u8; 32],
|
||||
) -> Result<SessionKeys, PrfError> {
|
||||
self.execute_session_keys(server_random).await
|
||||
vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?;
|
||||
vm.commit(sf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
self.state = State::Complete;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,32 +13,30 @@ name = "key_exchange"
|
||||
|
||||
[features]
|
||||
default = ["mock"]
|
||||
mock = []
|
||||
mock = ["mpz-share-conversion/test-utils", "mpz-common/ideal"]
|
||||
|
||||
[dependencies]
|
||||
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [
|
||||
"ideal",
|
||||
] }
|
||||
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-fields = { workspace = true }
|
||||
mpz-share-conversion = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
|
||||
p256 = { workspace = true, features = ["ecdh", "serde"] }
|
||||
async-trait = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serio = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [
|
||||
"ideal",
|
||||
] }
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
mpz-garble = { workspace = true }
|
||||
|
||||
rand_chacha = { workspace = true }
|
||||
rand_core = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rstest = { workspace = true }
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
//! This module provides the circuits used in the key exchange protocol.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use mpz_circuits::{circuits::big_num::nbyte_add_mod_trace, Circuit, CircuitBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// NIST P-256 prime big-endian.
|
||||
static P: [u8; 32] = [
|
||||
|
||||
@@ -1,29 +1 @@
|
||||
use derive_builder::Builder;
|
||||
|
||||
/// Role in the key exchange protocol.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Role {
|
||||
/// Leader.
|
||||
Leader,
|
||||
/// Follower.
|
||||
Follower,
|
||||
}
|
||||
|
||||
/// A config used for [MpcKeyExchange](super::MpcKeyExchange).
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
pub struct KeyExchangeConfig {
|
||||
/// Protocol role.
|
||||
role: Role,
|
||||
}
|
||||
|
||||
impl KeyExchangeConfig {
|
||||
/// Creates a new builder for the key exchange configuration.
|
||||
pub fn builder() -> KeyExchangeConfigBuilder {
|
||||
KeyExchangeConfigBuilder::default()
|
||||
}
|
||||
|
||||
/// Get the role of this instance.
|
||||
pub fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,120 +1,87 @@
|
||||
use core::fmt;
|
||||
use std::error::Error;
|
||||
|
||||
/// A key exchange error.
|
||||
/// MPC-TLS protocol error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct KeyExchangeError {
|
||||
kind: ErrorKind,
|
||||
#[source]
|
||||
source: Option<Box<dyn Error + Send + Sync>>,
|
||||
#[error(transparent)]
|
||||
pub struct KeyExchangeError(#[from] pub(crate) ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("key exchange error: {0}")]
|
||||
pub(crate) enum ErrorRepr {
|
||||
#[error("state error: {0}")]
|
||||
State(Box<dyn Error + Send + Sync + 'static>),
|
||||
#[error("context error: {0}")]
|
||||
Ctx(Box<dyn Error + Send + Sync + 'static>),
|
||||
#[error("io error: {0}")]
|
||||
Io(std::io::Error),
|
||||
#[error("vm error: {0}")]
|
||||
Vm(Box<dyn Error + Send + Sync + 'static>),
|
||||
#[error("share conversion error: {0}")]
|
||||
ShareConversion(Box<dyn Error + Send + Sync + 'static>),
|
||||
#[error("role error: {0}")]
|
||||
Role(Box<dyn Error + Send + Sync + 'static>),
|
||||
#[error("key error: {0}")]
|
||||
Key(Box<dyn Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl KeyExchangeError {
|
||||
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
pub(crate) fn state<E>(err: E) -> KeyExchangeError
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync>>,
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
Self(ErrorRepr::State(err.into()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn kind(&self) -> &ErrorKind {
|
||||
&self.kind
|
||||
pub(crate) fn ctx<E>(err: E) -> KeyExchangeError
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Ctx(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn state(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::State,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
pub(crate) fn vm<E>(err: E) -> KeyExchangeError
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn role(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Role,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
pub(crate) fn share_conversion<E>(err: E) -> KeyExchangeError
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::ShareConversion(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ErrorKind {
|
||||
Io,
|
||||
Context,
|
||||
Vm,
|
||||
ShareConversion,
|
||||
Key,
|
||||
State,
|
||||
Role,
|
||||
}
|
||||
pub(crate) fn role<E>(err: E) -> KeyExchangeError
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Role(err.into()))
|
||||
}
|
||||
|
||||
impl fmt::Display for KeyExchangeError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::Io => write!(f, "io error")?,
|
||||
ErrorKind::Context => write!(f, "context error")?,
|
||||
ErrorKind::Vm => write!(f, "vm error")?,
|
||||
ErrorKind::ShareConversion => write!(f, "share conversion error")?,
|
||||
ErrorKind::Key => write!(f, "key error")?,
|
||||
ErrorKind::State => write!(f, "state error")?,
|
||||
ErrorKind::Role => write!(f, "role error")?,
|
||||
}
|
||||
|
||||
if let Some(ref source) = self.source {
|
||||
write!(f, " caused by: {}", source)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
pub(crate) fn key<E>(err: E) -> KeyExchangeError
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Key(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_common::ContextError> for KeyExchangeError {
|
||||
fn from(error: mpz_common::ContextError) -> Self {
|
||||
Self::new(ErrorKind::Context, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::MemoryError> for KeyExchangeError {
|
||||
fn from(error: mpz_garble::MemoryError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::LoadError> for KeyExchangeError {
|
||||
fn from(error: mpz_garble::LoadError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::ExecutionError> for KeyExchangeError {
|
||||
fn from(error: mpz_garble::ExecutionError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::DecodeError> for KeyExchangeError {
|
||||
fn from(error: mpz_garble::DecodeError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_share_conversion::ShareConversionError> for KeyExchangeError {
|
||||
fn from(error: mpz_share_conversion::ShareConversionError) -> Self {
|
||||
Self::new(ErrorKind::ShareConversion, error)
|
||||
fn from(value: mpz_common::ContextError) -> Self {
|
||||
Self::ctx(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<p256::elliptic_curve::Error> for KeyExchangeError {
|
||||
fn from(error: p256::elliptic_curve::Error) -> Self {
|
||||
Self::new(ErrorKind::Key, error)
|
||||
fn from(value: p256::elliptic_curve::Error) -> Self {
|
||||
Self::key(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for KeyExchangeError {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Self::new(ErrorKind::Io, error)
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
Self(ErrorRepr::Io(err))
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,63 +15,64 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod circuit;
|
||||
mod config;
|
||||
pub(crate) mod error;
|
||||
mod exchange;
|
||||
#[cfg(feature = "mock")]
|
||||
pub mod mock;
|
||||
pub(crate) mod point_addition;
|
||||
|
||||
pub use config::{
|
||||
KeyExchangeConfig, KeyExchangeConfigBuilder, KeyExchangeConfigBuilderError, Role,
|
||||
};
|
||||
pub use error::KeyExchangeError;
|
||||
pub use exchange::MpcKeyExchange;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_garble::value::ValueRef;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
Array,
|
||||
};
|
||||
use mpz_vm_core::Vm;
|
||||
use p256::PublicKey;
|
||||
|
||||
/// Pre-master secret.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Pms(ValueRef);
|
||||
pub type Pms = Array<U8, 32>;
|
||||
|
||||
impl Pms {
|
||||
/// Creates a new PMS.
|
||||
pub fn new(value: ValueRef) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
/// Gets the value of the PMS.
|
||||
pub fn into_value(self) -> ValueRef {
|
||||
self.0
|
||||
}
|
||||
/// Role in the key exchange protocol.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Role {
|
||||
/// Leader.
|
||||
Leader,
|
||||
/// Follower.
|
||||
Follower,
|
||||
}
|
||||
|
||||
/// A trait for the 3-party key exchange protocol.
|
||||
#[async_trait]
|
||||
pub trait KeyExchange {
|
||||
/// Gets the server's public key.
|
||||
fn server_key(&self) -> Option<PublicKey>;
|
||||
/// Allocate necessary computational resources.
|
||||
fn alloc(&mut self, vm: &mut dyn Vm<Binary>) -> Result<Pms, KeyExchangeError>;
|
||||
|
||||
/// Sets the server's public key.
|
||||
async fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>;
|
||||
fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>;
|
||||
|
||||
/// Gets the server's public key.
|
||||
fn server_key(&self) -> Option<PublicKey>;
|
||||
|
||||
/// Computes the client's public key.
|
||||
///
|
||||
/// The client's public key in this context is the combined public key (EC
|
||||
/// point addition) of the leader's public key and the follower's public
|
||||
/// key.
|
||||
async fn client_key(&mut self) -> Result<PublicKey, KeyExchangeError>;
|
||||
fn client_key(&self) -> Result<PublicKey, KeyExchangeError>;
|
||||
|
||||
/// Performs any necessary one-time setup, returning a reference to the PMS.
|
||||
///
|
||||
/// The PMS will not be assigned until `compute_pms` is called.
|
||||
async fn setup(&mut self) -> Result<Pms, KeyExchangeError>;
|
||||
/// Performs one-time setup for the key exchange protocol.
|
||||
async fn setup(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError>;
|
||||
|
||||
/// Preprocesses the key exchange.
|
||||
async fn preprocess(&mut self) -> Result<(), KeyExchangeError>;
|
||||
/// Computes the shares of the PMS.
|
||||
async fn compute_shares(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError>;
|
||||
|
||||
/// Computes the PMS.
|
||||
async fn compute_pms(&mut self) -> Result<Pms, KeyExchangeError>;
|
||||
/// Assigns the PMS shares to the VM.
|
||||
fn assign(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), KeyExchangeError>;
|
||||
|
||||
/// Finalizes the key exchange protocol.
|
||||
async fn finalize(&mut self) -> Result<(), KeyExchangeError>;
|
||||
}
|
||||
|
||||
@@ -1,71 +1,51 @@
|
||||
//! This module provides mock types for key exchange leader and follower and a
|
||||
//! function to create such a pair.
|
||||
|
||||
use crate::{KeyExchangeConfig, MpcKeyExchange, Role};
|
||||
|
||||
use mpz_common::executor::{test_st_executor, STExecutor};
|
||||
use mpz_garble::{Decode, Execute, Memory};
|
||||
use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter};
|
||||
use serio::channel::MemoryDuplex;
|
||||
use crate::{MpcKeyExchange, Role};
|
||||
use mpz_core::Block;
|
||||
use mpz_fields::p256::P256;
|
||||
use mpz_share_conversion::ideal::{
|
||||
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
|
||||
};
|
||||
|
||||
/// A mock key exchange instance.
|
||||
pub type MockKeyExchange<E> =
|
||||
MpcKeyExchange<STExecutor<MemoryDuplex>, IdealShareConverter, IdealShareConverter, E>;
|
||||
pub type MockKeyExchange =
|
||||
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>;
|
||||
|
||||
/// Creates a mock pair of key exchange leader and follower.
|
||||
pub fn create_mock_key_exchange_pair<E: Memory + Execute + Decode + Send>(
|
||||
leader_executor: E,
|
||||
follower_executor: E,
|
||||
) -> (MockKeyExchange<E>, MockKeyExchange<E>) {
|
||||
let (leader_ctx, follower_ctx) = test_st_executor(8);
|
||||
let (leader_converter_0, follower_converter_0) = ideal_share_converter();
|
||||
let (leader_converter_1, follower_converter_1) = ideal_share_converter();
|
||||
pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
|
||||
let (leader_converter_0, follower_converter_0) = ideal_share_convert(Block::ZERO);
|
||||
let (follower_converter_1, leader_converter_1) = ideal_share_convert(Block::ZERO);
|
||||
|
||||
let key_exchange_config_leader = KeyExchangeConfig::builder()
|
||||
.role(Role::Leader)
|
||||
.build()
|
||||
.unwrap();
|
||||
let leader = MpcKeyExchange::new(Role::Leader, leader_converter_0, leader_converter_1);
|
||||
|
||||
let key_exchange_config_follower = KeyExchangeConfig::builder()
|
||||
.role(Role::Follower)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let leader = MpcKeyExchange::new(
|
||||
key_exchange_config_leader,
|
||||
leader_ctx,
|
||||
leader_converter_0,
|
||||
leader_converter_1,
|
||||
leader_executor,
|
||||
);
|
||||
|
||||
let follower = MpcKeyExchange::new(
|
||||
key_exchange_config_follower,
|
||||
follower_ctx,
|
||||
follower_converter_0,
|
||||
follower_converter_1,
|
||||
follower_executor,
|
||||
);
|
||||
let follower = MpcKeyExchange::new(Role::Follower, follower_converter_1, follower_converter_0);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_garble::protocol::deap::mock::create_mock_deap_vm;
|
||||
|
||||
use crate::KeyExchange;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
|
||||
|
||||
use super::*;
|
||||
use crate::KeyExchange;
|
||||
|
||||
#[test]
|
||||
fn test_mock_is_ke() {
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
let (leader, follower) = create_mock_key_exchange_pair(leader_vm, follower_vm);
|
||||
let (leader, follower) = create_mock_key_exchange_pair();
|
||||
|
||||
fn is_key_exchange<T: KeyExchange>(_: T) {}
|
||||
fn is_key_exchange<T: KeyExchange, V>(_: T) {}
|
||||
|
||||
is_key_exchange(leader);
|
||||
is_key_exchange(follower);
|
||||
is_key_exchange::<
|
||||
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
|
||||
Generator<IdealCOTSender>,
|
||||
>(leader);
|
||||
|
||||
is_key_exchange::<
|
||||
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
|
||||
Evaluator<IdealCOTReceiver>,
|
||||
>(follower);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
//! This module contains the message types exchanged between the prover and the TLS verifier.
|
||||
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
use p256::{elliptic_curve::sec1::ToEncodedPoint, PublicKey as P256PublicKey};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A type for messages exchanged between the prover and the TLS verifier during the key exchange
|
||||
/// protocol.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum KeyExchangeMessage {
|
||||
FollowerPublicKey(PublicKey),
|
||||
ServerPublicKey(PublicKey),
|
||||
}
|
||||
|
||||
/// A wrapper for a serialized public key.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PublicKey {
|
||||
/// The sec1 serialized public key.
|
||||
pub key: Vec<u8>,
|
||||
}
|
||||
|
||||
/// An error that can occur during parsing of a public key.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct KeyParseError(#[from] p256::elliptic_curve::Error);
|
||||
|
||||
impl Display for KeyParseError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unable to parse public key: {}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<P256PublicKey> for PublicKey {
|
||||
fn from(value: P256PublicKey) -> Self {
|
||||
let key = value.to_encoded_point(false).as_bytes().to_vec();
|
||||
PublicKey { key }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PublicKey> for P256PublicKey {
|
||||
type Error = KeyParseError;
|
||||
|
||||
fn try_from(value: PublicKey) -> Result<Self, Self::Error> {
|
||||
P256PublicKey::from_sec1_bytes(&value.key).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
@@ -1,27 +1,28 @@
|
||||
//! This module implements a secure two-party computation protocol for adding
|
||||
//! two private EC points and secret-sharing the resulting x coordinate (the
|
||||
//! shares are field elements of the field underlying the elliptic curve).
|
||||
//! This protocol has semi-honest security.
|
||||
//! shares are field elements of the field underlying the elliptic curve). This
|
||||
//! protocol has semi-honest security.
|
||||
//!
|
||||
//! The protocol is described in <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
|
||||
//! The protocol is described in
|
||||
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
|
||||
|
||||
use mpz_common::Context;
|
||||
use crate::{KeyExchangeError, Role};
|
||||
use mpz_common::{Context, Flush};
|
||||
use mpz_fields::{p256::P256, Field};
|
||||
use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive};
|
||||
use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConvert};
|
||||
use p256::EncodedPoint;
|
||||
|
||||
use crate::{config::Role, error::ErrorKind, KeyExchangeError};
|
||||
|
||||
/// Derives the x-coordinate share of an elliptic curve point.
|
||||
pub(crate) async fn derive_x_coord_share<Ctx, C>(
|
||||
pub(crate) async fn derive_x_coord_share<C>(
|
||||
ctx: &mut Context,
|
||||
role: Role,
|
||||
ctx: &mut Ctx,
|
||||
converter: &mut C,
|
||||
share: EncodedPoint,
|
||||
) -> Result<P256, KeyExchangeError>
|
||||
where
|
||||
Ctx: Context,
|
||||
C: AdditiveToMultiplicative<Ctx, P256> + MultiplicativeToAdditive<Ctx, P256>,
|
||||
C: ShareConvert<P256> + Flush + Send,
|
||||
<C as AdditiveToMultiplicative<P256>>::Future: Send,
|
||||
<C as MultiplicativeToAdditive<P256>>::Future: Send,
|
||||
{
|
||||
let [x, y] = decompose_point(share)?;
|
||||
|
||||
@@ -31,16 +32,40 @@ where
|
||||
Role::Follower => vec![-y, -x],
|
||||
};
|
||||
|
||||
let [a, b] = converter
|
||||
.to_multiplicative(ctx, inputs)
|
||||
.await?
|
||||
let a2m = converter
|
||||
.queue_to_multiplicative(&inputs)
|
||||
.map_err(KeyExchangeError::share_conversion)?;
|
||||
|
||||
converter
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)?;
|
||||
|
||||
let [a, b] = a2m
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)?
|
||||
.shares
|
||||
.try_into()
|
||||
.expect("output is same length as input");
|
||||
|
||||
let c = a * b.inverse();
|
||||
let c = a * b
|
||||
.inverse()
|
||||
.expect("field element should not be zero when inverting");
|
||||
let c = c * c;
|
||||
|
||||
let d = converter.to_additive(ctx, vec![c]).await?[0];
|
||||
let m2a = converter
|
||||
.queue_to_additive(&[c])
|
||||
.map_err(KeyExchangeError::share_conversion)?;
|
||||
|
||||
converter
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)?;
|
||||
|
||||
let d = m2a
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)?
|
||||
.shares[0];
|
||||
|
||||
let x_r = d + -x;
|
||||
|
||||
@@ -50,13 +75,11 @@ where
|
||||
/// Decomposes the x and y coordinates of a SEC1 encoded point.
|
||||
fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> {
|
||||
// Coordinates are stored as big-endian bytes.
|
||||
let mut x: [u8; 32] = (*point.x().ok_or(KeyExchangeError::new(
|
||||
ErrorKind::Key,
|
||||
"key share is an identity point",
|
||||
))?)
|
||||
let mut x: [u8; 32] = (*point
|
||||
.x()
|
||||
.ok_or(KeyExchangeError::key("key share is an identity point"))?)
|
||||
.into();
|
||||
let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::new(
|
||||
ErrorKind::Key,
|
||||
let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::key(
|
||||
"key share is an identity point or compressed",
|
||||
))?)
|
||||
.into();
|
||||
@@ -75,20 +98,20 @@ fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use mpz_common::executor::test_st_executor;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_fields::{p256::P256, Field};
|
||||
use mpz_share_conversion::ideal::ideal_share_converter;
|
||||
use mpz_share_conversion::ideal::ideal_share_convert;
|
||||
use p256::{
|
||||
elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint},
|
||||
EncodedPoint, NonZeroScalar, ProjectivePoint, PublicKey,
|
||||
};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha12Rng;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_point_addition() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_executor(8);
|
||||
let mut rng = ChaCha12Rng::from_seed([0u8; 32]);
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let p1: [u8; 32] = rng.gen();
|
||||
let p2: [u8; 32] = rng.gen();
|
||||
@@ -98,11 +121,11 @@ mod tests {
|
||||
|
||||
let p = add_curve_points(&p1, &p2);
|
||||
|
||||
let (mut c_a, mut c_b) = ideal_share_converter();
|
||||
let (mut c_a, mut c_b) = ideal_share_convert(Block::ZERO);
|
||||
|
||||
let (a, b) = tokio::try_join!(
|
||||
derive_x_coord_share(Role::Leader, &mut ctx_a, &mut c_a, p1),
|
||||
derive_x_coord_share(Role::Follower, &mut ctx_b, &mut c_b, p2)
|
||||
derive_x_coord_share(&mut ctx_a, Role::Leader, &mut c_a, p1),
|
||||
derive_x_coord_share(&mut ctx_b, Role::Follower, &mut c_b, p2)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -113,7 +136,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_decompose_point() {
|
||||
let mut rng = ChaCha12Rng::from_seed([0_u8; 32]);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let p_expected: [u8; 32] = rng.gen();
|
||||
let p_expected = curve_point_from_be_bytes(p_expected);
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
[package]
|
||||
name = "tlsn-stream-cipher"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "2PC stream cipher implementation"
|
||||
keywords = ["tls", "mpc", "2pc", "stream-cipher"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.8-pre"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["mock"]
|
||||
rayon = ["mpz-garble/rayon"]
|
||||
mock = []
|
||||
|
||||
[dependencies]
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
tlsn-utils = { workspace = true }
|
||||
aes = { workspace = true }
|
||||
ctr = { workspace = true }
|
||||
cipher = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
opaque-debug = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
futures = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rstest = { workspace = true, features = ["async-timeout"] }
|
||||
criterion = { workspace = true, features = ["async_tokio"] }
|
||||
|
||||
[[bench]]
|
||||
name = "mock"
|
||||
harness = false
|
||||
@@ -1,132 +0,0 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
|
||||
use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory};
|
||||
use tlsn_stream_cipher::{
|
||||
Aes128Ctr, CtrCircuit, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder,
|
||||
};
|
||||
|
||||
async fn bench_stream_cipher_encrypt(len: usize) {
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, [0u8; 16]).unwrap();
|
||||
leader_vm.assign(&leader_iv, [0u8; 4]).unwrap();
|
||||
|
||||
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap();
|
||||
|
||||
follower_vm.assign(&follower_key, [0u8; 16]).unwrap();
|
||||
follower_vm.assign(&follower_iv, [0u8; 4]).unwrap();
|
||||
|
||||
let leader_config = StreamCipherConfigBuilder::default()
|
||||
.id("test".to_string())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let follower_config = StreamCipherConfigBuilder::default()
|
||||
.id("test".to_string())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut leader = MpcStreamCipher::<Aes128Ctr, _>::new(leader_config, leader_vm);
|
||||
leader.set_key(leader_key, leader_iv);
|
||||
|
||||
let mut follower = MpcStreamCipher::<Aes128Ctr, _>::new(follower_config, follower_vm);
|
||||
follower.set_key(follower_key, follower_iv);
|
||||
|
||||
let plaintext = vec![0u8; len];
|
||||
let explicit_nonce = vec![0u8; 8];
|
||||
|
||||
_ = tokio::try_join!(
|
||||
leader.encrypt_private(explicit_nonce.clone(), plaintext),
|
||||
follower.encrypt_blind(explicit_nonce, len)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
_ = tokio::try_join!(
|
||||
leader.thread_mut().finalize(),
|
||||
follower.thread_mut().finalize()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn bench_stream_cipher_zk(len: usize) {
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
let key = [0u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
|
||||
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, key).unwrap();
|
||||
leader_vm.assign(&leader_iv, iv).unwrap();
|
||||
|
||||
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap();
|
||||
|
||||
follower_vm.assign(&follower_key, key).unwrap();
|
||||
follower_vm.assign(&follower_iv, iv).unwrap();
|
||||
|
||||
let leader_config = StreamCipherConfigBuilder::default()
|
||||
.id("test".to_string())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let follower_config = StreamCipherConfigBuilder::default()
|
||||
.id("test".to_string())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut leader = MpcStreamCipher::<Aes128Ctr, _>::new(leader_config, leader_vm);
|
||||
leader.set_key(leader_key, leader_iv);
|
||||
|
||||
let mut follower = MpcStreamCipher::<Aes128Ctr, _>::new(follower_config, follower_vm);
|
||||
follower.set_key(follower_key, follower_iv);
|
||||
|
||||
futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap();
|
||||
|
||||
let plaintext = vec![0u8; len];
|
||||
let explicit_nonce = [0u8; 8];
|
||||
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap();
|
||||
|
||||
_ = tokio::try_join!(
|
||||
leader.prove_plaintext(explicit_nonce.to_vec(), plaintext),
|
||||
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
_ = tokio::try_join!(
|
||||
leader.thread_mut().finalize(),
|
||||
follower.thread_mut().finalize()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let len = 1024;
|
||||
|
||||
let mut group = c.benchmark_group("stream_cipher/encrypt_private");
|
||||
group.throughput(Throughput::Bytes(len as u64));
|
||||
group.bench_function(BenchmarkId::from_parameter(len), |b| {
|
||||
b.to_async(&rt)
|
||||
.iter(|| async { bench_stream_cipher_encrypt(len).await })
|
||||
});
|
||||
|
||||
drop(group);
|
||||
|
||||
let mut group = c.benchmark_group("stream_cipher/zk");
|
||||
group.throughput(Throughput::Bytes(len as u64));
|
||||
group.bench_function(BenchmarkId::from_parameter(len), |b| {
|
||||
b.to_async(&rt)
|
||||
.iter(|| async { bench_stream_cipher_zk(len).await })
|
||||
});
|
||||
|
||||
drop(group);
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
@@ -1,118 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mpz_circuits::{
|
||||
types::{StaticValueType, Value},
|
||||
Circuit,
|
||||
};
|
||||
|
||||
use crate::{circuit::AES_CTR, StreamCipherError};
|
||||
|
||||
/// A counter-mode block cipher circuit.
|
||||
pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {
|
||||
/// The key type.
|
||||
type KEY: StaticValueType + TryFrom<Vec<u8>> + Send + Sync + 'static;
|
||||
/// The block type.
|
||||
type BLOCK: StaticValueType
|
||||
+ TryFrom<Vec<u8>>
|
||||
+ TryFrom<Value>
|
||||
+ Into<Vec<u8>>
|
||||
+ Default
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static;
|
||||
/// The IV type.
|
||||
type IV: StaticValueType
|
||||
+ TryFrom<Vec<u8>>
|
||||
+ TryFrom<Value>
|
||||
+ Into<Vec<u8>>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static;
|
||||
/// The nonce type.
|
||||
type NONCE: StaticValueType
|
||||
+ TryFrom<Vec<u8>>
|
||||
+ TryFrom<Value>
|
||||
+ Into<Vec<u8>>
|
||||
+ Clone
|
||||
+ Copy
|
||||
+ Send
|
||||
+ Sync
|
||||
+ std::fmt::Debug
|
||||
+ 'static;
|
||||
|
||||
/// The length of the key.
|
||||
const KEY_LEN: usize;
|
||||
/// The length of the block.
|
||||
const BLOCK_LEN: usize;
|
||||
/// The length of the IV.
|
||||
const IV_LEN: usize;
|
||||
/// The length of the nonce.
|
||||
const NONCE_LEN: usize;
|
||||
|
||||
/// Returns the circuit of the cipher.
|
||||
fn circuit() -> Arc<Circuit>;
|
||||
|
||||
/// Applies the keystream to the message.
|
||||
fn apply_keystream(
|
||||
key: &[u8],
|
||||
iv: &[u8],
|
||||
start_ctr: usize,
|
||||
explicit_nonce: &[u8],
|
||||
msg: &[u8],
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
}
|
||||
|
||||
/// A circuit for AES-128 in counter mode.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Aes128Ctr;
|
||||
|
||||
impl CtrCircuit for Aes128Ctr {
|
||||
type KEY = [u8; 16];
|
||||
type BLOCK = [u8; 16];
|
||||
type IV = [u8; 4];
|
||||
type NONCE = [u8; 8];
|
||||
|
||||
const KEY_LEN: usize = 16;
|
||||
const BLOCK_LEN: usize = 16;
|
||||
const IV_LEN: usize = 4;
|
||||
const NONCE_LEN: usize = 8;
|
||||
|
||||
fn circuit() -> Arc<Circuit> {
|
||||
AES_CTR.clone()
|
||||
}
|
||||
|
||||
fn apply_keystream(
|
||||
key: &[u8],
|
||||
iv: &[u8],
|
||||
start_ctr: usize,
|
||||
explicit_nonce: &[u8],
|
||||
msg: &[u8],
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
|
||||
use aes::Aes128;
|
||||
use ctr::Ctr32BE;
|
||||
|
||||
let key: &[u8; 16] = key
|
||||
.try_into()
|
||||
.map_err(|_| StreamCipherError::key_len::<Self>(key.len()))?;
|
||||
let iv: &[u8; 4] = iv
|
||||
.try_into()
|
||||
.map_err(|_| StreamCipherError::iv_len::<Self>(iv.len()))?;
|
||||
let explicit_nonce: &[u8; 8] = explicit_nonce
|
||||
.try_into()
|
||||
.map_err(|_| StreamCipherError::explicit_nonce_len::<Self>(explicit_nonce.len()))?;
|
||||
|
||||
let mut full_iv = [0u8; 16];
|
||||
full_iv[0..4].copy_from_slice(iv);
|
||||
full_iv[4..12].copy_from_slice(explicit_nonce);
|
||||
let mut cipher = Ctr32BE::<Aes128>::new(key.into(), &full_iv.into());
|
||||
let mut buf = msg.to_vec();
|
||||
|
||||
cipher
|
||||
.try_seek(start_ctr * Self::BLOCK_LEN)
|
||||
.expect("start counter is less than keystream length");
|
||||
cipher.apply_keystream(&mut buf);
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
use derive_builder::Builder;
|
||||
use std::fmt::Debug;
|
||||
|
||||
/// Configuration for a stream cipher.
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
pub struct StreamCipherConfig {
|
||||
/// The ID of the stream cipher.
|
||||
#[builder(setter(into))]
|
||||
pub(crate) id: String,
|
||||
/// The start block counter value.
|
||||
#[builder(default = "2")]
|
||||
pub(crate) start_ctr: usize,
|
||||
/// Transcript ID used to determine the unique identifiers
|
||||
/// for the plaintext bytes during encryption and decryption.
|
||||
#[builder(setter(into), default = "\"transcript\".to_string()")]
|
||||
pub(crate) transcript_id: String,
|
||||
}
|
||||
|
||||
impl StreamCipherConfig {
|
||||
/// Creates a new builder for the stream cipher configuration.
|
||||
pub fn builder() -> StreamCipherConfigBuilder {
|
||||
StreamCipherConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum InputText {
|
||||
Public { ids: Vec<String>, text: Vec<u8> },
|
||||
Private { ids: Vec<String>, text: Vec<u8> },
|
||||
Blind { ids: Vec<String> },
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for InputText {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Public { ids, .. } => f
|
||||
.debug_struct("Public")
|
||||
.field("ids", ids)
|
||||
.field("text", &"{{ ... }}")
|
||||
.finish(),
|
||||
Self::Private { ids, .. } => f
|
||||
.debug_struct("Private")
|
||||
.field("ids", ids)
|
||||
.field("text", &"{{ ... }}")
|
||||
.finish(),
|
||||
Self::Blind { ids, .. } => f.debug_struct("Blind").field("ids", ids).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The mode of execution.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) enum ExecutionMode {
|
||||
/// Computes either the plaintext or the ciphertext.
|
||||
Mpc,
|
||||
/// Computes the ciphertext and proves its authenticity and correctness.
|
||||
Prove,
|
||||
/// Computes the ciphertext and verifies its authenticity and correctness.
|
||||
Verify,
|
||||
}
|
||||
|
||||
pub(crate) fn is_valid_mode(mode: &ExecutionMode, input_text: &InputText) -> bool {
|
||||
matches!(
|
||||
(mode, input_text),
|
||||
(ExecutionMode::Mpc, _)
|
||||
| (ExecutionMode::Prove, InputText::Private { .. })
|
||||
| (ExecutionMode::Verify, InputText::Blind { .. })
|
||||
)
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
use core::fmt;
|
||||
use std::error::Error;
|
||||
|
||||
use crate::CtrCircuit;
|
||||
|
||||
/// A stream cipher error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct StreamCipherError {
|
||||
kind: ErrorKind,
|
||||
#[source]
|
||||
source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl StreamCipherError {
|
||||
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync>>,
|
||||
{
|
||||
Self {
|
||||
kind,
|
||||
source: Some(source.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn key_len<C: CtrCircuit>(len: usize) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Key,
|
||||
source: Some(
|
||||
format!("invalid key length: expected {}, got {}", C::KEY_LEN, len).into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn iv_len<C: CtrCircuit>(len: usize) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Iv,
|
||||
source: Some(format!("invalid iv length: expected {}, got {}", C::IV_LEN, len).into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn explicit_nonce_len<C: CtrCircuit>(len: usize) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::ExplicitNonce,
|
||||
source: Some(
|
||||
format!(
|
||||
"invalid explicit nonce length: expected {}, got {}",
|
||||
C::NONCE_LEN,
|
||||
len
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn key_not_set() -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Key,
|
||||
source: Some("key not set".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ErrorKind {
|
||||
Vm,
|
||||
Key,
|
||||
Iv,
|
||||
ExplicitNonce,
|
||||
}
|
||||
|
||||
impl fmt::Display for StreamCipherError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::Vm => write!(f, "vm error")?,
|
||||
ErrorKind::Key => write!(f, "key error")?,
|
||||
ErrorKind::Iv => write!(f, "iv error")?,
|
||||
ErrorKind::ExplicitNonce => write!(f, "explicit nonce error")?,
|
||||
}
|
||||
|
||||
if let Some(ref source) = self.source {
|
||||
write!(f, " caused by: {}", source)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::MemoryError> for StreamCipherError {
|
||||
fn from(error: mpz_garble::MemoryError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::LoadError> for StreamCipherError {
|
||||
fn from(error: mpz_garble::LoadError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::ExecutionError> for StreamCipherError {
|
||||
fn from(error: mpz_garble::ExecutionError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::ProveError> for StreamCipherError {
|
||||
fn from(error: mpz_garble::ProveError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::VerifyError> for StreamCipherError {
|
||||
fn from(error: mpz_garble::VerifyError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::DecodeError> for StreamCipherError {
|
||||
fn from(error: mpz_garble::DecodeError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
use std::{collections::VecDeque, marker::PhantomData};
|
||||
|
||||
use mpz_garble::{value::ValueRef, Execute, Load, Memory, Prove, Thread, Verify};
|
||||
use tracing::instrument;
|
||||
use utils::id::NestedId;
|
||||
|
||||
use crate::{config::ExecutionMode, CtrCircuit, StreamCipherError};
|
||||
|
||||
pub(crate) struct KeyStream<C> {
|
||||
block_counter: NestedId,
|
||||
preprocessed: BlockVars,
|
||||
_pd: PhantomData<C>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockVars {
|
||||
blocks: VecDeque<ValueRef>,
|
||||
nonces: VecDeque<ValueRef>,
|
||||
ctrs: VecDeque<ValueRef>,
|
||||
}
|
||||
|
||||
impl BlockVars {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.blocks.is_empty()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.blocks.len()
|
||||
}
|
||||
|
||||
fn drain(&mut self, count: usize) -> BlockVars {
|
||||
let blocks = self.blocks.drain(0..count).collect();
|
||||
let nonces = self.nonces.drain(0..count).collect();
|
||||
let ctrs = self.ctrs.drain(0..count).collect();
|
||||
|
||||
BlockVars {
|
||||
blocks,
|
||||
nonces,
|
||||
ctrs,
|
||||
}
|
||||
}
|
||||
|
||||
fn extend(&mut self, vars: BlockVars) {
|
||||
self.blocks.extend(vars.blocks);
|
||||
self.nonces.extend(vars.nonces);
|
||||
self.ctrs.extend(vars.ctrs);
|
||||
}
|
||||
|
||||
fn iter(&self) -> impl Iterator<Item = (&ValueRef, &ValueRef, &ValueRef)> {
|
||||
self.blocks
|
||||
.iter()
|
||||
.zip(self.nonces.iter())
|
||||
.zip(self.ctrs.iter())
|
||||
.map(|((block, nonce), ctr)| (block, nonce, ctr))
|
||||
}
|
||||
|
||||
fn flatten(&self, len: usize) -> Vec<ValueRef> {
|
||||
self.blocks
|
||||
.iter()
|
||||
.flat_map(|block| block.iter())
|
||||
.take(len)
|
||||
.cloned()
|
||||
.map(|byte| ValueRef::Value { id: byte })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: CtrCircuit> KeyStream<C> {
|
||||
pub(crate) fn new(id: &str) -> Self {
|
||||
let block_counter = NestedId::new(id).append_counter();
|
||||
Self {
|
||||
block_counter,
|
||||
preprocessed: BlockVars::default(),
|
||||
_pd: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn define_vars(
|
||||
&mut self,
|
||||
mem: &mut impl Memory,
|
||||
count: usize,
|
||||
) -> Result<BlockVars, StreamCipherError> {
|
||||
let mut vars = BlockVars::default();
|
||||
for _ in 0..count {
|
||||
let block_id = self.block_counter.increment_in_place();
|
||||
let block = mem.new_output::<C::BLOCK>(&block_id.to_string())?;
|
||||
let nonce =
|
||||
mem.new_public_input::<C::NONCE>(&block_id.append_string("nonce").to_string())?;
|
||||
let ctr =
|
||||
mem.new_public_input::<[u8; 4]>(&block_id.append_string("ctr").to_string())?;
|
||||
|
||||
vars.blocks.push_back(block);
|
||||
vars.nonces.push_back(nonce);
|
||||
vars.ctrs.push_back(ctr);
|
||||
}
|
||||
|
||||
Ok(vars)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) async fn preprocess<T>(
|
||||
&mut self,
|
||||
thread: &mut T,
|
||||
key: &ValueRef,
|
||||
iv: &ValueRef,
|
||||
len: usize,
|
||||
) -> Result<(), StreamCipherError>
|
||||
where
|
||||
T: Thread + Memory + Load + Send + 'static,
|
||||
{
|
||||
let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize;
|
||||
let vars = self.define_vars(thread, block_count)?;
|
||||
|
||||
let calls = vars
|
||||
.iter()
|
||||
.map(|(block, nonce, ctr)| {
|
||||
(
|
||||
C::circuit(),
|
||||
vec![key.clone(), iv.clone(), nonce.clone(), ctr.clone()],
|
||||
vec![block.clone()],
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (circ, inputs, outputs) in calls {
|
||||
thread.load(circ, &inputs, &outputs).await?;
|
||||
}
|
||||
|
||||
self.preprocessed.extend(vars);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn compute<T>(
|
||||
&mut self,
|
||||
thread: &mut T,
|
||||
mode: ExecutionMode,
|
||||
key: &ValueRef,
|
||||
iv: &ValueRef,
|
||||
explicit_nonce: Vec<u8>,
|
||||
start_ctr: usize,
|
||||
len: usize,
|
||||
) -> Result<ValueRef, StreamCipherError>
|
||||
where
|
||||
T: Thread + Memory + Execute + Prove + Verify + Send + 'static,
|
||||
{
|
||||
let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize;
|
||||
let explicit_nonce_len = explicit_nonce.len();
|
||||
let explicit_nonce: C::NONCE = explicit_nonce
|
||||
.try_into()
|
||||
.map_err(|_| StreamCipherError::explicit_nonce_len::<C>(explicit_nonce_len))?;
|
||||
|
||||
// Take any preprocessed blocks if available, and define new ones if needed.
|
||||
let vars = if !self.preprocessed.is_empty() {
|
||||
let mut vars = self
|
||||
.preprocessed
|
||||
.drain(block_count.min(self.preprocessed.len()));
|
||||
if vars.len() < block_count {
|
||||
vars.extend(self.define_vars(thread, block_count - vars.len())?)
|
||||
}
|
||||
vars
|
||||
} else {
|
||||
self.define_vars(thread, block_count)?
|
||||
};
|
||||
|
||||
let mut calls = Vec::with_capacity(vars.len());
|
||||
let mut inputs = Vec::with_capacity(vars.len() * 4);
|
||||
for (i, (block, nonce_ref, ctr_ref)) in vars.iter().enumerate() {
|
||||
thread.assign(nonce_ref, explicit_nonce)?;
|
||||
thread.assign(ctr_ref, ((start_ctr + i) as u32).to_be_bytes())?;
|
||||
|
||||
inputs.push(key.clone());
|
||||
inputs.push(iv.clone());
|
||||
inputs.push(nonce_ref.clone());
|
||||
inputs.push(ctr_ref.clone());
|
||||
|
||||
calls.push((
|
||||
C::circuit(),
|
||||
vec![key.clone(), iv.clone(), nonce_ref.clone(), ctr_ref.clone()],
|
||||
vec![block.clone()],
|
||||
));
|
||||
}
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Mpc => {
|
||||
thread.commit(&inputs).await?;
|
||||
for (circ, inputs, outputs) in calls {
|
||||
thread.execute(circ, &inputs, &outputs).await?;
|
||||
}
|
||||
}
|
||||
ExecutionMode::Prove => {
|
||||
// Note that after the circuit execution, the value of `block` can be considered
|
||||
// as implicitly authenticated since `key` and `iv` have already
|
||||
// been authenticated earlier and `nonce_ref` and `ctr_ref` are
|
||||
// public. [Prove::prove] will **not** be called on `block` at
|
||||
// any later point.
|
||||
thread.commit_prove(&inputs).await?;
|
||||
for (circ, inputs, outputs) in calls {
|
||||
thread.execute_prove(circ, &inputs, &outputs).await?;
|
||||
}
|
||||
}
|
||||
ExecutionMode::Verify => {
|
||||
thread.commit_verify(&inputs).await?;
|
||||
for (circ, inputs, outputs) in calls {
|
||||
thread.execute_verify(circ, &inputs, &outputs).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let keystream = thread.array_from_values(&vars.flatten(len))?;
|
||||
|
||||
Ok(keystream)
|
||||
}
|
||||
}
|
||||
@@ -1,474 +0,0 @@
|
||||
//! This crate provides a 2PC stream cipher implementation using a block cipher
|
||||
//! in counter mode.
|
||||
//!
|
||||
//! Each party plays a specific role, either the `StreamCipherLeader` or the
|
||||
//! `StreamCipherFollower`. Both parties work together to encrypt and decrypt
|
||||
//! messages using a shared key.
|
||||
//!
|
||||
//! # Transcript
|
||||
//!
|
||||
//! Using the `record` flag, the `StreamCipherFollower` can optionally use a
|
||||
//! dedicated stream when encoding the plaintext labels, which allows the
|
||||
//! `StreamCipherLeader` to build a transcript of active labels which are pushed
|
||||
//! to the provided `TranscriptSink`.
|
||||
//!
|
||||
//! Afterwards, the `StreamCipherLeader` can create commitments to the
|
||||
//! transcript which can be used in a selective disclosure protocol.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![deny(unsafe_code)]
|
||||
|
||||
mod cipher;
|
||||
mod circuit;
|
||||
mod config;
|
||||
pub(crate) mod error;
|
||||
pub(crate) mod keystream;
|
||||
mod stream_cipher;
|
||||
|
||||
pub use self::cipher::{Aes128Ctr, CtrCircuit};
|
||||
pub use config::{StreamCipherConfig, StreamCipherConfigBuilder, StreamCipherConfigBuilderError};
|
||||
pub use error::StreamCipherError;
|
||||
pub use stream_cipher::MpcStreamCipher;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_garble::value::ValueRef;
|
||||
|
||||
/// A trait for MPC stream ciphers.
|
||||
#[async_trait]
|
||||
pub trait StreamCipher<Cipher>: Send + Sync
|
||||
where
|
||||
Cipher: cipher::CtrCircuit,
|
||||
{
|
||||
/// Sets the key and iv for the stream cipher.
|
||||
fn set_key(&mut self, key: ValueRef, iv: ValueRef);
|
||||
|
||||
/// Decodes the key for the stream cipher, revealing it to this party.
|
||||
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError>;
|
||||
|
||||
/// Decodes the key for the stream cipher, revealing it to the other
|
||||
/// party(s).
|
||||
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError>;
|
||||
|
||||
/// Sets the transcript id
|
||||
///
|
||||
/// The stream cipher assigns unique identifiers to each byte of plaintext
|
||||
/// during encryption and decryption.
|
||||
///
|
||||
/// For example, if the transcript id is set to `foo`, then the first byte
|
||||
/// will be assigned the id `foo/0`, the second byte `foo/1`, and so on.
|
||||
///
|
||||
/// Each transcript id has an independent counter.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The state of a transcript counter is preserved between calls to
|
||||
/// `set_transcript_id`.
|
||||
fn set_transcript_id(&mut self, id: &str);
|
||||
|
||||
/// Preprocesses the keystream for the given number of bytes.
|
||||
async fn preprocess(&mut self, len: usize) -> Result<(), StreamCipherError>;
|
||||
|
||||
/// Applies the keystream to the given plaintext, where all parties
|
||||
/// provide the plaintext as an input.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `plaintext` - The message to apply the keystream to.
|
||||
async fn encrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
|
||||
/// Applies the keystream to the given plaintext without revealing it
|
||||
/// to the other party(s).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `plaintext` - The message to apply the keystream to.
|
||||
async fn encrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
|
||||
/// Applies the keystream to a plaintext provided by another party.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `len` - The length of the plaintext provided by another party.
|
||||
async fn encrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
|
||||
/// Decrypts a ciphertext by removing the keystream, where the plaintext
|
||||
/// is revealed to all parties.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to decrypt.
|
||||
async fn decrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
|
||||
/// Decrypts a ciphertext by removing the keystream, where the plaintext
|
||||
/// is only revealed to this party.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to decrypt.
|
||||
async fn decrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
|
||||
/// Decrypts a ciphertext by removing the keystream, where the plaintext
|
||||
/// is not revealed to this party.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to decrypt.
|
||||
async fn decrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<(), StreamCipherError>;
|
||||
|
||||
/// Locally decrypts the provided ciphertext and then proves in ZK to the
|
||||
/// other party(s) that the plaintext is correct.
|
||||
///
|
||||
/// Returns the plaintext.
|
||||
///
|
||||
/// This method requires this party to know the encryption key, which can be
|
||||
/// achieved by calling the `decode_key_private` method.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to decrypt and prove.
|
||||
async fn prove_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
|
||||
/// Verifies the other party(s) can prove they know a plaintext which
|
||||
/// encrypts to the given ciphertext.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
|
||||
/// * `ciphertext` - The ciphertext to verify.
|
||||
async fn verify_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<(), StreamCipherError>;
|
||||
|
||||
/// Returns an additive share of the keystream block for the given explicit
|
||||
/// nonce and counter.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `explicit_nonce` - The explicit nonce to use for the keystream block.
|
||||
/// * `ctr` - The counter to use for the keystream block.
|
||||
async fn share_keystream_block(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ctr: usize,
|
||||
) -> Result<Vec<u8>, StreamCipherError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::cipher::Aes128Ctr;
|
||||
|
||||
use super::*;
|
||||
|
||||
use mpz_garble::{
|
||||
protocol::deap::mock::{create_mock_deap_vm, MockFollower, MockLeader},
|
||||
Memory,
|
||||
};
|
||||
use rstest::*;
|
||||
|
||||
async fn create_test_pair<C: CtrCircuit>(
|
||||
start_ctr: usize,
|
||||
key: [u8; 16],
|
||||
iv: [u8; 4],
|
||||
) -> (
|
||||
MpcStreamCipher<C, MockLeader>,
|
||||
MpcStreamCipher<C, MockFollower>,
|
||||
) {
|
||||
let (leader_vm, follower_vm) = create_mock_deap_vm();
|
||||
|
||||
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap();
|
||||
|
||||
leader_vm.assign(&leader_key, key).unwrap();
|
||||
leader_vm.assign(&leader_iv, iv).unwrap();
|
||||
|
||||
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
|
||||
let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap();
|
||||
|
||||
follower_vm.assign(&follower_key, key).unwrap();
|
||||
follower_vm.assign(&follower_iv, iv).unwrap();
|
||||
|
||||
let leader_config = StreamCipherConfig::builder()
|
||||
.id("test")
|
||||
.start_ctr(start_ctr)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let follower_config = StreamCipherConfig::builder()
|
||||
.id("test")
|
||||
.start_ctr(start_ctr)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut leader = MpcStreamCipher::<C, _>::new(leader_config, leader_vm);
|
||||
leader.set_key(leader_key, leader_iv);
|
||||
|
||||
let mut follower = MpcStreamCipher::<C, _>::new(follower_config, follower_vm);
|
||||
follower.set_key(follower_key, follower_iv);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[timeout(Duration::from_millis(10000))]
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_stream_cipher_public() {
|
||||
let key = [0u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
let explicit_nonce = [0u8; 8];
|
||||
|
||||
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
|
||||
|
||||
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
|
||||
|
||||
let leader_fut = async {
|
||||
let leader_encrypted_msg = leader
|
||||
.encrypt_public(explicit_nonce.to_vec(), msg.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let leader_decrypted_msg = leader
|
||||
.decrypt_public(explicit_nonce.to_vec(), leader_encrypted_msg.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
(leader_encrypted_msg, leader_decrypted_msg)
|
||||
};
|
||||
|
||||
let follower_fut = async {
|
||||
let follower_encrypted_msg = follower
|
||||
.encrypt_public(explicit_nonce.to_vec(), msg.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let follower_decrypted_msg = follower
|
||||
.decrypt_public(explicit_nonce.to_vec(), follower_encrypted_msg.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
(follower_encrypted_msg, follower_decrypted_msg)
|
||||
};
|
||||
|
||||
let (
|
||||
(leader_encrypted_msg, leader_decrypted_msg),
|
||||
(follower_encrypted_msg, follower_decrypted_msg),
|
||||
) = futures::join!(leader_fut, follower_fut);
|
||||
|
||||
let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
|
||||
|
||||
assert_eq!(leader_encrypted_msg, reference);
|
||||
assert_eq!(leader_decrypted_msg, msg);
|
||||
assert_eq!(follower_encrypted_msg, reference);
|
||||
assert_eq!(follower_decrypted_msg, msg);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[timeout(Duration::from_millis(10000))]
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_stream_cipher_private() {
|
||||
let key = [0u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
let explicit_nonce = [1u8; 8];
|
||||
|
||||
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
|
||||
|
||||
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
|
||||
|
||||
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
|
||||
|
||||
let leader_fut = async {
|
||||
let leader_decrypted_msg = leader
|
||||
.decrypt_private(explicit_nonce.to_vec(), ciphertext.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let leader_encrypted_msg = leader
|
||||
.encrypt_private(explicit_nonce.to_vec(), leader_decrypted_msg.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
(leader_encrypted_msg, leader_decrypted_msg)
|
||||
};
|
||||
|
||||
let follower_fut = async {
|
||||
follower
|
||||
.decrypt_blind(explicit_nonce.to_vec(), ciphertext.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
follower
|
||||
.encrypt_blind(explicit_nonce.to_vec(), msg.len())
|
||||
.await
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let ((leader_encrypted_msg, leader_decrypted_msg), follower_encrypted_msg) =
|
||||
futures::join!(leader_fut, follower_fut);
|
||||
|
||||
assert_eq!(leader_encrypted_msg, ciphertext);
|
||||
assert_eq!(leader_decrypted_msg, msg);
|
||||
assert_eq!(follower_encrypted_msg, ciphertext);
|
||||
|
||||
futures::try_join!(
|
||||
leader.thread_mut().finalize(),
|
||||
follower.thread_mut().finalize()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[timeout(Duration::from_millis(10000))]
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_stream_cipher_share_key_block() {
|
||||
let key = [0u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
let explicit_nonce = [0u8; 8];
|
||||
|
||||
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
|
||||
|
||||
let leader_fut = async {
|
||||
leader
|
||||
.share_keystream_block(explicit_nonce.to_vec(), 1)
|
||||
.await
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let follower_fut = async {
|
||||
follower
|
||||
.share_keystream_block(explicit_nonce.to_vec(), 1)
|
||||
.await
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let (leader_share, follower_share) = futures::join!(leader_fut, follower_fut);
|
||||
|
||||
let key_block = leader_share
|
||||
.into_iter()
|
||||
.zip(follower_share)
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
let reference =
|
||||
Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]).unwrap();
|
||||
|
||||
assert_eq!(reference, key_block);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[timeout(Duration::from_millis(10000))]
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_stream_cipher_zk() {
|
||||
let key = [0u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
let explicit_nonce = [1u8; 8];
|
||||
|
||||
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
|
||||
|
||||
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg).unwrap();
|
||||
|
||||
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(2, key, iv).await;
|
||||
|
||||
futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap();
|
||||
|
||||
futures::try_join!(
|
||||
leader.prove_plaintext(explicit_nonce.to_vec(), ciphertext.clone()),
|
||||
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
|
||||
)
|
||||
.unwrap();
|
||||
futures::try_join!(
|
||||
leader.thread_mut().finalize(),
|
||||
follower.thread_mut().finalize()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::one_block(16)]
|
||||
#[case::partial(17)]
|
||||
#[case::extra(128)]
|
||||
#[timeout(Duration::from_millis(10000))]
|
||||
#[tokio::test]
|
||||
#[ignore = "expensive"]
|
||||
async fn test_stream_cipher_preprocess(#[case] len: usize) {
|
||||
let key = [0u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
let explicit_nonce = [1u8; 8];
|
||||
|
||||
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
|
||||
|
||||
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
|
||||
|
||||
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
|
||||
|
||||
let leader_fut = async {
|
||||
leader.preprocess(len).await.unwrap();
|
||||
|
||||
leader
|
||||
.decrypt_private(explicit_nonce.to_vec(), ciphertext.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let follower_fut = async {
|
||||
follower.preprocess(len).await.unwrap();
|
||||
|
||||
follower
|
||||
.decrypt_blind(explicit_nonce.to_vec(), ciphertext.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
let (leader_decrypted_msg, _) = futures::join!(leader_fut, follower_fut);
|
||||
|
||||
assert_eq!(leader_decrypted_msg, msg);
|
||||
|
||||
futures::try_join!(
|
||||
leader.thread_mut().finalize(),
|
||||
follower.thread_mut().finalize()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use mpz_circuits::types::Value;
|
||||
use std::collections::HashMap;
|
||||
use tracing::instrument;
|
||||
|
||||
use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Load, Prove, Thread, Verify};
|
||||
use utils::id::NestedId;
|
||||
|
||||
use crate::{
|
||||
cipher::CtrCircuit,
|
||||
circuit::build_array_xor,
|
||||
config::{is_valid_mode, ExecutionMode, InputText, StreamCipherConfig},
|
||||
keystream::KeyStream,
|
||||
StreamCipher, StreamCipherError,
|
||||
};
|
||||
|
||||
/// An MPC stream cipher.
|
||||
#[derive(Debug)]
|
||||
pub struct MpcStreamCipher<C, E>
|
||||
where
|
||||
C: CtrCircuit,
|
||||
E: Thread + Execute + Decode + DecodePrivate + Send + Sync,
|
||||
{
|
||||
config: StreamCipherConfig,
|
||||
state: State<C>,
|
||||
thread: E,
|
||||
}
|
||||
|
||||
struct State<C> {
|
||||
/// Encoded key and IV for the cipher.
|
||||
encoded_key_iv: Option<EncodedKeyAndIv>,
|
||||
/// Key and IV for the cipher.
|
||||
key_iv: Option<KeyAndIv>,
|
||||
/// Keystream state.
|
||||
keystream: KeyStream<C>,
|
||||
/// Current transcript.
|
||||
transcript: Transcript,
|
||||
/// Maps a transcript ID to the corresponding transcript.
|
||||
transcripts: HashMap<String, Transcript>,
|
||||
/// Number of messages operated on.
|
||||
counter: usize,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(State<C>);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EncodedKeyAndIv {
|
||||
key: ValueRef,
|
||||
iv: ValueRef,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct KeyAndIv {
|
||||
key: Vec<u8>,
|
||||
iv: Vec<u8>,
|
||||
}
|
||||
|
||||
/// A subset of plaintext bytes processed by the stream cipher.
|
||||
///
|
||||
/// Note that `Transcript` does not store the actual bytes. Instead, it provides
|
||||
/// IDs which are assigned to plaintext bytes of the stream cipher.
|
||||
struct Transcript {
|
||||
/// The ID of this transcript.
|
||||
id: String,
|
||||
/// The ID for the next plaintext byte.
|
||||
plaintext: NestedId,
|
||||
}
|
||||
|
||||
impl Transcript {
|
||||
fn new(id: &str) -> Self {
|
||||
Self {
|
||||
id: id.to_string(),
|
||||
plaintext: NestedId::new(id).append_counter(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns unique identifiers for the next plaintext bytes in the
|
||||
/// transcript.
|
||||
fn extend_plaintext(&mut self, len: usize) -> Vec<String> {
|
||||
(0..len)
|
||||
.map(|_| self.plaintext.increment_in_place().to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, E> MpcStreamCipher<C, E>
|
||||
where
|
||||
C: CtrCircuit,
|
||||
E: Thread + Execute + Load + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static,
|
||||
{
|
||||
/// Creates a new counter-mode cipher.
|
||||
pub fn new(config: StreamCipherConfig, thread: E) -> Self {
|
||||
let keystream = KeyStream::new(&config.id);
|
||||
let transcript = Transcript::new(&config.transcript_id);
|
||||
Self {
|
||||
config,
|
||||
state: State {
|
||||
encoded_key_iv: None,
|
||||
key_iv: None,
|
||||
keystream,
|
||||
transcript,
|
||||
transcripts: HashMap::new(),
|
||||
counter: 0,
|
||||
},
|
||||
thread,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the underlying thread.
|
||||
pub fn thread_mut(&mut self) -> &mut E {
|
||||
&mut self.thread
|
||||
}
|
||||
|
||||
/// Computes a keystream of the given length.
|
||||
async fn compute_keystream(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
start_ctr: usize,
|
||||
len: usize,
|
||||
mode: ExecutionMode,
|
||||
) -> Result<ValueRef, StreamCipherError> {
|
||||
let EncodedKeyAndIv { key, iv } = self
|
||||
.state
|
||||
.encoded_key_iv
|
||||
.as_ref()
|
||||
.ok_or_else(StreamCipherError::key_not_set)?;
|
||||
|
||||
let keystream = self
|
||||
.state
|
||||
.keystream
|
||||
.compute(
|
||||
&mut self.thread,
|
||||
mode,
|
||||
key,
|
||||
iv,
|
||||
explicit_nonce,
|
||||
start_ctr,
|
||||
len,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.state.counter += 1;
|
||||
|
||||
Ok(keystream)
|
||||
}
|
||||
|
||||
/// Applies the keystream to the provided input text.
|
||||
async fn apply_keystream(
|
||||
&mut self,
|
||||
mode: ExecutionMode,
|
||||
input_text: InputText,
|
||||
keystream: ValueRef,
|
||||
) -> Result<ValueRef, StreamCipherError> {
|
||||
debug_assert!(
|
||||
is_valid_mode(&mode, &input_text),
|
||||
"invalid execution mode for input text"
|
||||
);
|
||||
|
||||
let input_text = match input_text {
|
||||
InputText::Public { ids, text } => {
|
||||
let refs = text
|
||||
.into_iter()
|
||||
.zip(ids)
|
||||
.map(|(byte, id)| {
|
||||
let value_ref = self.thread.new_public_input::<u8>(&id)?;
|
||||
self.thread.assign(&value_ref, byte)?;
|
||||
|
||||
Ok::<_, StreamCipherError>(value_ref)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
self.thread.array_from_values(&refs)?
|
||||
}
|
||||
InputText::Private { ids, text } => {
|
||||
let refs = text
|
||||
.into_iter()
|
||||
.zip(ids)
|
||||
.map(|(byte, id)| {
|
||||
let value_ref = self.thread.new_private_input::<u8>(&id)?;
|
||||
self.thread.assign(&value_ref, byte)?;
|
||||
|
||||
Ok::<_, StreamCipherError>(value_ref)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
self.thread.array_from_values(&refs)?
|
||||
}
|
||||
InputText::Blind { ids } => {
|
||||
let refs = ids
|
||||
.into_iter()
|
||||
.map(|id| self.thread.new_blind_input::<u8>(&id))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
self.thread.array_from_values(&refs)?
|
||||
}
|
||||
};
|
||||
|
||||
let output_text = self.thread.new_array_output::<u8>(
|
||||
&format!("{}/out/{}", self.config.id, self.state.counter),
|
||||
input_text.len(),
|
||||
)?;
|
||||
|
||||
let circ = build_array_xor(input_text.len());
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Mpc => {
|
||||
self.thread
|
||||
.execute(circ, &[input_text, keystream], &[output_text.clone()])
|
||||
.await?;
|
||||
}
|
||||
ExecutionMode::Prove => {
|
||||
self.thread
|
||||
.execute_prove(circ, &[input_text, keystream], &[output_text.clone()])
|
||||
.await?;
|
||||
}
|
||||
ExecutionMode::Verify => {
|
||||
self.thread
|
||||
.execute_verify(circ, &[input_text, keystream], &[output_text.clone()])
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output_text)
|
||||
}
|
||||
|
||||
async fn decode_public(&mut self, value: ValueRef) -> Result<Value, StreamCipherError> {
|
||||
self.thread
|
||||
.decode(&[value])
|
||||
.await
|
||||
.map_err(StreamCipherError::from)
|
||||
.map(|mut output| output.pop().unwrap())
|
||||
}
|
||||
|
||||
async fn decode_shared(&mut self, value: ValueRef) -> Result<Value, StreamCipherError> {
|
||||
self.thread
|
||||
.decode_shared(&[value])
|
||||
.await
|
||||
.map_err(StreamCipherError::from)
|
||||
.map(|mut output| output.pop().unwrap())
|
||||
}
|
||||
|
||||
async fn decode_private(&mut self, value: ValueRef) -> Result<Value, StreamCipherError> {
|
||||
self.thread
|
||||
.decode_private(&[value])
|
||||
.await
|
||||
.map_err(StreamCipherError::from)
|
||||
.map(|mut output| output.pop().unwrap())
|
||||
}
|
||||
|
||||
async fn decode_blind(&mut self, value: ValueRef) -> Result<(), StreamCipherError> {
|
||||
self.thread.decode_blind(&[value]).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn prove(&mut self, value: ValueRef) -> Result<(), StreamCipherError> {
|
||||
self.thread.prove(&[value]).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn verify(&mut self, value: ValueRef, expected: Value) -> Result<(), StreamCipherError> {
|
||||
self.thread.verify(&[value], &[expected]).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<C, E> StreamCipher<C> for MpcStreamCipher<C, E>
|
||||
where
|
||||
C: CtrCircuit,
|
||||
E: Thread + Execute + Load + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static,
|
||||
{
|
||||
fn set_key(&mut self, key: ValueRef, iv: ValueRef) {
|
||||
self.state.encoded_key_iv = Some(EncodedKeyAndIv { key, iv });
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError> {
|
||||
let EncodedKeyAndIv { key, iv } = self
|
||||
.state
|
||||
.encoded_key_iv
|
||||
.clone()
|
||||
.ok_or_else(StreamCipherError::key_not_set)?;
|
||||
|
||||
let [key, iv]: [_; 2] = self
|
||||
.thread
|
||||
.decode_private(&[key, iv])
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("decoded 2 values");
|
||||
|
||||
let key: Vec<u8> = key.try_into().expect("key is an array");
|
||||
let iv: Vec<u8> = iv.try_into().expect("iv is an array");
|
||||
|
||||
self.state.key_iv = Some(KeyAndIv { key, iv });
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError> {
|
||||
let EncodedKeyAndIv { key, iv } = self
|
||||
.state
|
||||
.encoded_key_iv
|
||||
.clone()
|
||||
.ok_or_else(StreamCipherError::key_not_set)?;
|
||||
|
||||
self.thread.decode_blind(&[key, iv]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_transcript_id(&mut self, id: &str) {
|
||||
if id == self.state.transcript.id {
|
||||
return;
|
||||
}
|
||||
|
||||
let transcript = self
|
||||
.state
|
||||
.transcripts
|
||||
.remove(id)
|
||||
.unwrap_or_else(|| Transcript::new(id));
|
||||
let old_transcript = std::mem::replace(&mut self.state.transcript, transcript);
|
||||
self.state
|
||||
.transcripts
|
||||
.insert(old_transcript.id.clone(), old_transcript);
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn preprocess(&mut self, len: usize) -> Result<(), StreamCipherError> {
|
||||
let EncodedKeyAndIv { key, iv } = self
|
||||
.state
|
||||
.encoded_key_iv
|
||||
.as_ref()
|
||||
.ok_or_else(StreamCipherError::key_not_set)?;
|
||||
|
||||
self.state
|
||||
.keystream
|
||||
.preprocess(&mut self.thread, key, iv, len)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
let keystream = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
plaintext.len(),
|
||||
ExecutionMode::Mpc,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
|
||||
let ciphertext = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Mpc,
|
||||
InputText::Public {
|
||||
ids: plaintext_ids,
|
||||
text: plaintext,
|
||||
},
|
||||
keystream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ciphertext: Vec<u8> = self
|
||||
.decode_public(ciphertext)
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("ciphertext is array");
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
plaintext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
let keystream = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
plaintext.len(),
|
||||
ExecutionMode::Mpc,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
|
||||
let ciphertext = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Mpc,
|
||||
InputText::Private {
|
||||
ids: plaintext_ids,
|
||||
text: plaintext,
|
||||
},
|
||||
keystream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ciphertext: Vec<u8> = self
|
||||
.decode_public(ciphertext)
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("ciphertext is array");
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn encrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
let keystream = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
len,
|
||||
ExecutionMode::Mpc,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(len);
|
||||
let ciphertext = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Mpc,
|
||||
InputText::Blind { ids: plaintext_ids },
|
||||
keystream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ciphertext: Vec<u8> = self
|
||||
.decode_public(ciphertext)
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("ciphertext is array");
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decrypt_public(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
// TODO: We may want to support writing to the transcript when decrypting
|
||||
// in public mode.
|
||||
let keystream = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
ciphertext.len(),
|
||||
ExecutionMode::Mpc,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ciphertext_ids = (0..ciphertext.len())
|
||||
.map(|i| format!("ct/{}/{}", self.state.counter, i))
|
||||
.collect();
|
||||
let plaintext = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Mpc,
|
||||
InputText::Public {
|
||||
ids: ciphertext_ids,
|
||||
text: ciphertext,
|
||||
},
|
||||
keystream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext: Vec<u8> = self
|
||||
.decode_public(plaintext)
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("plaintext is array");
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decrypt_private(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
let keystream_ref = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
ciphertext.len(),
|
||||
ExecutionMode::Mpc,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let keystream: Vec<u8> = self
|
||||
.decode_private(keystream_ref.clone())
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("keystream is array");
|
||||
|
||||
let plaintext = ciphertext
|
||||
.into_iter()
|
||||
.zip(keystream)
|
||||
.map(|(c, k)| c ^ k)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Prove plaintext encrypts back to ciphertext.
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
|
||||
let ciphertext = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Prove,
|
||||
InputText::Private {
|
||||
ids: plaintext_ids,
|
||||
text: plaintext.clone(),
|
||||
},
|
||||
keystream_ref,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.prove(ciphertext).await?;
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn decrypt_blind(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<(), StreamCipherError> {
|
||||
let keystream_ref = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
ciphertext.len(),
|
||||
ExecutionMode::Mpc,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.decode_blind(keystream_ref.clone()).await?;
|
||||
|
||||
// Verify the plaintext encrypts back to ciphertext.
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(ciphertext.len());
|
||||
let ciphertext_ref = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Verify,
|
||||
InputText::Blind { ids: plaintext_ids },
|
||||
keystream_ref,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.verify(ciphertext_ref, ciphertext.into()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn prove_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
let KeyAndIv { key, iv } = self
|
||||
.state
|
||||
.key_iv
|
||||
.clone()
|
||||
.ok_or_else(StreamCipherError::key_not_set)?;
|
||||
|
||||
let plaintext = C::apply_keystream(
|
||||
&key,
|
||||
&iv,
|
||||
self.config.start_ctr,
|
||||
&explicit_nonce,
|
||||
&ciphertext,
|
||||
)?;
|
||||
|
||||
// Prove plaintext encrypts back to ciphertext.
|
||||
let keystream = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
plaintext.len(),
|
||||
ExecutionMode::Prove,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
|
||||
let ciphertext = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Prove,
|
||||
InputText::Private {
|
||||
ids: plaintext_ids,
|
||||
text: plaintext.clone(),
|
||||
},
|
||||
keystream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.prove(ciphertext).await?;
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn verify_plaintext(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<(), StreamCipherError> {
|
||||
let keystream = self
|
||||
.compute_keystream(
|
||||
explicit_nonce,
|
||||
self.config.start_ctr,
|
||||
ciphertext.len(),
|
||||
ExecutionMode::Verify,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let plaintext_ids = self.state.transcript.extend_plaintext(ciphertext.len());
|
||||
let ciphertext_ref = self
|
||||
.apply_keystream(
|
||||
ExecutionMode::Verify,
|
||||
InputText::Blind { ids: plaintext_ids },
|
||||
keystream,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.verify(ciphertext_ref, ciphertext.into()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn share_keystream_block(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ctr: usize,
|
||||
) -> Result<Vec<u8>, StreamCipherError> {
|
||||
let EncodedKeyAndIv { key, iv } = self
|
||||
.state
|
||||
.encoded_key_iv
|
||||
.as_ref()
|
||||
.ok_or_else(StreamCipherError::key_not_set)?;
|
||||
|
||||
let key_block = self
|
||||
.state
|
||||
.keystream
|
||||
.compute(
|
||||
&mut self.thread,
|
||||
ExecutionMode::Mpc,
|
||||
key,
|
||||
iv,
|
||||
explicit_nonce,
|
||||
ctr,
|
||||
C::BLOCK_LEN,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let share = self
|
||||
.decode_shared(key_block)
|
||||
.await?
|
||||
.try_into()
|
||||
.expect("key block is array");
|
||||
|
||||
Ok(share)
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ pub struct Ghash<C, Ctx> {
|
||||
|
||||
impl<C, Ctx> Ghash<C, Ctx>
|
||||
where
|
||||
Ctx: Context,
|
||||
|
||||
C: ShareConvert<Ctx, Gf2_128>,
|
||||
{
|
||||
/// Creates a new instance.
|
||||
@@ -91,7 +91,7 @@ impl<C, Ctx> Debug for Ghash<C, Ctx> {
|
||||
#[async_trait]
|
||||
impl<Ctx, C> UniversalHash for Ghash<C, Ctx>
|
||||
where
|
||||
Ctx: Context,
|
||||
|
||||
C: Preprocess<Ctx, Error = ShareConversionError> + ShareConvert<Ctx, Gf2_128> + Send,
|
||||
{
|
||||
#[instrument(level = "info", fields(thread = %self.context.id()), skip_all, err)]
|
||||
|
||||
@@ -30,6 +30,7 @@ opaque-debug = { workspace = true }
|
||||
p256 = { workspace = true, features = ["serde"] }
|
||||
rand = { workspace = true }
|
||||
rand_core = { workspace = true }
|
||||
rand_chacha = { workspace = true }
|
||||
rs_merkle = { workspace = true, features = ["serde"] }
|
||||
rstest = { workspace = true, optional = true }
|
||||
serde = { workspace = true }
|
||||
@@ -38,6 +39,7 @@ thiserror = { workspace = true }
|
||||
tiny-keccak = { version = "2.0", features = ["keccak"] }
|
||||
web-time = { workspace = true }
|
||||
webpki-roots = { workspace = true }
|
||||
itybity = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
bincode = { workspace = true }
|
||||
|
||||
@@ -12,6 +12,7 @@ use crate::{
|
||||
request::Request,
|
||||
serialize::CanonicalSerialize,
|
||||
signing::SignatureAlgId,
|
||||
transcript::encoding::EncoderSecret,
|
||||
CryptoProvider,
|
||||
};
|
||||
|
||||
@@ -25,7 +26,7 @@ pub struct Sign {
|
||||
server_ephemeral_key: Option<ServerEphemKey>,
|
||||
cert_commitment: ServerCertCommitment,
|
||||
encoding_commitment_root: Option<TypedHash>,
|
||||
encoding_seed: Option<Vec<u8>>,
|
||||
encoder_secret: Option<EncoderSecret>,
|
||||
}
|
||||
|
||||
/// An attestation builder.
|
||||
@@ -91,7 +92,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
|
||||
server_ephemeral_key: None,
|
||||
cert_commitment,
|
||||
encoding_commitment_root,
|
||||
encoding_seed: None,
|
||||
encoder_secret: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -110,9 +111,9 @@ impl AttestationBuilder<'_, Sign> {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the encoding seed.
|
||||
pub fn encoding_seed(&mut self, seed: Vec<u8>) -> &mut Self {
|
||||
self.state.encoding_seed = Some(seed);
|
||||
/// Sets the encoder secret.
|
||||
pub fn encoder_secret(&mut self, secret: EncoderSecret) -> &mut Self {
|
||||
self.state.encoder_secret = Some(secret);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -125,7 +126,7 @@ impl AttestationBuilder<'_, Sign> {
|
||||
server_ephemeral_key,
|
||||
cert_commitment,
|
||||
encoding_commitment_root,
|
||||
encoding_seed,
|
||||
encoder_secret,
|
||||
} = self.state;
|
||||
|
||||
let hasher = provider.hash.get(&hash_alg).map_err(|_| {
|
||||
@@ -144,14 +145,14 @@ impl AttestationBuilder<'_, Sign> {
|
||||
})?;
|
||||
|
||||
let encoding_commitment = if let Some(root) = encoding_commitment_root {
|
||||
let Some(seed) = encoding_seed else {
|
||||
let Some(secret) = encoder_secret else {
|
||||
return Err(AttestationBuilderError::new(
|
||||
ErrorKind::Field,
|
||||
"encoding commitment requested but seed was not set",
|
||||
"encoding commitment requested but encoder_secret was not set",
|
||||
));
|
||||
};
|
||||
|
||||
Some(EncodingCommitment { root, seed })
|
||||
Some(EncodingCommitment { root, secret })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -246,7 +247,7 @@ mod test {
|
||||
use crate::{
|
||||
connection::{HandshakeData, HandshakeDataV1_2},
|
||||
fixtures::{
|
||||
encoder_seed, encoding_provider, request_fixture, ConnectionFixture, RequestFixture,
|
||||
encoder_secret, encoding_provider, request_fixture, ConnectionFixture, RequestFixture,
|
||||
},
|
||||
hash::Blake3,
|
||||
transcript::Transcript,
|
||||
@@ -435,7 +436,7 @@ mod test {
|
||||
|
||||
attestation_builder
|
||||
.connection_info(connection_info)
|
||||
.encoding_seed(encoder_seed().to_vec());
|
||||
.encoder_secret(encoder_secret());
|
||||
|
||||
let err = attestation_builder.build(crypto_provider).err().unwrap();
|
||||
assert!(matches!(err.kind, ErrorKind::Field));
|
||||
@@ -471,7 +472,7 @@ mod test {
|
||||
|
||||
attestation_builder
|
||||
.server_ephemeral_key(server_ephemeral_key)
|
||||
.encoding_seed(encoder_seed().to_vec());
|
||||
.encoder_secret(encoder_secret());
|
||||
|
||||
let err = attestation_builder.build(crypto_provider).err().unwrap();
|
||||
assert!(matches!(err.kind, ErrorKind::Field));
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
mod provider;
|
||||
|
||||
pub use provider::ChaChaProvider;
|
||||
pub use provider::FixtureEncodingProvider;
|
||||
|
||||
use hex::FromHex;
|
||||
use p256::ecdsa::SigningKey;
|
||||
@@ -17,7 +17,7 @@ use crate::{
|
||||
request::{Request, RequestConfig},
|
||||
signing::SignatureAlgId,
|
||||
transcript::{
|
||||
encoding::{EncodingProvider, EncodingTree},
|
||||
encoding::{EncoderSecret, EncodingProvider, EncodingTree},
|
||||
Transcript, TranscriptCommitConfigBuilder,
|
||||
},
|
||||
CryptoProvider,
|
||||
@@ -131,12 +131,26 @@ impl ConnectionFixture {
|
||||
|
||||
/// Returns an encoding provider fixture.
|
||||
pub fn encoding_provider(tx: &[u8], rx: &[u8]) -> impl EncodingProvider {
|
||||
ChaChaProvider::new(encoder_seed(), Transcript::new(tx, rx))
|
||||
let secret = encoder_secret();
|
||||
FixtureEncodingProvider::new(&secret, Transcript::new(tx, rx))
|
||||
}
|
||||
|
||||
/// Returns an encoder seed fixture.
|
||||
pub fn encoder_seed() -> [u8; 32] {
|
||||
[0u8; 32]
|
||||
/// Seed fixture.
|
||||
const SEED: [u8; 32] = [0; 32];
|
||||
|
||||
/// Delta fixture.
|
||||
const DELTA: [u8; 16] = [1; 16];
|
||||
|
||||
/// Returns an encoder secret fixture.
|
||||
pub fn encoder_secret() -> EncoderSecret {
|
||||
EncoderSecret::new(SEED, DELTA)
|
||||
}
|
||||
|
||||
/// Returns a tampered encoder secret fixture.
|
||||
pub fn encoder_secret_tampered_seed() -> EncoderSecret {
|
||||
let mut seed = SEED;
|
||||
seed[0] += 1;
|
||||
EncoderSecret::new(seed, DELTA)
|
||||
}
|
||||
|
||||
/// Returns a notary signing key fixture.
|
||||
@@ -205,7 +219,7 @@ pub fn attestation_fixture(
|
||||
request: Request,
|
||||
connection: ConnectionFixture,
|
||||
signature_alg: SignatureAlgId,
|
||||
encoding_seed: Vec<u8>,
|
||||
secret: EncoderSecret,
|
||||
) -> Attestation {
|
||||
let ConnectionFixture {
|
||||
connection_info,
|
||||
@@ -237,7 +251,7 @@ pub fn attestation_fixture(
|
||||
attestation_builder
|
||||
.connection_info(connection_info)
|
||||
.server_ephemeral_key(server_ephemeral_key)
|
||||
.encoding_seed(encoding_seed);
|
||||
.encoder_secret(secret);
|
||||
|
||||
attestation_builder.build(&provider).unwrap()
|
||||
}
|
||||
|
||||
@@ -1,27 +1,25 @@
|
||||
use mpz_garble_core::ChaChaEncoder;
|
||||
|
||||
use crate::transcript::{
|
||||
encoding::{Encoder, EncodingProvider},
|
||||
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider},
|
||||
Direction, Idx, Transcript,
|
||||
};
|
||||
|
||||
/// A ChaCha encoding provider fixture.
|
||||
pub struct ChaChaProvider {
|
||||
encoder: ChaChaEncoder,
|
||||
/// A encoding provider fixture.
|
||||
pub struct FixtureEncodingProvider {
|
||||
encoder: Box<dyn Encoder>,
|
||||
transcript: Transcript,
|
||||
}
|
||||
|
||||
impl ChaChaProvider {
|
||||
/// Creates a new ChaCha encoding provider.
|
||||
pub(crate) fn new(seed: [u8; 32], transcript: Transcript) -> Self {
|
||||
impl FixtureEncodingProvider {
|
||||
/// Creates a new encoding provider fixture.
|
||||
pub(crate) fn new(secret: &EncoderSecret, transcript: Transcript) -> Self {
|
||||
Self {
|
||||
encoder: ChaChaEncoder::new(seed),
|
||||
encoder: Box::new(new_encoder(secret)),
|
||||
transcript,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncodingProvider for ChaChaProvider {
|
||||
impl EncodingProvider for FixtureEncodingProvider {
|
||||
fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option<Vec<u8>> {
|
||||
let seq = self.transcript.get(direction, idx)?;
|
||||
Some(self.encoder.encode_subsequence(direction, &seq))
|
||||
|
||||
@@ -97,7 +97,7 @@ mod test {
|
||||
use crate::{
|
||||
connection::{ServerCertOpening, TranscriptLength},
|
||||
fixtures::{
|
||||
attestation_fixture, encoder_seed, encoding_provider, request_fixture,
|
||||
attestation_fixture, encoder_secret, encoding_provider, request_fixture,
|
||||
ConnectionFixture, RequestFixture,
|
||||
},
|
||||
hash::{Blake3, Hash, HashAlgId},
|
||||
@@ -122,7 +122,7 @@ mod test {
|
||||
request.clone(),
|
||||
connection,
|
||||
SignatureAlgId::SECP256K1,
|
||||
encoder_seed().to_vec(),
|
||||
encoder_secret(),
|
||||
);
|
||||
|
||||
assert!(request.validate(&attestation).is_ok())
|
||||
@@ -144,7 +144,7 @@ mod test {
|
||||
request.clone(),
|
||||
connection,
|
||||
SignatureAlgId::SECP256K1,
|
||||
encoder_seed().to_vec(),
|
||||
encoder_secret(),
|
||||
);
|
||||
|
||||
request.signature_alg = SignatureAlgId::SECP256R1;
|
||||
@@ -169,7 +169,7 @@ mod test {
|
||||
request.clone(),
|
||||
connection,
|
||||
SignatureAlgId::SECP256K1,
|
||||
encoder_seed().to_vec(),
|
||||
encoder_secret(),
|
||||
);
|
||||
|
||||
request.hash_alg = HashAlgId::SHA256;
|
||||
@@ -194,7 +194,7 @@ mod test {
|
||||
request.clone(),
|
||||
connection,
|
||||
SignatureAlgId::SECP256K1,
|
||||
encoder_seed().to_vec(),
|
||||
encoder_secret(),
|
||||
);
|
||||
|
||||
let ConnectionFixture {
|
||||
@@ -229,7 +229,7 @@ mod test {
|
||||
request.clone(),
|
||||
connection,
|
||||
SignatureAlgId::SECP256K1,
|
||||
encoder_seed().to_vec(),
|
||||
encoder_secret(),
|
||||
);
|
||||
|
||||
request.encoding_commitment_root = Some(TypedHash {
|
||||
|
||||
@@ -52,11 +52,6 @@ pub use proof::{
|
||||
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
|
||||
};
|
||||
|
||||
/// Sent data transcript ID.
|
||||
pub static TX_TRANSCRIPT_ID: &str = "tx";
|
||||
/// Received data transcript ID.
|
||||
pub static RX_TRANSCRIPT_ID: &str = "rx";
|
||||
|
||||
/// A transcript contains all the data communicated over a TLS connection.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Transcript {
|
||||
@@ -588,17 +583,6 @@ impl Subsequence {
|
||||
#[error("invalid subsequence: {0}")]
|
||||
pub struct InvalidSubsequence(&'static str);
|
||||
|
||||
/// Returns the value ID for each byte in the provided range set.
|
||||
#[doc(hidden)]
|
||||
pub fn get_value_ids(direction: Direction, idx: &Idx) -> impl Iterator<Item = String> + '_ {
|
||||
let id = match direction {
|
||||
Direction::Sent => TX_TRANSCRIPT_ID,
|
||||
Direction::Received => RX_TRANSCRIPT_ID,
|
||||
};
|
||||
|
||||
idx.iter().map(move |idx| format!("{}/{}", id, idx))
|
||||
}
|
||||
|
||||
mod validation {
|
||||
use super::*;
|
||||
|
||||
@@ -716,7 +700,7 @@ mod validation {
|
||||
.sent_idx
|
||||
.0
|
||||
.iter_ranges()
|
||||
.last()
|
||||
.next_back()
|
||||
.unwrap()
|
||||
.end;
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ mod proof;
|
||||
mod provider;
|
||||
mod tree;
|
||||
|
||||
pub(crate) use encoder::{new_encoder, Encoder};
|
||||
pub use encoder::{new_encoder, Encoder, EncoderSecret};
|
||||
pub use proof::{EncodingProof, EncodingProofError};
|
||||
pub use provider::EncodingProvider;
|
||||
pub use tree::EncodingTree;
|
||||
@@ -32,7 +32,7 @@ pub struct EncodingCommitment {
|
||||
/// Merkle root of the encoding commitments.
|
||||
pub root: TypedHash,
|
||||
/// Seed used to generate the encodings.
|
||||
pub seed: Vec<u8>,
|
||||
pub secret: EncoderSecret,
|
||||
}
|
||||
|
||||
impl_domain_separator!(EncodingCommitment);
|
||||
|
||||
@@ -1,18 +1,79 @@
|
||||
use mpz_circuits::types::ValueType;
|
||||
use mpz_core::serialize::CanonicalSerialize;
|
||||
use mpz_garble_core::ChaChaEncoder;
|
||||
use crate::transcript::{Direction, Idx, Subsequence};
|
||||
use itybity::ToBits;
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use rand_chacha::ChaCha12Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::transcript::{Direction, Subsequence, RX_TRANSCRIPT_ID, TX_TRANSCRIPT_ID};
|
||||
/// The size of the encoding for 1 bit, in bytes.
|
||||
const BIT_ENCODING_SIZE: usize = 16;
|
||||
/// The size of the encoding for 1 byte, in bytes.
|
||||
const BYTE_ENCODING_SIZE: usize = 128;
|
||||
|
||||
pub(crate) fn new_encoder(seed: [u8; 32]) -> impl Encoder {
|
||||
ChaChaEncoder::new(seed)
|
||||
/// Secret used by an encoder to generate encodings.
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct EncoderSecret {
|
||||
seed: [u8; 32],
|
||||
delta: [u8; BIT_ENCODING_SIZE],
|
||||
}
|
||||
|
||||
opaque_debug::implement!(EncoderSecret);
|
||||
|
||||
impl EncoderSecret {
|
||||
/// Creates a new secret.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `seed` - The seed for the PRG.
|
||||
/// * `delta` - Delta for deriving the one-encodings.
|
||||
pub fn new(seed: [u8; 32], delta: [u8; 16]) -> Self {
|
||||
Self { seed, delta }
|
||||
}
|
||||
|
||||
/// Returns the seed.
|
||||
pub fn seed(&self) -> &[u8; 32] {
|
||||
&self.seed
|
||||
}
|
||||
|
||||
/// Returns the delta.
|
||||
pub fn delta(&self) -> &[u8; 16] {
|
||||
&self.delta
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new encoder.
|
||||
pub fn new_encoder(secret: &EncoderSecret) -> impl Encoder {
|
||||
ChaChaEncoder::new(secret)
|
||||
}
|
||||
|
||||
pub(crate) struct ChaChaEncoder {
|
||||
seed: [u8; 32],
|
||||
delta: [u8; 16],
|
||||
}
|
||||
|
||||
impl ChaChaEncoder {
|
||||
pub(crate) fn new(secret: &EncoderSecret) -> Self {
|
||||
let seed = *secret.seed();
|
||||
let delta = *secret.delta();
|
||||
|
||||
Self { seed, delta }
|
||||
}
|
||||
|
||||
pub(crate) fn new_prg(&self, stream_id: u64) -> ChaCha12Rng {
|
||||
let mut prg = ChaCha12Rng::from_seed(self.seed);
|
||||
prg.set_stream(stream_id);
|
||||
prg.set_word_pos(0);
|
||||
prg
|
||||
}
|
||||
}
|
||||
|
||||
/// A transcript encoder.
|
||||
///
|
||||
/// This is an internal implementation detail that should not be exposed to the
|
||||
/// public API.
|
||||
pub(crate) trait Encoder {
|
||||
pub trait Encoder {
|
||||
/// Returns the zero encoding for the given index.
|
||||
fn encode_idx(&self, direction: Direction, idx: &Idx) -> Vec<u8>;
|
||||
|
||||
/// Returns the encoding for the given subsequence of the transcript.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -22,28 +83,45 @@ pub(crate) trait Encoder {
|
||||
}
|
||||
|
||||
impl Encoder for ChaChaEncoder {
|
||||
fn encode_subsequence(&self, direction: Direction, seq: &Subsequence) -> Vec<u8> {
|
||||
let id = match direction {
|
||||
Direction::Sent => TX_TRANSCRIPT_ID,
|
||||
Direction::Received => RX_TRANSCRIPT_ID,
|
||||
fn encode_idx(&self, direction: Direction, idx: &Idx) -> Vec<u8> {
|
||||
// ChaCha20 encoder works with 32-bit words. Each encoded bit is 128 bits long.
|
||||
const WORDS_PER_BYTE: u128 = 8 * 128 / 32;
|
||||
|
||||
let stream_id: u64 = match direction {
|
||||
Direction::Sent => 0,
|
||||
Direction::Received => 1,
|
||||
};
|
||||
|
||||
let mut encoding = Vec::with_capacity(seq.len() * 16);
|
||||
for (byte_id, &byte) in seq.index().iter().zip(seq.data()) {
|
||||
let id_hash = mpz_core::utils::blake3(format!("{}/{}", id, byte_id).as_bytes());
|
||||
let id = u64::from_be_bytes(id_hash[..8].try_into().unwrap());
|
||||
let mut prg = self.new_prg(stream_id);
|
||||
let mut encoding: Vec<u8> = vec![0u8; idx.len() * BYTE_ENCODING_SIZE];
|
||||
|
||||
encoding.extend(
|
||||
<ChaChaEncoder as mpz_garble_core::Encoder>::encode_by_type(
|
||||
self,
|
||||
id,
|
||||
&ValueType::U8,
|
||||
)
|
||||
.select(byte)
|
||||
.expect("encoding is a byte encoding")
|
||||
.to_bytes(),
|
||||
)
|
||||
let mut pos = 0;
|
||||
for range in idx.iter_ranges() {
|
||||
let len = range.len() * BYTE_ENCODING_SIZE;
|
||||
prg.set_word_pos(range.start as u128 * WORDS_PER_BYTE);
|
||||
prg.fill_bytes(&mut encoding[pos..pos + len]);
|
||||
pos += len;
|
||||
}
|
||||
|
||||
encoding
|
||||
}
|
||||
|
||||
fn encode_subsequence(&self, direction: Direction, seq: &Subsequence) -> Vec<u8> {
|
||||
const ZERO: [u8; 16] = [0; BIT_ENCODING_SIZE];
|
||||
let mut encoding = self.encode_idx(direction, seq.index());
|
||||
for (byte_idx, &byte) in seq.data().iter().enumerate() {
|
||||
let start = byte_idx * BYTE_ENCODING_SIZE;
|
||||
for (bit_idx, bit) in byte.iter_lsb0().enumerate() {
|
||||
let pos = start + (bit_idx * BIT_ENCODING_SIZE);
|
||||
let delta = if bit { &self.delta } else { &ZERO };
|
||||
|
||||
encoding[pos..pos + BIT_ENCODING_SIZE]
|
||||
.iter_mut()
|
||||
.zip(delta)
|
||||
.for_each(|(a, b)| *a ^= *b);
|
||||
}
|
||||
}
|
||||
|
||||
encoding
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,11 +51,7 @@ impl EncodingProof {
|
||||
) -> Result<PartialTranscript, EncodingProofError> {
|
||||
let hasher = provider.hash.get(&commitment.root.alg)?;
|
||||
|
||||
let seed: [u8; 32] = commitment.seed.clone().try_into().map_err(|_| {
|
||||
EncodingProofError::new(ErrorKind::Commitment, "encoding seed not 32 bytes")
|
||||
})?;
|
||||
|
||||
let encoder = new_encoder(seed);
|
||||
let encoder = new_encoder(&commitment.secret);
|
||||
let Self {
|
||||
inclusion_proof,
|
||||
openings,
|
||||
@@ -149,7 +145,6 @@ impl EncodingProofError {
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
Provider,
|
||||
Commitment,
|
||||
Proof,
|
||||
}
|
||||
|
||||
@@ -159,7 +154,6 @@ impl fmt::Display for EncodingProofError {
|
||||
|
||||
match self.kind {
|
||||
ErrorKind::Provider => f.write_str("provider error")?,
|
||||
ErrorKind::Commitment => f.write_str("commitment error")?,
|
||||
ErrorKind::Proof => f.write_str("proof error")?,
|
||||
}
|
||||
|
||||
@@ -188,9 +182,12 @@ mod test {
|
||||
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
|
||||
|
||||
use crate::{
|
||||
fixtures::{encoder_seed, encoding_provider},
|
||||
fixtures::{encoder_secret, encoder_secret_tampered_seed, encoding_provider},
|
||||
hash::Blake3,
|
||||
transcript::{encoding::EncodingTree, Idx, Transcript},
|
||||
transcript::{
|
||||
encoding::{EncoderSecret, EncodingTree},
|
||||
Idx, Transcript,
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -201,7 +198,7 @@ mod test {
|
||||
commitment: EncodingCommitment,
|
||||
}
|
||||
|
||||
fn new_encoding_fixture(seed: Vec<u8>) -> EncodingFixture {
|
||||
fn new_encoding_fixture(secret: EncoderSecret) -> EncodingFixture {
|
||||
let transcript = Transcript::new(POST_JSON, OK_JSON);
|
||||
|
||||
let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
|
||||
@@ -226,7 +223,7 @@ mod test {
|
||||
|
||||
let commitment = EncodingCommitment {
|
||||
root: tree.root(),
|
||||
seed,
|
||||
secret,
|
||||
};
|
||||
|
||||
EncodingFixture {
|
||||
@@ -237,12 +234,12 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_encoding_proof_invalid_seed() {
|
||||
fn test_verify_encoding_proof_tampered_seed() {
|
||||
let EncodingFixture {
|
||||
transcript,
|
||||
proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture(encoder_seed().to_vec().split_off(1));
|
||||
} = new_encoding_fixture(encoder_secret_tampered_seed());
|
||||
|
||||
let err = proof
|
||||
.verify_with_provider(
|
||||
@@ -252,7 +249,7 @@ mod test {
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Commitment));
|
||||
assert!(matches!(err.kind, ErrorKind::Proof));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -261,7 +258,7 @@ mod test {
|
||||
transcript,
|
||||
proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture(encoder_seed().to_vec());
|
||||
} = new_encoding_fixture(encoder_secret());
|
||||
|
||||
let err = proof
|
||||
.verify_with_provider(
|
||||
@@ -283,7 +280,7 @@ mod test {
|
||||
transcript,
|
||||
mut proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture(encoder_seed().to_vec());
|
||||
} = new_encoding_fixture(encoder_secret());
|
||||
|
||||
let Opening { seq, .. } = proof.openings.values_mut().next().unwrap();
|
||||
|
||||
@@ -306,7 +303,7 @@ mod test {
|
||||
transcript,
|
||||
mut proof,
|
||||
commitment,
|
||||
} = new_encoding_fixture(encoder_seed().to_vec());
|
||||
} = new_encoding_fixture(encoder_secret());
|
||||
|
||||
let Opening { blinder, .. } = proof.openings.values_mut().next().unwrap();
|
||||
|
||||
|
||||
@@ -198,7 +198,7 @@ impl EncodingTree {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
fixtures::{encoder_seed, encoding_provider},
|
||||
fixtures::{encoder_secret, encoding_provider},
|
||||
hash::Blake3,
|
||||
transcript::encoding::EncodingCommitment,
|
||||
CryptoProvider,
|
||||
@@ -235,7 +235,7 @@ mod tests {
|
||||
|
||||
let commitment = EncodingCommitment {
|
||||
root: tree.root(),
|
||||
seed: encoder_seed().to_vec(),
|
||||
secret: encoder_secret(),
|
||||
};
|
||||
|
||||
let partial_transcript = proof
|
||||
@@ -272,7 +272,7 @@ mod tests {
|
||||
|
||||
let commitment = EncodingCommitment {
|
||||
root: tree.root(),
|
||||
seed: encoder_seed().to_vec(),
|
||||
secret: encoder_secret(),
|
||||
};
|
||||
|
||||
let partial_transcript = proof
|
||||
|
||||
@@ -361,7 +361,7 @@ mod tests {
|
||||
|
||||
use crate::{
|
||||
fixtures::{
|
||||
attestation_fixture, encoder_seed, encoding_provider, request_fixture,
|
||||
attestation_fixture, encoder_secret, encoding_provider, request_fixture,
|
||||
ConnectionFixture, RequestFixture,
|
||||
},
|
||||
hash::Blake3,
|
||||
@@ -448,7 +448,7 @@ mod tests {
|
||||
request,
|
||||
connection,
|
||||
SignatureAlgId::SECP256K1,
|
||||
encoder_seed().to_vec(),
|
||||
encoder_secret(),
|
||||
);
|
||||
|
||||
let provider = CryptoProvider::default();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use tlsn_core::{
|
||||
attestation::{Attestation, AttestationConfig},
|
||||
connection::{HandshakeData, HandshakeDataV1_2},
|
||||
fixtures::{self, encoder_seed, ConnectionFixture},
|
||||
fixtures::{self, encoder_secret, ConnectionFixture},
|
||||
hash::Blake3,
|
||||
presentation::PresentationOutput,
|
||||
request::{Request, RequestConfig},
|
||||
@@ -84,7 +84,7 @@ fn test_api() {
|
||||
.connection_info(connection_info.clone())
|
||||
// Server key Notary received during handshake
|
||||
.server_ephemeral_key(server_ephemeral_key)
|
||||
.encoding_seed(encoder_seed().to_vec());
|
||||
.encoder_secret(encoder_secret());
|
||||
|
||||
let attestation = attestation_builder.build(&provider).unwrap();
|
||||
|
||||
|
||||
72
crates/mpc-tls/Cargo.toml
Normal file
72
crates/mpc-tls/Cargo.toml
Normal file
@@ -0,0 +1,72 @@
|
||||
[package]
|
||||
name = "tlsn-mpc-tls"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "Implementation of the backend trait for 2PC"
|
||||
keywords = ["tls", "mpc", "2pc"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.0.0"
|
||||
publish = false
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "mpc_tls"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
tlsn-cipher = { workspace = true }
|
||||
tlsn-common = { workspace = true }
|
||||
tlsn-hmac-sha256 = { workspace = true }
|
||||
tlsn-key-exchange = { workspace = true }
|
||||
tlsn-tls-backend = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true, features = ["serde"] }
|
||||
tlsn-utils-aio = { workspace = true }
|
||||
|
||||
mpz-common = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-fields = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-ole = { workspace = true }
|
||||
mpz-share-conversion = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
|
||||
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
|
||||
serio = { workspace = true }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
enum-try-as-inner = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
p256 = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
opaque-debug = { workspace = true }
|
||||
aes = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
ctr = { workspace = true }
|
||||
ghash_rc = { package = "ghash", version = "0.5" }
|
||||
cipher-crate = { package = "cipher", version = "0.4" }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
pin-project-lite = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-ole = { workspace = true, features = ["test-utils"] }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
|
||||
tls-server-fixture = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio-util = { workspace = true, features = ["compat"] }
|
||||
tracing-subscriber = { workspace = true }
|
||||
rand_chacha = { workspace = true }
|
||||
generic-array = { workspace = true }
|
||||
uid-mux = { workspace = true, features = ["serio", "test-utils"] }
|
||||
rstest = { workspace = true }
|
||||
67
crates/mpc-tls/src/config.rs
Normal file
67
crates/mpc-tls/src/config.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use derive_builder::Builder;
|
||||
|
||||
const MIN_SENT: usize = 32;
|
||||
const MIN_SENT_RECORDS: usize = 8;
|
||||
const MIN_RECV: usize = 32;
|
||||
const MIN_RECV_RECORDS: usize = 8;
|
||||
|
||||
/// MPC-TLS configuration.
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
#[builder(build_fn(skip))]
|
||||
pub struct Config {
|
||||
/// Defers decryption of received data until after the MPC-TLS connection is
|
||||
/// closed.
|
||||
///
|
||||
/// The received data will be decrypted locally without MPC, thus improving
|
||||
/// bandwidth usage and performance.
|
||||
pub(crate) defer_decryption: bool,
|
||||
/// Maximum number of sent TLS records. Data is transmitted in records up to
|
||||
/// 16KB long.
|
||||
pub(crate) max_sent_records: usize,
|
||||
/// Maximum number of sent bytes.
|
||||
pub(crate) max_sent: usize,
|
||||
/// Maximum number of received TLS records. Data is transmitted in records
|
||||
/// up to 16KB long.
|
||||
pub(crate) max_recv_records: usize,
|
||||
/// Maximum number of received bytes which will be decrypted while
|
||||
/// the TLS connection is active. Data which can be decrypted after the TLS
|
||||
/// connection will be decrypted for free.
|
||||
pub(crate) max_recv_online: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> ConfigBuilder {
|
||||
ConfigBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigBuilder {
|
||||
/// Builds the configuration.
|
||||
pub fn build(&self) -> Result<Config, ConfigBuilderError> {
|
||||
let defer_decryption = self.defer_decryption.unwrap_or(true);
|
||||
let max_sent = MIN_SENT
|
||||
+ self
|
||||
.max_sent
|
||||
.ok_or(ConfigBuilderError::UninitializedField("max_sent"))?;
|
||||
let max_recv_online = MIN_RECV
|
||||
+ self
|
||||
.max_recv_online
|
||||
.ok_or(ConfigBuilderError::UninitializedField("max_recv_online"))?;
|
||||
|
||||
let max_sent_records = self
|
||||
.max_sent_records
|
||||
.unwrap_or_else(|| MIN_SENT_RECORDS + max_sent.div_ceil(16384));
|
||||
let max_recv_records = self
|
||||
.max_recv_records
|
||||
.unwrap_or_else(|| MIN_RECV_RECORDS + max_recv_online.div_ceil(16384));
|
||||
|
||||
Ok(Config {
|
||||
defer_decryption,
|
||||
max_sent_records,
|
||||
max_sent,
|
||||
max_recv_records,
|
||||
max_recv_online,
|
||||
})
|
||||
}
|
||||
}
|
||||
72
crates/mpc-tls/src/decode.rs
Normal file
72
crates/mpc-tls/src/decode.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::{
|
||||
array::from_fn,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use crate::Role;
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
DecodeError, DecodeFutureTyped,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, Vm, VmError};
|
||||
use pin_project_lite::pin_project;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
pin_project! {
|
||||
/// Supports decoding into additive shares.
|
||||
#[project = OneTimePadSharedProj]
|
||||
pub(crate) enum OneTimePadShared<T> {
|
||||
Leader {
|
||||
otp: T,
|
||||
},
|
||||
Follower {
|
||||
#[pin] value: DecodeFutureTyped<BitVec, T>,
|
||||
otp: T,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> OneTimePadShared<[u8; N]> {
|
||||
pub(crate) fn new(
|
||||
role: Role,
|
||||
value: Array<U8, N>,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<Self, VmError> {
|
||||
let mut rng = thread_rng();
|
||||
let otp: [u8; N] = from_fn(|_| rng.gen());
|
||||
match role {
|
||||
Role::Leader => {
|
||||
let masked = vm.mask_private(value, otp)?;
|
||||
let masked = vm.mask_blind(masked)?;
|
||||
_ = vm.decode(masked)?;
|
||||
|
||||
Ok(Self::Leader { otp })
|
||||
}
|
||||
Role::Follower => {
|
||||
let masked = vm.mask_blind(value)?;
|
||||
let masked = vm.mask_private(masked, otp)?;
|
||||
let value = vm.decode(masked)?;
|
||||
|
||||
Ok(Self::Follower { value, otp })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for OneTimePadShared<[u8; 16]> {
|
||||
type Output = Result<[u8; 16], DecodeError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project() {
|
||||
OneTimePadSharedProj::Leader { otp } => Poll::Ready(Ok(*otp)),
|
||||
OneTimePadSharedProj::Follower { value, otp } => {
|
||||
let mut value = ready!(value.poll(cx))?;
|
||||
value.iter_mut().zip(otp).for_each(|(a, b)| *a ^= *b);
|
||||
Poll::Ready(Ok(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
113
crates/mpc-tls/src/error.rs
Normal file
113
crates/mpc-tls/src/error.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use hmac_sha256::PrfError;
|
||||
use key_exchange::KeyExchangeError;
|
||||
use tls_backend::BackendError;
|
||||
|
||||
/// MPC-TLS error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct MpcTlsError(#[from] ErrorRepr);
|
||||
|
||||
impl MpcTlsError {
|
||||
pub(crate) fn peer<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Peer(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn actor<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Actor(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn state<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::State(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn alloc<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Alloc(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn preprocess<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Preprocess(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn hs<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Handshake(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn record_layer<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::RecordLayer(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn other<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Other(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("mpc-tls error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("peer misbehaved")]
|
||||
Peer(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("I/O error: {0}")]
|
||||
Io(std::io::Error),
|
||||
#[error("actor error: {0}")]
|
||||
Actor(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("state error: {0}")]
|
||||
State(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("allocation error: {0}")]
|
||||
Alloc(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("preprocess error: {0}")]
|
||||
Preprocess(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("handshake error: {0}")]
|
||||
Handshake(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("record layer error: {0}")]
|
||||
RecordLayer(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("other: {0}")]
|
||||
Other(Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for MpcTlsError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
MpcTlsError(ErrorRepr::Io(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MpcTlsError> for BackendError {
|
||||
fn from(value: MpcTlsError) -> Self {
|
||||
BackendError::InternalError(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<KeyExchangeError> for MpcTlsError {
|
||||
fn from(value: KeyExchangeError) -> Self {
|
||||
MpcTlsError::hs(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PrfError> for MpcTlsError {
|
||||
fn from(value: PrfError) -> Self {
|
||||
MpcTlsError::hs(value)
|
||||
}
|
||||
}
|
||||
582
crates/mpc-tls/src/follower.rs
Normal file
582
crates/mpc-tls/src/follower.rs
Normal file
@@ -0,0 +1,582 @@
|
||||
use crate::{
|
||||
msg::Message,
|
||||
record_layer::{aead::MpcAesGcm, RecordLayer},
|
||||
Config, FollowerData, MpcTlsError, Role, SessionKeys, Vm,
|
||||
};
|
||||
use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput};
|
||||
use ke::KeyExchange;
|
||||
use key_exchange::{self as ke, MpcKeyExchange};
|
||||
use mpz_common::{scoped_futures::ScopedFutureExt, Context, Flush};
|
||||
use mpz_core::{bitvec::BitVec, Block};
|
||||
use mpz_memory_core::{DecodeFutureTyped, MemoryExt};
|
||||
use mpz_ole::{Receiver as OLEReceiver, Sender as OLESender};
|
||||
use mpz_ot::{
|
||||
rcot::{RCOTReceiver, RCOTSender},
|
||||
rot::{
|
||||
any::{AnyReceiver, AnySender},
|
||||
randomize::{RandomizeRCOTReceiver, RandomizeRCOTSender},
|
||||
},
|
||||
};
|
||||
use mpz_share_conversion::{ShareConversionReceiver, ShareConversionSender};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serio::stream::IoStreamExt;
|
||||
use std::mem;
|
||||
use tls_core::msgs::{
|
||||
alert::AlertMessagePayload,
|
||||
codec::{Codec, Reader},
|
||||
enums::{AlertDescription, ContentType, NamedGroup, ProtocolVersion},
|
||||
handshake::{HandshakeMessagePayload, HandshakePayload},
|
||||
};
|
||||
use tlsn_common::transcript::TlsTranscript;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
/// MPC-TLS follower.
|
||||
#[derive(Debug)]
|
||||
pub struct MpcTlsFollower {
|
||||
config: Config,
|
||||
ctx: Context,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl MpcTlsFollower {
|
||||
/// Creates a new follower.
|
||||
pub fn new<CS, CR>(
|
||||
config: Config,
|
||||
ctx: Context,
|
||||
vm: Vm,
|
||||
cot_send: CS,
|
||||
cot_recv: (CR, CR, CR),
|
||||
) -> Self
|
||||
where
|
||||
CS: RCOTSender<Block> + Flush + Send + Sync + 'static,
|
||||
CR: RCOTReceiver<bool, Block> + Flush + Send + Sync + 'static,
|
||||
{
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let ke = Box::new(MpcKeyExchange::new(
|
||||
key_exchange::Role::Follower,
|
||||
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
|
||||
RandomizeRCOTReceiver::new(cot_recv.0),
|
||||
))),
|
||||
ShareConversionSender::new(OLESender::new(
|
||||
rng.gen(),
|
||||
AnySender::new(RandomizeRCOTSender::new(cot_send)),
|
||||
)),
|
||||
)) as Box<dyn KeyExchange + Send + Sync>;
|
||||
|
||||
let prf = MpcPrf::new(
|
||||
PrfConfig::builder()
|
||||
.role(hmac_sha256::Role::Follower)
|
||||
.build()
|
||||
.expect("PRF config is valid"),
|
||||
);
|
||||
|
||||
let encrypter = MpcAesGcm::new(
|
||||
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
|
||||
RandomizeRCOTReceiver::new(cot_recv.1),
|
||||
))),
|
||||
Role::Follower,
|
||||
);
|
||||
let decrypter = MpcAesGcm::new(
|
||||
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
|
||||
RandomizeRCOTReceiver::new(cot_recv.2),
|
||||
))),
|
||||
Role::Follower,
|
||||
);
|
||||
|
||||
let record_layer = RecordLayer::new(Role::Follower, encrypter, decrypter);
|
||||
|
||||
Self {
|
||||
config,
|
||||
ctx,
|
||||
state: State::Init {
|
||||
vm,
|
||||
ke,
|
||||
prf,
|
||||
record_layer,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates resources for the connection.
|
||||
pub fn alloc(&mut self) -> Result<SessionKeys, MpcTlsError> {
|
||||
let State::Init {
|
||||
vm,
|
||||
mut ke,
|
||||
mut prf,
|
||||
mut record_layer,
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(MpcTlsError::state("must be in init state to allocate"));
|
||||
};
|
||||
|
||||
let (keys, cf_vd, sf_vd) = {
|
||||
let vm = &mut (*vm
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("VM lock is held"))?);
|
||||
|
||||
let pms = ke.alloc(vm)?;
|
||||
let PrfOutput { keys, cf_vd, sf_vd } = prf.alloc(vm, pms)?;
|
||||
record_layer.set_keys(
|
||||
keys.client_write_key,
|
||||
keys.client_iv,
|
||||
keys.server_write_key,
|
||||
keys.server_iv,
|
||||
)?;
|
||||
|
||||
prf.set_client_random(vm, None)?;
|
||||
|
||||
let cf_vd = vm.decode(cf_vd).map_err(MpcTlsError::alloc)?;
|
||||
let sf_vd = vm.decode(sf_vd).map_err(MpcTlsError::alloc)?;
|
||||
|
||||
record_layer.alloc(
|
||||
vm,
|
||||
self.config.max_sent_records,
|
||||
self.config.max_recv_records,
|
||||
self.config.max_sent,
|
||||
self.config.max_recv_online,
|
||||
)?;
|
||||
|
||||
(keys, cf_vd, sf_vd)
|
||||
};
|
||||
|
||||
self.state = State::Setup {
|
||||
vm,
|
||||
keys: keys.into(),
|
||||
ke,
|
||||
prf,
|
||||
record_layer,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
};
|
||||
|
||||
Ok(keys.into())
|
||||
}
|
||||
|
||||
/// Preprocesses the connection.
|
||||
#[instrument(skip_all, err)]
|
||||
pub async fn preprocess(&mut self) -> Result<(), MpcTlsError> {
|
||||
let State::Setup {
|
||||
vm,
|
||||
keys,
|
||||
mut ke,
|
||||
prf,
|
||||
mut record_layer,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(MpcTlsError::state("must be in setup state to preprocess"));
|
||||
};
|
||||
|
||||
let (ke, record_layer, _) = {
|
||||
let mut vm = vm
|
||||
.clone()
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
|
||||
self.ctx
|
||||
.try_join3(
|
||||
|ctx| {
|
||||
async move {
|
||||
ke.setup(ctx)
|
||||
.await
|
||||
.map(|_| ke)
|
||||
.map_err(MpcTlsError::preprocess)
|
||||
}
|
||||
.scope_boxed()
|
||||
},
|
||||
|ctx| {
|
||||
async move {
|
||||
record_layer
|
||||
.preprocess(ctx)
|
||||
.await
|
||||
.map(|_| record_layer)
|
||||
.map_err(MpcTlsError::preprocess)
|
||||
}
|
||||
.scope_boxed()
|
||||
},
|
||||
|ctx| {
|
||||
async move {
|
||||
vm.flush(ctx).await.map_err(MpcTlsError::preprocess)?;
|
||||
vm.preprocess(ctx).await.map_err(MpcTlsError::preprocess)?;
|
||||
vm.flush(ctx).await.map_err(MpcTlsError::preprocess)?;
|
||||
|
||||
Ok::<_, MpcTlsError>(())
|
||||
}
|
||||
.scope_boxed()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(MpcTlsError::hs)??
|
||||
};
|
||||
|
||||
self.state = State::Ready {
|
||||
vm,
|
||||
keys,
|
||||
ke,
|
||||
prf,
|
||||
record_layer,
|
||||
cf_vd,
|
||||
sf_vd,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the follower.
|
||||
#[instrument(skip_all, err)]
|
||||
pub async fn run(mut self) -> Result<(Context, FollowerData), MpcTlsError> {
|
||||
let State::Ready {
|
||||
vm,
|
||||
keys,
|
||||
mut ke,
|
||||
mut prf,
|
||||
mut record_layer,
|
||||
cf_vd: mut cf_vd_fut,
|
||||
sf_vd: mut sf_vd_fut,
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(MpcTlsError::state("must be in ready state to run"));
|
||||
};
|
||||
|
||||
let mut server_random = None;
|
||||
let mut server_key = None;
|
||||
let mut cf_vd = None;
|
||||
let mut sf_vd = None;
|
||||
loop {
|
||||
let msg: Message = self.ctx.io_mut().expect_next().await?;
|
||||
match msg {
|
||||
Message::SetServerRandom(random) => {
|
||||
if server_random.is_some() {
|
||||
return Err(MpcTlsError::hs("server random already set"));
|
||||
}
|
||||
|
||||
let mut vm = vm
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
|
||||
|
||||
prf.set_server_random(&mut (*vm), random.random)?;
|
||||
|
||||
server_random = Some(random);
|
||||
}
|
||||
Message::SetServerKey(key) => {
|
||||
if server_key.is_some() {
|
||||
return Err(MpcTlsError::hs("server key already set"));
|
||||
}
|
||||
|
||||
let key = key.key;
|
||||
let NamedGroup::secp256r1 = key.group else {
|
||||
return Err(MpcTlsError::hs("unsupported server key group"));
|
||||
};
|
||||
|
||||
ke.set_server_key(
|
||||
p256::PublicKey::from_sec1_bytes(&key.key)
|
||||
.map_err(|_| MpcTlsError::hs("failed to parse server key"))?,
|
||||
)?;
|
||||
|
||||
server_key = Some(key);
|
||||
|
||||
let mut vm = vm
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
|
||||
|
||||
ke.compute_shares(&mut self.ctx).await?;
|
||||
ke.assign(&mut (*vm))?;
|
||||
|
||||
vm.execute_all(&mut self.ctx)
|
||||
.await
|
||||
.map_err(MpcTlsError::hs)?;
|
||||
|
||||
ke.finalize().await?;
|
||||
record_layer.setup(&mut self.ctx).await?;
|
||||
}
|
||||
Message::ClientFinishedVd(vd) => {
|
||||
if cf_vd.is_some() {
|
||||
return Err(MpcTlsError::hs("client finished VD already computed"));
|
||||
}
|
||||
|
||||
let mut vm = vm
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
|
||||
|
||||
prf.set_cf_hash(&mut (*vm), vd.handshake_hash)?;
|
||||
|
||||
vm.execute_all(&mut self.ctx)
|
||||
.await
|
||||
.map_err(MpcTlsError::hs)?;
|
||||
|
||||
cf_vd = Some(
|
||||
cf_vd_fut
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::hs)?
|
||||
.ok_or(MpcTlsError::hs("client finished VD not computed"))?,
|
||||
);
|
||||
}
|
||||
Message::ServerFinishedVd(vd) => {
|
||||
if sf_vd.is_some() {
|
||||
return Err(MpcTlsError::hs("server finished VD already computed"));
|
||||
}
|
||||
|
||||
let mut vm = vm
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
|
||||
|
||||
prf.set_sf_hash(&mut (*vm), vd.handshake_hash)?;
|
||||
|
||||
vm.execute_all(&mut self.ctx)
|
||||
.await
|
||||
.map_err(MpcTlsError::hs)?;
|
||||
|
||||
sf_vd = Some(
|
||||
sf_vd_fut
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::hs)?
|
||||
.ok_or(MpcTlsError::hs("server finished VD not computed"))?,
|
||||
);
|
||||
}
|
||||
Message::Encrypt(encrypt) => {
|
||||
record_layer
|
||||
.push_encrypt(
|
||||
encrypt.typ,
|
||||
encrypt.version,
|
||||
encrypt.len,
|
||||
encrypt.plaintext,
|
||||
encrypt.mode,
|
||||
)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
}
|
||||
Message::Decrypt(decrypt) => {
|
||||
record_layer
|
||||
.push_decrypt(
|
||||
decrypt.typ,
|
||||
decrypt.version,
|
||||
decrypt.explicit_nonce,
|
||||
decrypt.ciphertext,
|
||||
decrypt.tag,
|
||||
decrypt.mode,
|
||||
)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
}
|
||||
Message::Flush { is_decrypting } => {
|
||||
record_layer
|
||||
.flush(&mut self.ctx, vm.clone(), is_decrypting)
|
||||
.await?;
|
||||
debug!("flushed record layer");
|
||||
}
|
||||
Message::CloseConnection => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("committing");
|
||||
|
||||
let transcript = record_layer.commit(&mut self.ctx, vm).await?;
|
||||
|
||||
debug!("committed");
|
||||
|
||||
let server_key = server_key.ok_or(MpcTlsError::hs("server key not set"))?;
|
||||
let cf_vd = cf_vd.ok_or(MpcTlsError::hs("client finished VD not computed"))?;
|
||||
let sf_vd = sf_vd.ok_or(MpcTlsError::hs("server finished VD not computed"))?;
|
||||
|
||||
validate_transcript(cf_vd, sf_vd, &transcript)?;
|
||||
|
||||
Ok((
|
||||
self.ctx,
|
||||
FollowerData {
|
||||
server_key,
|
||||
transcript,
|
||||
keys,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
Init {
|
||||
vm: Vm,
|
||||
ke: Box<dyn KeyExchange + Send + Sync + 'static>,
|
||||
prf: MpcPrf,
|
||||
record_layer: RecordLayer,
|
||||
},
|
||||
Setup {
|
||||
vm: Vm,
|
||||
keys: SessionKeys,
|
||||
ke: Box<dyn KeyExchange + Send + Sync + 'static>,
|
||||
prf: MpcPrf,
|
||||
record_layer: RecordLayer,
|
||||
cf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
|
||||
sf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
|
||||
},
|
||||
Ready {
|
||||
vm: Vm,
|
||||
keys: SessionKeys,
|
||||
ke: Box<dyn KeyExchange + Send + Sync + 'static>,
|
||||
prf: MpcPrf,
|
||||
record_layer: RecordLayer,
|
||||
cf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
|
||||
sf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
|
||||
},
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for State {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Init { .. } => f.debug_struct("Init").finish_non_exhaustive(),
|
||||
Self::Setup { .. } => f.debug_struct("Setup").finish_non_exhaustive(),
|
||||
Self::Ready { .. } => f.debug_struct("Ready").finish_non_exhaustive(),
|
||||
Self::Error => write!(f, "Error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_transcript(
|
||||
cf_vd: [u8; 12],
|
||||
sf_vd: [u8; 12],
|
||||
transcript: &TlsTranscript,
|
||||
) -> Result<(), MpcTlsError> {
|
||||
let mut sent = transcript.sent.iter();
|
||||
let mut recv = transcript.recv.iter();
|
||||
|
||||
// Make sure the client finished verify data message was consistent.
|
||||
if let Some(record) = sent.next() {
|
||||
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
|
||||
"client finished message was hidden from the follower",
|
||||
))?;
|
||||
|
||||
let mut reader = Reader::init(payload);
|
||||
let payload = HandshakeMessagePayload::read_version(&mut reader, ProtocolVersion::TLSv1_2)
|
||||
.ok_or(MpcTlsError::record_layer(
|
||||
"first record sent was not a handshake message",
|
||||
))?;
|
||||
|
||||
let HandshakePayload::Finished(actual_cf_vd) = payload.payload else {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"first record sent was not a client finished message",
|
||||
));
|
||||
};
|
||||
|
||||
if cf_vd != actual_cf_vd.0.as_slice() {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"client finished verify data does not match output from PRF: {:?} != {:?}",
|
||||
cf_vd, actual_cf_vd
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
return Err(MpcTlsError::record_layer("no records were sent"));
|
||||
}
|
||||
|
||||
// Make sure the server finished verify data message was consistent.
|
||||
if let Some(record) = recv.next() {
|
||||
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
|
||||
"server finished message was hidden from the follower",
|
||||
))?;
|
||||
|
||||
let mut reader = Reader::init(payload);
|
||||
let payload = HandshakeMessagePayload::read_version(&mut reader, ProtocolVersion::TLSv1_2)
|
||||
.ok_or(MpcTlsError::record_layer(
|
||||
"first record received was not a handshake message",
|
||||
))?;
|
||||
|
||||
let HandshakePayload::Finished(actual_sf_vd) = payload.payload else {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"first record received was not a server finished message",
|
||||
));
|
||||
};
|
||||
|
||||
if sf_vd != actual_sf_vd.0.as_slice() {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"server finished verify data does not match output from PRF: {:?} != {:?}",
|
||||
sf_vd, actual_sf_vd
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
return Err(MpcTlsError::record_layer("no records were received"));
|
||||
}
|
||||
|
||||
// Verify last record sent was either application data or close notify.
|
||||
if let Some(record) = sent.next_back() {
|
||||
match record.typ {
|
||||
ContentType::ApplicationData => {}
|
||||
ContentType::Alert => {
|
||||
// Ensure the alert is a close notify.
|
||||
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
|
||||
"alert content was hidden from the follower",
|
||||
))?;
|
||||
|
||||
let mut reader = Reader::init(payload);
|
||||
let payload = AlertMessagePayload::read(&mut reader)
|
||||
.ok_or(MpcTlsError::record_layer("alert message was malformed"))?;
|
||||
|
||||
let AlertDescription::CloseNotify = payload.description else {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"sent alert that is not close notify",
|
||||
));
|
||||
};
|
||||
}
|
||||
typ => {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"sent unexpected record content type: {:?}",
|
||||
typ
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify last record received was either application data or close notify.
|
||||
if let Some(record) = recv.next_back() {
|
||||
match record.typ {
|
||||
ContentType::ApplicationData => {}
|
||||
ContentType::Alert => {
|
||||
// Ensure the alert is a close notify.
|
||||
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
|
||||
"alert content was hidden from the follower",
|
||||
))?;
|
||||
|
||||
let mut reader = Reader::init(payload);
|
||||
let payload = AlertMessagePayload::read(&mut reader)
|
||||
.ok_or(MpcTlsError::record_layer("alert message was malformed"))?;
|
||||
|
||||
let AlertDescription::CloseNotify = payload.description else {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"received alert that is not close notify",
|
||||
));
|
||||
};
|
||||
}
|
||||
typ => {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"received unexpected record content type: {:?}",
|
||||
typ
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all other records were application data.
|
||||
for record in sent {
|
||||
if record.typ != ContentType::ApplicationData {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"sent unexpected record content type: {:?}",
|
||||
record.typ
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
for record in recv {
|
||||
if record.typ != ContentType::ApplicationData {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"received unexpected record content type: {:?}",
|
||||
record.typ
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
1004
crates/mpc-tls/src/leader.rs
Normal file
1004
crates/mpc-tls/src/leader.rs
Normal file
File diff suppressed because it is too large
Load Diff
1714
crates/mpc-tls/src/leader/actor.rs
Normal file
1714
crates/mpc-tls/src/leader/actor.rs
Normal file
File diff suppressed because it is too large
Load Diff
106
crates/mpc-tls/src/lib.rs
Normal file
106
crates/mpc-tls/src/lib.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
//! TLSNotary MPC-TLS protocol implementation.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod config;
|
||||
mod decode;
|
||||
mod error;
|
||||
pub(crate) mod follower;
|
||||
pub(crate) mod leader;
|
||||
mod msg;
|
||||
mod record_layer;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub use config::{Config, ConfigBuilder, ConfigBuilderError};
|
||||
pub use error::MpcTlsError;
|
||||
pub use follower::MpcTlsFollower;
|
||||
pub use leader::{LeaderCtrl, MpcTlsLeader};
|
||||
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
Array,
|
||||
};
|
||||
use mpz_vm_core::Vm as VmTrait;
|
||||
use tls_core::{
|
||||
cert::ServerCertDetails,
|
||||
ke::ServerKxDetails,
|
||||
key::PublicKey,
|
||||
msgs::{
|
||||
enums::{CipherSuite, ProtocolVersion},
|
||||
handshake::Random,
|
||||
},
|
||||
};
|
||||
use tlsn_common::transcript::TlsTranscript;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub(crate) type BoxFut<T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>;
|
||||
/// Virtual machine type.
|
||||
pub type Vm = Arc<Mutex<dyn VmTrait<Binary> + Send + Sync + 'static>>;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum Role {
|
||||
Leader,
|
||||
Follower,
|
||||
}
|
||||
|
||||
/// TLS session keys.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionKeys {
|
||||
/// Client write key.
|
||||
pub client_write_key: Array<U8, 16>,
|
||||
/// Client write IV.
|
||||
pub client_write_iv: Array<U8, 4>,
|
||||
/// Server write key.
|
||||
pub server_write_key: Array<U8, 16>,
|
||||
/// Server write IV.
|
||||
pub server_write_iv: Array<U8, 4>,
|
||||
}
|
||||
|
||||
impl From<hmac_sha256::SessionKeys> for SessionKeys {
|
||||
fn from(keys: hmac_sha256::SessionKeys) -> Self {
|
||||
Self {
|
||||
client_write_key: keys.client_write_key,
|
||||
client_write_iv: keys.client_iv,
|
||||
server_write_key: keys.server_write_key,
|
||||
server_write_iv: keys.server_iv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MPC-TLS Leader output.
|
||||
#[derive(Debug)]
|
||||
pub struct LeaderOutput {
|
||||
/// TLS protocol version.
|
||||
pub protocol_version: ProtocolVersion,
|
||||
/// TLS cipher suite.
|
||||
pub cipher_suite: CipherSuite,
|
||||
/// Server ephemeral public key.
|
||||
pub server_key: PublicKey,
|
||||
/// Server certificate chain and related details.
|
||||
pub server_cert_details: ServerCertDetails,
|
||||
/// Key exchange details.
|
||||
pub server_kx_details: ServerKxDetails,
|
||||
/// Client random
|
||||
pub client_random: Random,
|
||||
/// Server random
|
||||
pub server_random: Random,
|
||||
/// TLS transcript.
|
||||
pub transcript: TlsTranscript,
|
||||
/// TLS session keys.
|
||||
pub keys: SessionKeys,
|
||||
}
|
||||
|
||||
/// MPC-TLS Follower output.
|
||||
#[derive(Debug)]
|
||||
pub struct FollowerData {
|
||||
/// Server ephemeral public key.
|
||||
pub server_key: PublicKey,
|
||||
/// TLS transcript.
|
||||
pub transcript: TlsTranscript,
|
||||
/// TLS session keys.
|
||||
pub keys: SessionKeys,
|
||||
}
|
||||
62
crates/mpc-tls/src/msg.rs
Normal file
62
crates/mpc-tls/src/msg.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tls_core::{
|
||||
key::PublicKey,
|
||||
msgs::enums::{ContentType, ProtocolVersion},
|
||||
};
|
||||
|
||||
use crate::record_layer::{DecryptMode, EncryptMode};
|
||||
|
||||
/// MPC-TLS protocol message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) enum Message {
|
||||
SetServerRandom(SetServerRandom),
|
||||
SetServerKey(SetServerKey),
|
||||
ClientFinishedVd(ClientFinishedVd),
|
||||
ServerFinishedVd(ServerFinishedVd),
|
||||
Encrypt(Encrypt),
|
||||
Decrypt(Decrypt),
|
||||
Flush { is_decrypting: bool },
|
||||
CloseConnection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct SetServerRandom {
|
||||
pub(crate) random: [u8; 32],
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct SetServerKey {
|
||||
pub(crate) key: PublicKey,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct Decrypt {
|
||||
pub(crate) typ: ContentType,
|
||||
pub(crate) version: ProtocolVersion,
|
||||
pub(crate) explicit_nonce: Vec<u8>,
|
||||
pub(crate) ciphertext: Vec<u8>,
|
||||
pub(crate) tag: Vec<u8>,
|
||||
pub(crate) mode: DecryptMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct Encrypt {
|
||||
pub(crate) typ: ContentType,
|
||||
pub(crate) version: ProtocolVersion,
|
||||
pub(crate) len: usize,
|
||||
pub(crate) plaintext: Option<Vec<u8>>,
|
||||
pub(crate) mode: EncryptMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct ClientFinishedVd {
|
||||
pub handshake_hash: [u8; 32],
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct ServerFinishedVd {
|
||||
pub handshake_hash: [u8; 32],
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct CloseConnection;
|
||||
558
crates/mpc-tls/src/record_layer.rs
Normal file
558
crates/mpc-tls/src/record_layer.rs
Normal file
@@ -0,0 +1,558 @@
|
||||
//! TLS record layer.
|
||||
|
||||
pub(crate) mod aead;
|
||||
mod aes_ctr;
|
||||
mod decrypt;
|
||||
mod encrypt;
|
||||
|
||||
use std::{collections::VecDeque, mem::take, sync::Arc};
|
||||
|
||||
use aead::MpcAesGcm;
|
||||
use futures::TryFutureExt;
|
||||
use mpz_common::{scoped_futures::ScopedFutureExt, Context, Task};
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
Array,
|
||||
};
|
||||
use mpz_vm_core::Vm as VmTrait;
|
||||
use rand::{thread_rng, RngCore};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tls_core::{
|
||||
cipher::make_tls12_aad,
|
||||
msgs::enums::{ContentType, ProtocolVersion},
|
||||
};
|
||||
use tlsn_common::transcript::{Record, TlsTranscript};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
record_layer::{aes_ctr::AesCtr, decrypt::DecryptOp, encrypt::EncryptOp},
|
||||
MpcTlsError, Role, Vm,
|
||||
};
|
||||
pub(crate) use decrypt::DecryptMode;
|
||||
pub(crate) use encrypt::EncryptMode;
|
||||
|
||||
const MAX_RECORD_SIZE: usize = 1026 * 16;
|
||||
// This limits how much the leader can cause the follower to allocate.
|
||||
const MAX_BUFFER_SIZE: usize = (16 * (1 << 20)) / MAX_RECORD_SIZE;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct PlainRecord {
|
||||
pub(crate) typ: ContentType,
|
||||
pub(crate) version: ProtocolVersion,
|
||||
pub(crate) plaintext: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct EncryptedRecord {
|
||||
pub(crate) typ: ContentType,
|
||||
pub(crate) version: ProtocolVersion,
|
||||
pub(crate) explicit_nonce: Vec<u8>,
|
||||
pub(crate) ciphertext: Vec<u8>,
|
||||
pub(crate) tag: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Init,
|
||||
Online {
|
||||
recv_otp: Option<Vec<u8>>,
|
||||
sent_records: Vec<Record>,
|
||||
recv_records: Vec<Record>,
|
||||
},
|
||||
Complete,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
/// MPC-TLS record layer.
|
||||
pub(crate) struct RecordLayer {
|
||||
role: Role,
|
||||
write_seq: u64,
|
||||
read_seq: u64,
|
||||
encrypter: Arc<Mutex<MpcAesGcm>>,
|
||||
decrypt: Arc<Mutex<MpcAesGcm>>,
|
||||
aes_ctr: AesCtr,
|
||||
state: State,
|
||||
|
||||
encrypt_buffer: Vec<EncryptOp>,
|
||||
decrypt_buffer: Vec<DecryptOp>,
|
||||
encrypted_buffer: VecDeque<EncryptedRecord>,
|
||||
decrypted_buffer: VecDeque<PlainRecord>,
|
||||
}
|
||||
|
||||
impl RecordLayer {
|
||||
/// Creates a new record layer.
|
||||
pub(crate) fn new(role: Role, encrypt: MpcAesGcm, decrypt: MpcAesGcm) -> Self {
|
||||
Self {
|
||||
role,
|
||||
write_seq: 0,
|
||||
read_seq: 0,
|
||||
encrypter: Arc::new(Mutex::new(encrypt)),
|
||||
decrypt: Arc::new(Mutex::new(decrypt)),
|
||||
aes_ctr: AesCtr::new(role),
|
||||
state: State::Init,
|
||||
encrypt_buffer: Vec::new(),
|
||||
decrypt_buffer: Vec::new(),
|
||||
encrypted_buffer: VecDeque::new(),
|
||||
decrypted_buffer: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates resources for the record layer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `sent_records` - Number of sent records to allocate.
|
||||
/// * `recv_records` - Number of received records to allocate.
|
||||
/// * `sent_len` - Total length of sent records to allocate.
|
||||
/// * `recv_len` - Total length of received records to allocate.
|
||||
pub(crate) fn alloc(
|
||||
&mut self,
|
||||
vm: &mut dyn VmTrait<Binary>,
|
||||
sent_records: usize,
|
||||
recv_records: usize,
|
||||
sent_len: usize,
|
||||
recv_len: usize,
|
||||
) -> Result<(), MpcTlsError> {
|
||||
let State::Init = self.state.take() else {
|
||||
return Err(MpcTlsError::other("record layer is already allocated"));
|
||||
};
|
||||
|
||||
let mut encrypt = self
|
||||
.encrypter
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
|
||||
|
||||
let mut decrypt = self
|
||||
.decrypt
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
|
||||
|
||||
encrypt
|
||||
.alloc(vm, sent_records, sent_len)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
decrypt
|
||||
.alloc(vm, recv_records, recv_len)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
let recv_otp = match self.role {
|
||||
Role::Leader => {
|
||||
let mut recv_otp = vec![0u8; recv_len];
|
||||
thread_rng().fill_bytes(&mut recv_otp);
|
||||
|
||||
Some(recv_otp)
|
||||
}
|
||||
Role::Follower => None,
|
||||
};
|
||||
|
||||
self.aes_ctr.alloc(vm)?;
|
||||
|
||||
self.state = State::Online {
|
||||
recv_otp,
|
||||
sent_records: Vec::new(),
|
||||
recv_records: Vec::new(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), MpcTlsError> {
|
||||
let mut encrypt = self
|
||||
.encrypter
|
||||
.clone()
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
|
||||
let mut decrypt = self
|
||||
.decrypt
|
||||
.clone()
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
|
||||
|
||||
// Computes GHASH keys in parallel.
|
||||
ctx.try_join(
|
||||
|ctx| async move { encrypt.preprocess(ctx).await }.scope_boxed(),
|
||||
|ctx| async move { decrypt.preprocess(ctx).await }.scope_boxed(),
|
||||
)
|
||||
.await
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the keys for the record layer.
|
||||
pub(crate) fn set_keys(
|
||||
&mut self,
|
||||
client_write_key: Array<U8, 16>,
|
||||
client_iv: Array<U8, 4>,
|
||||
server_write_key: Array<U8, 16>,
|
||||
server_iv: Array<U8, 4>,
|
||||
) -> Result<(), MpcTlsError> {
|
||||
let mut encrypt = self
|
||||
.encrypter
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
|
||||
let mut decrypt = self
|
||||
.decrypt
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
|
||||
|
||||
encrypt.set_key(client_write_key);
|
||||
encrypt.set_iv(client_iv);
|
||||
decrypt.set_key(server_write_key);
|
||||
decrypt.set_iv(server_iv);
|
||||
self.aes_ctr.set_key(server_write_key, server_iv);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets up the record layer.
|
||||
pub(crate) async fn setup(&mut self, ctx: &mut Context) -> Result<(), MpcTlsError> {
|
||||
let mut encrypt = self
|
||||
.encrypter
|
||||
.clone()
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
|
||||
let mut decrypt = self
|
||||
.decrypt
|
||||
.clone()
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
|
||||
|
||||
// Computes GHASH keys in parallel.
|
||||
ctx.try_join(
|
||||
|ctx| async move { encrypt.setup(ctx).await }.scope_boxed(),
|
||||
|ctx| async move { decrypt.setup(ctx).await }.scope_boxed(),
|
||||
)
|
||||
.await
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.encrypt_buffer.is_empty()
|
||||
&& self.decrypt_buffer.is_empty()
|
||||
&& self.encrypted_buffer.is_empty()
|
||||
&& self.decrypted_buffer.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn wants_flush(&self) -> bool {
|
||||
!self.encrypt_buffer.is_empty() || !self.decrypt_buffer.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn push_encrypt(
|
||||
&mut self,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
len: usize,
|
||||
plaintext: Option<Vec<u8>>,
|
||||
mode: EncryptMode,
|
||||
) -> Result<(), MpcTlsError> {
|
||||
if self.encrypt_buffer.len() >= MAX_BUFFER_SIZE {
|
||||
return Err(MpcTlsError::peer("encrypt buffer is full"));
|
||||
}
|
||||
|
||||
let (seq, explicit_nonce, aad) = self.next_write(typ, version, len);
|
||||
self.encrypt_buffer.push(EncryptOp::new(
|
||||
seq,
|
||||
typ,
|
||||
version,
|
||||
len,
|
||||
plaintext,
|
||||
explicit_nonce,
|
||||
aad,
|
||||
mode,
|
||||
)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn push_decrypt(
|
||||
&mut self,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
tag: Vec<u8>,
|
||||
mode: DecryptMode,
|
||||
) -> Result<(), MpcTlsError> {
|
||||
if self.decrypt_buffer.len() >= MAX_BUFFER_SIZE {
|
||||
return Err(MpcTlsError::peer("decrypt buffer is full"));
|
||||
}
|
||||
|
||||
let (seq, aad) = self.next_read(typ, version, ciphertext.len());
|
||||
self.decrypt_buffer.push(DecryptOp::new(
|
||||
seq,
|
||||
typ,
|
||||
version,
|
||||
explicit_nonce,
|
||||
ciphertext,
|
||||
aad,
|
||||
tag,
|
||||
mode,
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the next encrypted record.
|
||||
pub(crate) fn next_encrypted(&mut self) -> Option<EncryptedRecord> {
|
||||
self.encrypted_buffer.pop_front()
|
||||
}
|
||||
|
||||
/// Returns the next decrypted record.
|
||||
pub(crate) fn next_decrypted(&mut self) -> Option<PlainRecord> {
|
||||
self.decrypted_buffer.pop_front()
|
||||
}
|
||||
|
||||
pub(crate) async fn flush(
|
||||
&mut self,
|
||||
ctx: &mut Context,
|
||||
vm: Vm,
|
||||
is_decrypting: bool,
|
||||
) -> Result<(), MpcTlsError> {
|
||||
let State::Online {
|
||||
recv_otp,
|
||||
sent_records,
|
||||
recv_records,
|
||||
..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(MpcTlsError::state(
|
||||
"record layer must be in online state to flush",
|
||||
));
|
||||
};
|
||||
|
||||
let mut vm = vm
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::record_layer("VM lock is held"))?;
|
||||
|
||||
let mut encrypter = self
|
||||
.encrypter
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::record_layer("encrypt lock is held"))?;
|
||||
|
||||
let mut decrypter = self
|
||||
.decrypt
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?;
|
||||
|
||||
let encrypt_ops = take(&mut self.encrypt_buffer);
|
||||
|
||||
let decrypt_end = if is_decrypting {
|
||||
self.decrypt_buffer.len()
|
||||
} else {
|
||||
// Position of the first application data in the decrypt buffer.
|
||||
self.decrypt_buffer
|
||||
.iter()
|
||||
.position(|op| op.typ == ContentType::ApplicationData)
|
||||
.unwrap_or(self.decrypt_buffer.len())
|
||||
};
|
||||
|
||||
let decrypt_ops: Vec<_> = self.decrypt_buffer.drain(..decrypt_end).collect();
|
||||
|
||||
let (pending_encrypt, compute_tags) =
|
||||
encrypt::encrypt(&mut (*vm), &mut encrypter, &encrypt_ops)?;
|
||||
|
||||
let pending_decrypt =
|
||||
decrypt::decrypt_mpc(&mut (*vm), &mut decrypter, recv_otp.as_mut(), &decrypt_ops)?;
|
||||
let verify_tags = decrypt::verify_tags(&mut (*vm), &mut decrypter, &decrypt_ops)?;
|
||||
|
||||
// Run tag computation and VM in parallel.
|
||||
let (mut tags, _, _) = ctx
|
||||
.try_join3(
|
||||
|ctx| {
|
||||
compute_tags
|
||||
.run(ctx)
|
||||
.map_err(MpcTlsError::record_layer)
|
||||
.scope_boxed()
|
||||
},
|
||||
|ctx| {
|
||||
verify_tags
|
||||
.run(ctx)
|
||||
.map_err(MpcTlsError::record_layer)
|
||||
.scope_boxed()
|
||||
},
|
||||
|ctx| {
|
||||
async move { vm.execute_all(ctx).map_err(MpcTlsError::record_layer).await }
|
||||
.scope_boxed()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(MpcTlsError::record_layer)??;
|
||||
|
||||
// Reverse tags, as we will be popping from the back.
|
||||
if let Some(tags) = tags.as_mut() {
|
||||
tags.reverse();
|
||||
}
|
||||
|
||||
for (op, pending) in encrypt_ops.into_iter().zip(pending_encrypt) {
|
||||
let ciphertext = pending.output.try_encrypt()?;
|
||||
self.encrypted_buffer.push_back(EncryptedRecord {
|
||||
typ: op.typ,
|
||||
version: op.version,
|
||||
explicit_nonce: op.explicit_nonce.clone(),
|
||||
ciphertext: ciphertext.clone(),
|
||||
tag: tags.as_mut().and_then(Vec::pop),
|
||||
});
|
||||
|
||||
sent_records.push(Record {
|
||||
seq: op.seq,
|
||||
typ: op.typ,
|
||||
plaintext: op.plaintext,
|
||||
plaintext_ref: pending.plaintext_ref,
|
||||
explicit_nonce: op.explicit_nonce,
|
||||
ciphertext,
|
||||
});
|
||||
}
|
||||
|
||||
for (op, pending) in decrypt_ops.into_iter().zip(pending_decrypt) {
|
||||
let plaintext = pending.output.try_decrypt()?;
|
||||
self.decrypted_buffer.push_back(PlainRecord {
|
||||
typ: op.typ,
|
||||
version: op.version,
|
||||
plaintext: plaintext.clone(),
|
||||
});
|
||||
|
||||
recv_records.push(Record {
|
||||
seq: op.seq,
|
||||
typ: op.typ,
|
||||
plaintext,
|
||||
plaintext_ref: None,
|
||||
explicit_nonce: op.explicit_nonce,
|
||||
ciphertext: op.ciphertext,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn commit(
|
||||
&mut self,
|
||||
ctx: &mut Context,
|
||||
vm: Vm,
|
||||
) -> Result<TlsTranscript, MpcTlsError> {
|
||||
let State::Online {
|
||||
sent_records,
|
||||
mut recv_records,
|
||||
..
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(MpcTlsError::state(
|
||||
"record layer must be in online state to commit",
|
||||
));
|
||||
};
|
||||
|
||||
if !self.encrypt_buffer.is_empty() {
|
||||
return Err(MpcTlsError::state(
|
||||
"record layer can not commit with pending encrypt operations",
|
||||
));
|
||||
}
|
||||
|
||||
let mut vm = vm
|
||||
.try_lock_owned()
|
||||
.map_err(|_| MpcTlsError::record_layer("VM lock is held"))?;
|
||||
|
||||
let mut decrypter = self
|
||||
.decrypt
|
||||
.try_lock()
|
||||
.map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?;
|
||||
|
||||
let buffered_ops = take(&mut self.decrypt_buffer);
|
||||
|
||||
// Verify tags of buffered ciphertexts.
|
||||
let verify_tags = decrypt::verify_tags(&mut (*vm), &mut decrypter, &buffered_ops)?;
|
||||
|
||||
vm.execute_all(ctx)
|
||||
.await
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
verify_tags
|
||||
.run(ctx)
|
||||
.await
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
// Reveal decrypt key to the leader.
|
||||
self.aes_ctr.decode_key(&mut (*vm))?;
|
||||
vm.flush(ctx).await.map_err(MpcTlsError::record_layer)?;
|
||||
self.aes_ctr.finish_decode()?;
|
||||
|
||||
let pending_decrypts = decrypt::decrypt_local(
|
||||
self.role,
|
||||
&mut (*vm),
|
||||
&mut decrypter,
|
||||
&mut self.aes_ctr,
|
||||
&buffered_ops,
|
||||
)?;
|
||||
|
||||
vm.execute_all(ctx)
|
||||
.await
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
for (op, pending) in buffered_ops.into_iter().zip(pending_decrypts) {
|
||||
let plaintext = pending.output.try_decrypt()?;
|
||||
self.decrypted_buffer.push_back(PlainRecord {
|
||||
typ: op.typ,
|
||||
version: op.version,
|
||||
plaintext: plaintext.clone(),
|
||||
});
|
||||
|
||||
recv_records.push(Record {
|
||||
seq: op.seq,
|
||||
typ: op.typ,
|
||||
plaintext,
|
||||
plaintext_ref: None,
|
||||
explicit_nonce: op.explicit_nonce,
|
||||
ciphertext: op.ciphertext,
|
||||
});
|
||||
}
|
||||
|
||||
self.state = State::Complete;
|
||||
|
||||
Ok(TlsTranscript {
|
||||
sent: sent_records,
|
||||
recv: recv_records,
|
||||
})
|
||||
}
|
||||
|
||||
fn next_write(
|
||||
&mut self,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
len: usize,
|
||||
) -> (u64, Vec<u8>, Vec<u8>) {
|
||||
let seq = self.write_seq;
|
||||
self.write_seq += 1;
|
||||
let explicit_nonce = seq.to_be_bytes().to_vec();
|
||||
let aad = make_tls12_aad(seq, typ, version, len).to_vec();
|
||||
|
||||
(seq, explicit_nonce, aad)
|
||||
}
|
||||
|
||||
fn next_read(
|
||||
&mut self,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
len: usize,
|
||||
) -> (u64, Vec<u8>) {
|
||||
let seq = self.read_seq;
|
||||
self.read_seq += 1;
|
||||
let aad = make_tls12_aad(seq, typ, version, len).to_vec();
|
||||
|
||||
(seq, aad)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct TagData {
|
||||
pub(crate) explicit_nonce: Vec<u8>,
|
||||
pub(crate) aad: Vec<u8>,
|
||||
}
|
||||
69
crates/mpc-tls/src/record_layer/aead.rs
Normal file
69
crates/mpc-tls/src/record_layer/aead.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
mod aes_gcm;
|
||||
mod ghash;
|
||||
|
||||
pub(crate) use aes_gcm::MpcAesGcm;
|
||||
use cipher::{aes::AesError, CipherError};
|
||||
pub(crate) use ghash::{ComputeTags, VerifyTags};
|
||||
|
||||
use mpz_memory_core::{binary::U8, Array};
|
||||
use mpz_vm_core::VmError;
|
||||
|
||||
type Nonce = Array<U8, 8>;
|
||||
type Ctr = Array<U8, 4>;
|
||||
type Block = Array<U8, 16>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct AeadError(ErrorRepr);
|
||||
|
||||
impl AeadError {
|
||||
pub(crate) fn state<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::State(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn cipher<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Cipher(err.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn tag<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Tag(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("aead error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("state error: {0}")]
|
||||
State(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("cipher error: {0}")]
|
||||
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("tag error: {0}")]
|
||||
Tag(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl From<VmError> for AeadError {
|
||||
fn from(err: VmError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CipherError> for AeadError {
|
||||
fn from(err: CipherError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AesError> for AeadError {
|
||||
fn from(err: AesError) -> Self {
|
||||
Self(ErrorRepr::Cipher(Box::new(err)))
|
||||
}
|
||||
}
|
||||
664
crates/mpc-tls/src/record_layer/aead/aes_gcm.rs
Normal file
664
crates/mpc-tls/src/record_layer/aead/aes_gcm.rs
Normal file
@@ -0,0 +1,664 @@
|
||||
use std::{future::Future, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
decode::OneTimePadShared,
|
||||
record_layer::{
|
||||
aead::{
|
||||
ghash::{ComputeTagData, ComputeTags, Ghash, MpcGhash, VerifyTagData, VerifyTags},
|
||||
AeadError, Block, Ctr, Nonce,
|
||||
},
|
||||
TagData,
|
||||
},
|
||||
Role,
|
||||
};
|
||||
use cipher::{aes::Aes128, Cipher, CtrBlock, Keystream};
|
||||
use mpz_common::{Context, Flush};
|
||||
use mpz_fields::gf2_128::Gf2_128;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
Vector,
|
||||
};
|
||||
use mpz_share_conversion::ShareConvert;
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
use tracing::instrument;
|
||||
|
||||
const START_CTR: u32 = 2;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
enum State {
|
||||
Init {
|
||||
ghash: Box<dyn Ghash + Send + Sync>,
|
||||
},
|
||||
Setup {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
|
||||
output: Vector<U8>,
|
||||
ghash_key: OneTimePadShared<[u8; 16]>,
|
||||
ghash: Box<dyn Ghash + Send + Sync>,
|
||||
},
|
||||
Ready {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
|
||||
output: Vector<U8>,
|
||||
ghash: Arc<dyn Ghash + Send + Sync>,
|
||||
},
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct MpcAesGcm {
|
||||
role: Role,
|
||||
aes: Aes128,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl MpcAesGcm {
|
||||
/// Creates a new AES-GCM instance.
|
||||
pub(crate) fn new<C>(converter: C, role: Role) -> Self
|
||||
where
|
||||
C: ShareConvert<Gf2_128> + Flush + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
role,
|
||||
aes: Aes128::default(),
|
||||
state: State::Init {
|
||||
ghash: Box::new(MpcGhash::new(converter)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates resources.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine to allocate in.
|
||||
/// * `records` - Number of records to allocate.
|
||||
/// * `len` - Length of the input text in bytes.
|
||||
pub(crate) fn alloc(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
records: usize,
|
||||
len: usize,
|
||||
) -> Result<(), AeadError> {
|
||||
let State::Init { mut ghash } = self.state.take() else {
|
||||
return Err(AeadError::state("must be in init state to allocate"));
|
||||
};
|
||||
|
||||
let zero_block: Array<U8, 16> = vm.alloc()?;
|
||||
vm.mark_public(zero_block)?;
|
||||
vm.assign(zero_block, [0u8; 16])?;
|
||||
vm.commit(zero_block)?;
|
||||
|
||||
ghash.alloc()?;
|
||||
let ghash_key = self.aes.alloc_block(vm, zero_block)?;
|
||||
let ghash_key = OneTimePadShared::<[u8; 16]>::new(self.role, ghash_key, vm)?;
|
||||
|
||||
// Allocate J0 secret sharing for GHASH.
|
||||
let mut j0s = Vec::with_capacity(records);
|
||||
for _ in 0..records {
|
||||
let j0 = self.aes.alloc_ctr_block(vm)?;
|
||||
let j0_shared = OneTimePadShared::<[u8; 16]>::new(self.role, j0.output, vm)?;
|
||||
|
||||
j0s.push((j0, j0_shared));
|
||||
}
|
||||
|
||||
// Allocate encryption/decryption.
|
||||
|
||||
// Round up the length to the nearest multiple of the block count.
|
||||
let len = 16 * len.div_ceil(16);
|
||||
|
||||
let input = vm.alloc_vec::<U8>(len)?;
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
vm.mark_private(input)?;
|
||||
}
|
||||
Role::Follower => {
|
||||
vm.mark_blind(input)?;
|
||||
}
|
||||
}
|
||||
|
||||
let keystream = self.aes.alloc_keystream(vm, len)?;
|
||||
let output = keystream.apply(vm, input)?;
|
||||
|
||||
self.state = State::Setup {
|
||||
input,
|
||||
keystream,
|
||||
j0s,
|
||||
output,
|
||||
ghash,
|
||||
ghash_key,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), AeadError> {
|
||||
let State::Setup { ghash, .. } = &mut self.state else {
|
||||
return Err(AeadError::state("must be in setup state to allocate"));
|
||||
};
|
||||
|
||||
ghash.preprocess(ctx).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn set_key(&mut self, key: Array<U8, 16>) {
|
||||
self.aes.set_key(key);
|
||||
}
|
||||
|
||||
pub(crate) fn set_iv(&mut self, iv: Array<U8, 4>) {
|
||||
self.aes.set_iv(iv);
|
||||
}
|
||||
|
||||
pub(crate) async fn setup(&mut self, ctx: &mut Context) -> Result<(), AeadError> {
|
||||
let State::Setup {
|
||||
input,
|
||||
keystream,
|
||||
j0s,
|
||||
output,
|
||||
mut ghash,
|
||||
ghash_key,
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(AeadError::state("must be in setup state to set up"));
|
||||
};
|
||||
|
||||
let key = ghash_key.await.map_err(AeadError::tag)?;
|
||||
ghash.set_key(key.to_vec())?;
|
||||
ghash.setup(ctx).await?;
|
||||
|
||||
self.state = State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
j0s,
|
||||
output,
|
||||
ghash: Arc::from(ghash),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns `len` bytes of input and output text.
|
||||
///
|
||||
/// The outer context is responsible for assigning to the input text.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `len` - Number of bytes to take.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) fn apply_keystream(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<(Vector<U8>, Vector<U8>), AeadError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(AeadError::state(
|
||||
"must be in ready state to apply keystream",
|
||||
));
|
||||
};
|
||||
|
||||
let explicit_nonce: [u8; 8] = explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
|
||||
AeadError::cipher(format!(
|
||||
"explicit nonce length: expected {}, got {}",
|
||||
16,
|
||||
nonce.len()
|
||||
))
|
||||
})?;
|
||||
|
||||
let block_count = len.div_ceil(16);
|
||||
let padded_len = block_count * 16;
|
||||
let padding_len = padded_len - len;
|
||||
|
||||
if padded_len > input.len() {
|
||||
return Err(AeadError::cipher(format!(
|
||||
"input length exceeds allocated: {} > {}",
|
||||
padded_len,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut input = input.split_off(input.len() - padded_len);
|
||||
let keystream = keystream.consume(padded_len)?;
|
||||
let mut output = output.split_off(output.len() - padded_len);
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
keystream.assign(vm, explicit_nonce, move || {
|
||||
ctr.next().expect("range is unbounded").to_be_bytes()
|
||||
})?;
|
||||
|
||||
// Assign zeroes to the padding.
|
||||
if padding_len > 0 {
|
||||
let padding = input.split_off(input.len() - padding_len);
|
||||
if let Role::Leader = self.role {
|
||||
vm.assign(padding, vec![0; padding_len])?;
|
||||
}
|
||||
vm.commit(padding)?;
|
||||
output.truncate(len);
|
||||
}
|
||||
|
||||
Ok((input, output))
|
||||
}
|
||||
|
||||
/// Returns `len` bytes of keystream.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `len` - Number of bytes to take.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) fn take_keystream(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<Vector<U8>, AeadError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(AeadError::state("must be in ready state to take keystream"));
|
||||
};
|
||||
|
||||
let explicit_nonce: [u8; 8] = explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
|
||||
AeadError::cipher(format!(
|
||||
"explicit nonce length: expected {}, got {}",
|
||||
16,
|
||||
nonce.len()
|
||||
))
|
||||
})?;
|
||||
|
||||
let block_count = len.div_ceil(16);
|
||||
let padded_len = block_count * 16;
|
||||
|
||||
if padded_len > input.len() {
|
||||
return Err(AeadError::cipher(format!(
|
||||
"input length exceeds allocated: {} > {}",
|
||||
padded_len,
|
||||
input.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Drop the input and output text, we won't be needing them.
|
||||
// This leaves them allocated but unassigned in the VM.
|
||||
_ = input.split_off(input.len() - padded_len);
|
||||
_ = output.split_off(output.len() - padded_len);
|
||||
|
||||
let keystream = keystream.consume(len)?;
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
keystream.assign(vm, explicit_nonce, move || {
|
||||
ctr.next().expect("range is unbounded").to_be_bytes()
|
||||
})?;
|
||||
|
||||
Ok(keystream.to_vector(len)?)
|
||||
}
|
||||
|
||||
/// Computes tags for the provided ciphertext. See
|
||||
/// [`verify_tags`](MpcAesGcm::verify_tags) for a method that verifies an
|
||||
/// tags instead.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `ciphertext` - Ciphertext to compute the tag for.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub(crate) fn compute_tags<C>(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
ciphertexts: Vec<C>,
|
||||
data: Vec<TagData>,
|
||||
) -> Result<ComputeTags, AeadError>
|
||||
where
|
||||
C: Future<Output = Result<Vec<u8>, AeadError>> + Send + Sync + 'static,
|
||||
{
|
||||
let State::Ready { j0s, ghash, .. } = &mut self.state else {
|
||||
return Err(AeadError::state("must be in ready state to compute tags"));
|
||||
};
|
||||
|
||||
if ciphertexts.len() != data.len() {
|
||||
return Err(AeadError::tag("ciphertext and data length mismatch"));
|
||||
} else if ciphertexts.len() > j0s.len() {
|
||||
return Err(AeadError::tag("ciphertext length exceeds allocated"));
|
||||
}
|
||||
|
||||
let mut tag_data = Vec::with_capacity(ciphertexts.len());
|
||||
for (ciphertext, data) in ciphertexts.into_iter().zip(data) {
|
||||
let explicit_nonce: [u8; 8] =
|
||||
data.explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
|
||||
AeadError::cipher(format!(
|
||||
"explicit nonce length: expected {}, got {}",
|
||||
8,
|
||||
nonce.len()
|
||||
))
|
||||
})?;
|
||||
let (j0, j0_shared) = j0s.pop().expect("j0 length was checked");
|
||||
|
||||
assign_j0(vm, j0, explicit_nonce)?;
|
||||
|
||||
tag_data.push(ComputeTagData {
|
||||
j0: j0_shared,
|
||||
ciphertext: Box::pin(ciphertext),
|
||||
aad: data.aad,
|
||||
});
|
||||
}
|
||||
|
||||
let tags = ComputeTags::new(self.role, tag_data, ghash.clone());
|
||||
|
||||
Ok(tags)
|
||||
}
|
||||
|
||||
/// Verifies the tags for the provided ciphertexts.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `inputs` - Data to verify the tags for.
|
||||
pub(crate) fn verify_tags(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
data: Vec<TagData>,
|
||||
ciphertexts: Vec<Vec<u8>>,
|
||||
tags: Vec<Vec<u8>>,
|
||||
) -> Result<VerifyTags, AeadError> {
|
||||
let State::Ready { j0s, ghash, .. } = &mut self.state else {
|
||||
return Err(AeadError::state("must be in ready state to verify tags"));
|
||||
};
|
||||
|
||||
if ciphertexts.len() != data.len() {
|
||||
return Err(AeadError::tag("ciphertext and data length mismatch"));
|
||||
} else if ciphertexts.len() != tags.len() {
|
||||
return Err(AeadError::tag("ciphertext and tag length mismatch"));
|
||||
} else if ciphertexts.len() > j0s.len() {
|
||||
return Err(AeadError::tag("ciphertext length exceeds allocated"));
|
||||
}
|
||||
|
||||
let mut tag_data = Vec::with_capacity(ciphertexts.len());
|
||||
for ((ciphertext, data), tag) in ciphertexts.into_iter().zip(data).zip(tags) {
|
||||
let explicit_nonce: [u8; 8] =
|
||||
data.explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
|
||||
AeadError::cipher(format!(
|
||||
"explicit nonce length: expected {}, got {}",
|
||||
8,
|
||||
nonce.len()
|
||||
))
|
||||
})?;
|
||||
let (j0, j0_shared) = j0s.pop().expect("j0 length was checked");
|
||||
|
||||
assign_j0(vm, j0, explicit_nonce)?;
|
||||
|
||||
tag_data.push(VerifyTagData {
|
||||
j0: j0_shared,
|
||||
ciphertext,
|
||||
aad: data.aad,
|
||||
tag,
|
||||
});
|
||||
}
|
||||
|
||||
let tags = VerifyTags::new(self.role, tag_data, ghash.clone());
|
||||
|
||||
Ok(tags)
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_j0(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
j0: CtrBlock<Nonce, Ctr, Block>,
|
||||
explicit_nonce: [u8; 8],
|
||||
) -> Result<(), AeadError> {
|
||||
vm.assign(j0.explicit_nonce, explicit_nonce)?;
|
||||
vm.commit(j0.explicit_nonce)?;
|
||||
vm.assign(j0.counter, 1u32.to_be_bytes())?;
|
||||
vm.commit(j0.counter)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use aes_gcm::{
|
||||
aead::{AeadInPlace, NewAead},
|
||||
Aes128Gcm,
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_memory_core::{binary::U8, correlated::Delta};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_share_conversion::ideal::ideal_share_convert;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use rstest::*;
|
||||
|
||||
static SHORT_MSG: &[u8] = b"hello world";
|
||||
static LONG_MSG: &[u8] = b"this message exceeds one block in length";
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Vars {
|
||||
key: Array<U8, 16>,
|
||||
iv: Array<U8, 4>,
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::short(SHORT_MSG, 1)]
|
||||
#[case::long(LONG_MSG, 1)]
|
||||
#[case::short_multiple(SHORT_MSG, 3)]
|
||||
#[case::long_multiple(LONG_MSG, 3)]
|
||||
#[tokio::test]
|
||||
async fn test_aes_gcm_encrypt(#[case] msg: &[u8], #[case] count: usize) {
|
||||
let (mut ctx_0, mut ctx_1) = test_st_context(8);
|
||||
|
||||
let key = [42u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
|
||||
let ((mut vm_0, vars_0), (mut vm_1, vars_1)) = create_vm(key, iv);
|
||||
let (mut leader, mut follower) = create_pair(vars_0, vars_1);
|
||||
|
||||
leader.alloc(&mut vm_0, count, 256).unwrap();
|
||||
follower.alloc(&mut vm_1, count, 256).unwrap();
|
||||
|
||||
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
|
||||
|
||||
tokio::try_join!(leader.setup(&mut ctx_0), follower.setup(&mut ctx_1)).unwrap();
|
||||
|
||||
for i in 0u64..count as u64 {
|
||||
let explicit_nonce = i.to_be_bytes().to_vec();
|
||||
let (msg_0, ct_0) = leader
|
||||
.apply_keystream(&mut vm_0, explicit_nonce.clone(), msg.len())
|
||||
.unwrap();
|
||||
let (msg_1, ct_1) = follower
|
||||
.apply_keystream(&mut vm_1, explicit_nonce.clone(), msg.len())
|
||||
.unwrap();
|
||||
|
||||
vm_0.assign(msg_0, msg.to_vec()).unwrap();
|
||||
vm_0.commit(msg_0).unwrap();
|
||||
|
||||
vm_1.commit(msg_1).unwrap();
|
||||
|
||||
let ct_0 = vm_0.decode(ct_0).unwrap();
|
||||
let ct_1 = vm_1.decode(ct_1).unwrap();
|
||||
|
||||
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
|
||||
|
||||
let ct_0 = ct_0.await.unwrap();
|
||||
let ct_1 = ct_1.await.unwrap();
|
||||
|
||||
let (expected, _) = expected(&key, &iv, &explicit_nonce, msg, &[]);
|
||||
assert_eq!(ct_0, expected);
|
||||
assert_eq!(ct_1, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::short(SHORT_MSG, 1)]
|
||||
#[case::long(LONG_MSG, 1)]
|
||||
#[case::short_multiple(SHORT_MSG, 3)]
|
||||
#[case::long_multiple(LONG_MSG, 3)]
|
||||
#[tokio::test]
|
||||
async fn test_aes_gcm_decrypt(#[case] msg: &[u8], #[case] count: usize) {
|
||||
let (mut ctx_0, mut ctx_1) = test_st_context(8);
|
||||
|
||||
let key = [42u8; 16];
|
||||
let iv = [0u8; 4];
|
||||
|
||||
let ((mut vm_0, vars_0), (mut vm_1, vars_1)) = create_vm(key, iv);
|
||||
let (mut leader, mut follower) = create_pair(vars_0, vars_1);
|
||||
|
||||
leader.alloc(&mut vm_0, count, 256).unwrap();
|
||||
follower.alloc(&mut vm_1, count, 256).unwrap();
|
||||
|
||||
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
|
||||
|
||||
tokio::try_join!(leader.setup(&mut ctx_0), follower.setup(&mut ctx_1)).unwrap();
|
||||
|
||||
for i in 0u64..count as u64 {
|
||||
let explicit_nonce = i.to_be_bytes().to_vec();
|
||||
let (ct, _) = expected(&key, &iv, &explicit_nonce, msg, &[]);
|
||||
|
||||
let (ct_0, msg_0) = leader
|
||||
.apply_keystream(&mut vm_0, explicit_nonce.clone(), ct.len())
|
||||
.unwrap();
|
||||
let (ct_1, msg_1) = follower
|
||||
.apply_keystream(&mut vm_1, explicit_nonce.clone(), ct.len())
|
||||
.unwrap();
|
||||
|
||||
vm_0.assign(ct_0, ct.clone()).unwrap();
|
||||
vm_0.commit(ct_0).unwrap();
|
||||
|
||||
vm_1.commit(ct_1).unwrap();
|
||||
|
||||
let msg_0 = vm_0.decode(msg_0).unwrap();
|
||||
let msg_1 = vm_1.decode(msg_1).unwrap();
|
||||
|
||||
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
|
||||
|
||||
let msg_0 = msg_0.await.unwrap();
|
||||
let msg_1 = msg_1.await.unwrap();
|
||||
|
||||
assert_eq!(&msg_0, msg);
|
||||
assert_eq!(&msg_1, msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn create_vm(key: [u8; 16], iv: [u8; 4]) -> ((impl Vm<Binary>, Vars), (impl Vm<Binary>, Vars)) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let block = Block::random(&mut rng);
|
||||
let (sender, receiver) = ideal_cot(block);
|
||||
|
||||
let delta = Delta::new(block);
|
||||
let mut vm_0 = Generator::new(sender, [0u8; 16], delta);
|
||||
let mut vm_1 = Evaluator::new(receiver);
|
||||
|
||||
let key_ref_0 = vm_0.alloc::<Array<U8, 16>>().unwrap();
|
||||
vm_0.mark_public(key_ref_0).unwrap();
|
||||
vm_0.assign(key_ref_0, key).unwrap();
|
||||
vm_0.commit(key_ref_0).unwrap();
|
||||
|
||||
let key_ref_1 = vm_1.alloc::<Array<U8, 16>>().unwrap();
|
||||
vm_1.mark_public(key_ref_1).unwrap();
|
||||
vm_1.assign(key_ref_1, key).unwrap();
|
||||
vm_1.commit(key_ref_1).unwrap();
|
||||
|
||||
let iv_ref_0 = vm_0.alloc::<Array<U8, 4>>().unwrap();
|
||||
vm_0.mark_public(iv_ref_0).unwrap();
|
||||
vm_0.assign(iv_ref_0, iv).unwrap();
|
||||
vm_0.commit(iv_ref_0).unwrap();
|
||||
|
||||
let iv_ref_1 = vm_1.alloc::<Array<U8, 4>>().unwrap();
|
||||
vm_1.mark_public(iv_ref_1).unwrap();
|
||||
vm_1.assign(iv_ref_1, iv).unwrap();
|
||||
vm_1.commit(iv_ref_1).unwrap();
|
||||
|
||||
(
|
||||
(
|
||||
vm_0,
|
||||
Vars {
|
||||
key: key_ref_0,
|
||||
iv: iv_ref_0,
|
||||
},
|
||||
),
|
||||
(
|
||||
vm_1,
|
||||
Vars {
|
||||
key: key_ref_1,
|
||||
iv: iv_ref_1,
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_pair(vars_0: Vars, vars_1: Vars) -> (MpcAesGcm, MpcAesGcm) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let (c_0, c_1) = ideal_share_convert(rng.gen());
|
||||
let mut leader = MpcAesGcm::new(c_0, Role::Leader);
|
||||
let mut follower = MpcAesGcm::new(c_1, Role::Follower);
|
||||
|
||||
leader.set_key(vars_0.key);
|
||||
leader.set_iv(vars_0.iv);
|
||||
|
||||
follower.set_key(vars_1.key);
|
||||
follower.set_iv(vars_1.iv);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
|
||||
async fn run_vms(
|
||||
vm_0: &mut (dyn Vm<Binary> + Send),
|
||||
ctx_0: &mut Context,
|
||||
vm_1: &mut (dyn Vm<Binary> + Send),
|
||||
ctx_1: &mut Context,
|
||||
) {
|
||||
tokio::join!(
|
||||
async {
|
||||
vm_0.execute_all(ctx_0).await.unwrap();
|
||||
},
|
||||
async {
|
||||
vm_1.execute_all(ctx_1).await.unwrap();
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
fn expected(
|
||||
key: &[u8],
|
||||
iv: &[u8],
|
||||
explicit_nonce: &[u8],
|
||||
msg: &[u8],
|
||||
aad: &[u8],
|
||||
) -> (Vec<u8>, Vec<u8>) {
|
||||
let key: [u8; 16] = key.try_into().unwrap();
|
||||
let aes = Aes128Gcm::new(&key.into());
|
||||
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(iv);
|
||||
nonce[4..].copy_from_slice(explicit_nonce);
|
||||
|
||||
let mut payload = msg.to_vec();
|
||||
let tag = aes
|
||||
.encrypt_in_place_detached(&nonce.into(), aad, &mut payload)
|
||||
.unwrap();
|
||||
|
||||
(payload, tag.to_vec())
|
||||
}
|
||||
}
|
||||
531
crates/mpc-tls/src/record_layer/aead/ghash.rs
Normal file
531
crates/mpc-tls/src/record_layer/aead/ghash.rs
Normal file
@@ -0,0 +1,531 @@
|
||||
mod compute;
|
||||
mod verify;
|
||||
|
||||
pub(crate) use compute::{ComputeTagData, ComputeTags};
|
||||
pub(crate) use verify::{VerifyTagData, VerifyTags};
|
||||
|
||||
use std::{fmt::Debug, ops::Add};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_common::{future::Output, Context, Flush};
|
||||
use mpz_core::Block;
|
||||
use mpz_fields::{gf2_128::Gf2_128, Field};
|
||||
use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::record_layer::aead::AeadError;
|
||||
|
||||
/// Maximum key share power.
|
||||
const MAX_POWER: usize = 1026;
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait Ghash {
|
||||
/// Allocates resources needed for GHASH.
|
||||
fn alloc(&mut self) -> Result<(), GhashError>;
|
||||
|
||||
/// Preprocesses GHASH.
|
||||
async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), GhashError>;
|
||||
|
||||
/// Sets the key for the hash function.
|
||||
fn set_key(&mut self, key: Vec<u8>) -> Result<(), GhashError>;
|
||||
|
||||
/// Sets up GHASH, computing the key shares.
|
||||
async fn setup(&mut self, ctx: &mut Context) -> Result<(), GhashError>;
|
||||
|
||||
/// Computes the GHASH tag.
|
||||
fn compute(&self, input: &[u8]) -> Result<Vec<u8>, GhashError>;
|
||||
}
|
||||
|
||||
/// MPC GHASH implementation.
|
||||
pub(crate) struct MpcGhash<C> {
|
||||
state: State,
|
||||
converter: C,
|
||||
alloc: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum State {
|
||||
Init,
|
||||
SetKey { key: Gf2_128 },
|
||||
Ready { shares: Vec<Gf2_128> },
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> MpcGhash<C> {
|
||||
/// Creates a new instance.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `converter` - GF2_128 share converter.
|
||||
pub(crate) fn new(converter: C) -> Self {
|
||||
Self {
|
||||
state: State::Init,
|
||||
converter,
|
||||
alloc: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<C> Ghash for MpcGhash<C>
|
||||
where
|
||||
C: AdditiveToMultiplicative<Gf2_128> + Flush + Send,
|
||||
C: MultiplicativeToAdditive<Gf2_128> + Flush + Send,
|
||||
{
|
||||
fn alloc(&mut self) -> Result<(), GhashError> {
|
||||
if !self.alloc {
|
||||
// We need only half the number of `MAX_POWER` M2As because of the free
|
||||
// squaring trick and we need one extra A2M conversion in the beginning.
|
||||
// Both M2A and A2M, each require a single OLE.
|
||||
AdditiveToMultiplicative::<Gf2_128>::alloc(&mut self.converter, 1)
|
||||
.map_err(GhashError::conversion)?;
|
||||
|
||||
MultiplicativeToAdditive::<Gf2_128>::alloc(&mut self.converter, (MAX_POWER / 2) - 1)
|
||||
.map_err(GhashError::conversion)?;
|
||||
|
||||
self.alloc = true;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), GhashError> {
|
||||
self.converter
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(GhashError::conversion)
|
||||
}
|
||||
|
||||
fn set_key(&mut self, key: Vec<u8>) -> Result<(), GhashError> {
|
||||
if key.len() != 16 {
|
||||
return Err(ErrorRepr::KeyLength {
|
||||
expected: 16,
|
||||
actual: key.len(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let State::Init = self.state.take() else {
|
||||
return Err(GhashError::state("Key already set"));
|
||||
};
|
||||
|
||||
let mut h_additive = [0u8; 16];
|
||||
h_additive.copy_from_slice(key.as_slice());
|
||||
|
||||
// GHASH reflects the bits of the key.
|
||||
let h_additive = Gf2_128::new(u128::from_be_bytes(h_additive).reverse_bits());
|
||||
|
||||
self.state = State::SetKey { key: h_additive };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn setup(&mut self, ctx: &mut Context) -> Result<(), GhashError> {
|
||||
let State::SetKey { key: add_key } = self.state.take() else {
|
||||
return Err(GhashError::state("can not setup before key is set"));
|
||||
};
|
||||
|
||||
let mut mult_key = self
|
||||
.converter
|
||||
.queue_to_multiplicative(&[add_key])
|
||||
.map_err(GhashError::conversion)?;
|
||||
|
||||
self.converter
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(GhashError::conversion)?;
|
||||
|
||||
let mult_key = mult_key
|
||||
.try_recv()
|
||||
.map_err(GhashError::conversion)?
|
||||
.expect("share should be computed")
|
||||
.shares[0];
|
||||
|
||||
// Compute the odd powers of the multiplicative key share.
|
||||
//
|
||||
// Resulting vector contains odd powers of H from H^3 to H^1025.
|
||||
let odd_shares: Vec<_> = (0..MAX_POWER)
|
||||
.scan(mult_key, |acc, _| {
|
||||
let power_n = *acc;
|
||||
*acc = power_n * mult_key;
|
||||
Some(power_n)
|
||||
})
|
||||
// Start from H^3
|
||||
.skip(2)
|
||||
// Skip even powers
|
||||
.step_by(2)
|
||||
.collect();
|
||||
|
||||
// Compute the additive shares of the odd powers.
|
||||
let mut add_shares_odd = self
|
||||
.converter
|
||||
.queue_to_additive(&odd_shares)
|
||||
.map_err(GhashError::conversion)?;
|
||||
|
||||
self.converter
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(GhashError::conversion)?;
|
||||
|
||||
let add_shares_odd = add_shares_odd
|
||||
.try_recv()
|
||||
.map_err(GhashError::conversion)?
|
||||
.expect("share should be computed")
|
||||
.shares;
|
||||
|
||||
let shares = compute_shares(add_key, &add_shares_odd);
|
||||
|
||||
self.state = State::Ready { shares };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute(&self, input: &[u8]) -> Result<Vec<u8>, GhashError> {
|
||||
let State::Ready { shares } = &self.state else {
|
||||
return Err(GhashError::state("key shares are not computed"));
|
||||
};
|
||||
|
||||
// Divide by block length and round up.
|
||||
let block_count = input.len() / 16 + (input.len() % 16 != 0) as usize;
|
||||
|
||||
if block_count > MAX_POWER {
|
||||
return Err(ErrorRepr::InputLength {
|
||||
len: block_count,
|
||||
max: MAX_POWER * 16,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let mut input = input.to_vec();
|
||||
|
||||
// Pad input to a multiple of 16 bytes.
|
||||
input.resize(block_count * 16, 0);
|
||||
|
||||
// Convert input to blocks.
|
||||
let blocks = input
|
||||
.chunks_exact(16)
|
||||
.map(|chunk| {
|
||||
let mut block = [0u8; 16];
|
||||
block.copy_from_slice(chunk);
|
||||
Block::from(block)
|
||||
})
|
||||
.collect::<Vec<Block>>();
|
||||
|
||||
let offset = shares.len() - blocks.len();
|
||||
let tag: Block = blocks
|
||||
.iter()
|
||||
.zip(shares.iter().rev().skip(offset))
|
||||
.fold(Gf2_128::zero(), |acc, (block, share)| {
|
||||
acc + Gf2_128::from(block.reverse_bits()) * *share
|
||||
})
|
||||
.into();
|
||||
|
||||
Ok(tag.reverse_bits().to_bytes().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes shares of powers of H.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - Additive share of H.
|
||||
/// * `odd_powers` - Additive shares of odd powers of H starting at H^3.
|
||||
fn compute_shares(key: Gf2_128, odd_powers: &[Gf2_128]) -> Vec<Gf2_128> {
|
||||
let mut shares = Vec::with_capacity(MAX_POWER);
|
||||
|
||||
// H^1
|
||||
shares.push(key);
|
||||
|
||||
let mut odd_idx = 0;
|
||||
for i in 2..=MAX_POWER {
|
||||
if i % 2 == 0 {
|
||||
// Even power, compute by squaring the square root power.
|
||||
let base = shares[i / 2 - 1];
|
||||
shares.push(base * base);
|
||||
} else {
|
||||
// Odd power
|
||||
shares.push(odd_powers[odd_idx]);
|
||||
odd_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
shares
|
||||
}
|
||||
|
||||
/// Builds padded data for GHASH.
|
||||
pub(crate) fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
|
||||
let associated_data_bitlen = (aad.len() as u64) * 8;
|
||||
let text_bitlen = (ciphertext.len() as u64) * 8;
|
||||
|
||||
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
|
||||
|
||||
// Pad data to be a multiple of 16 bytes.
|
||||
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
|
||||
aad.resize(aad_padded_block_count * 16, 0);
|
||||
|
||||
let ciphertext_padded_block_count =
|
||||
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
|
||||
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
|
||||
|
||||
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
|
||||
data.extend(aad);
|
||||
data.extend(ciphertext);
|
||||
data.extend_from_slice(&len_block.to_be_bytes());
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct TagShare([u8; 16]);
|
||||
|
||||
impl Add for TagShare {
|
||||
type Output = Vec<u8>;
|
||||
|
||||
fn add(mut self, rhs: Self) -> Self::Output {
|
||||
self.0.iter_mut().zip(rhs.0).for_each(|(a, b)| *a ^= b);
|
||||
self.0.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct GhashError(#[from] ErrorRepr);
|
||||
|
||||
impl GhashError {
|
||||
fn conversion<E>(error: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::ShareConversion(error.into()))
|
||||
}
|
||||
|
||||
fn state(reason: impl ToString) -> Self {
|
||||
Self(ErrorRepr::State(reason.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("ghash error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("share conversion error: {0}")]
|
||||
ShareConversion(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("invalid state: {0}")]
|
||||
State(String),
|
||||
#[error("incorrect key length, expected: {expected}, actual: {actual}")]
|
||||
KeyLength { expected: usize, actual: usize },
|
||||
#[error("input length exceeds maximum: {len} > {max}")]
|
||||
InputLength { len: usize, max: usize },
|
||||
}
|
||||
|
||||
impl From<GhashError> for AeadError {
|
||||
fn from(value: GhashError) -> Self {
|
||||
AeadError::tag(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ghash_rc::{
|
||||
universal_hash::{KeyInit, UniversalHash as UniversalHashReference},
|
||||
GHash as GhashReference,
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_fields::{gf2_128::Gf2_128, UniformRand};
|
||||
use mpz_share_conversion::ideal::{
|
||||
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
fn create_pair() -> (
|
||||
MpcGhash<IdealShareConvertSender<Gf2_128>>,
|
||||
MpcGhash<IdealShareConvertReceiver<Gf2_128>>,
|
||||
) {
|
||||
let (convert_a, convert_b) = ideal_share_convert(Block::ZERO);
|
||||
|
||||
let (mut sender, mut receiver) = (MpcGhash::new(convert_a), MpcGhash::new(convert_b));
|
||||
sender.alloc().unwrap();
|
||||
receiver.alloc().unwrap();
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_shares() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let key = Gf2_128::rand(&mut rng);
|
||||
let expected_powers: Vec<_> = (0..MAX_POWER)
|
||||
.scan(key, |acc, _| {
|
||||
let power_n = *acc;
|
||||
*acc = power_n * key;
|
||||
Some(power_n)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let odd_powers = expected_powers
|
||||
.iter()
|
||||
.skip(2)
|
||||
.step_by(2)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let powers = compute_shares(key, &odd_powers);
|
||||
|
||||
assert_eq!(powers, expected_powers);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ghash_output() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let h: u128 = rng.gen();
|
||||
let sender_key: u128 = rng.gen();
|
||||
let receiver_key: u128 = h ^ sender_key;
|
||||
|
||||
let message: Vec<u8> = (0..16).map(|_| rng.gen()).collect();
|
||||
|
||||
let (mut sender, mut receiver) = create_pair();
|
||||
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
|
||||
receiver
|
||||
.set_key(receiver_key.to_be_bytes().to_vec())
|
||||
.unwrap();
|
||||
|
||||
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
|
||||
|
||||
let sender_share = sender.compute(&message).unwrap();
|
||||
let receiver_share = receiver.compute(&message).unwrap();
|
||||
|
||||
let tag = sender_share
|
||||
.iter()
|
||||
.zip(receiver_share.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
assert_eq!(tag, ghash_reference_impl(h, &message));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ghash_output_padded() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let h: u128 = rng.gen();
|
||||
let sender_key: u128 = rng.gen();
|
||||
let receiver_key: u128 = h ^ sender_key;
|
||||
|
||||
// Message length is not a multiple of the block length
|
||||
let message: Vec<u8> = (0..14).map(|_| rng.gen()).collect();
|
||||
|
||||
let (mut sender, mut receiver) = create_pair();
|
||||
|
||||
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
|
||||
receiver
|
||||
.set_key(receiver_key.to_be_bytes().to_vec())
|
||||
.unwrap();
|
||||
|
||||
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
|
||||
|
||||
let sender_share = sender.compute(&message).unwrap();
|
||||
let receiver_share = receiver.compute(&message).unwrap();
|
||||
|
||||
let tag = sender_share
|
||||
.iter()
|
||||
.zip(receiver_share.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
assert_eq!(tag, ghash_reference_impl(h, &message));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ghash_long_message() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let h: u128 = rng.gen();
|
||||
let sender_key: u128 = rng.gen();
|
||||
let receiver_key: u128 = h ^ sender_key;
|
||||
|
||||
// A longer message.
|
||||
let long_message: Vec<u8> = (0..30).map(|_| rng.gen()).collect();
|
||||
|
||||
let (mut sender, mut receiver) = create_pair();
|
||||
|
||||
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
|
||||
receiver
|
||||
.set_key(receiver_key.to_be_bytes().to_vec())
|
||||
.unwrap();
|
||||
|
||||
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
|
||||
|
||||
let sender_share = sender.compute(&long_message).unwrap();
|
||||
let receiver_share = receiver.compute(&long_message).unwrap();
|
||||
|
||||
let tag = sender_share
|
||||
.iter()
|
||||
.zip(receiver_share.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
assert_eq!(tag, ghash_reference_impl(h, &long_message));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ghash_repeated() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let h: u128 = rng.gen();
|
||||
let sender_key: u128 = rng.gen();
|
||||
let receiver_key: u128 = h ^ sender_key;
|
||||
|
||||
// Two messages.
|
||||
let first_message: Vec<u8> = (0..14).map(|_| rng.gen()).collect();
|
||||
let second_message: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
|
||||
|
||||
let (mut sender, mut receiver) = create_pair();
|
||||
|
||||
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
|
||||
receiver
|
||||
.set_key(receiver_key.to_be_bytes().to_vec())
|
||||
.unwrap();
|
||||
|
||||
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
|
||||
|
||||
// Compute and check first message.
|
||||
let sender_share = sender.compute(&first_message).unwrap();
|
||||
let receiver_share = receiver.compute(&first_message).unwrap();
|
||||
|
||||
let tag = sender_share
|
||||
.iter()
|
||||
.zip(receiver_share.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
assert_eq!(tag, ghash_reference_impl(h, &first_message));
|
||||
|
||||
// Compute and check second message.
|
||||
let sender_share = sender.compute(&second_message).unwrap();
|
||||
let receiver_share = receiver.compute(&second_message).unwrap();
|
||||
|
||||
let tag = sender_share
|
||||
.iter()
|
||||
.zip(receiver_share.iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
assert_eq!(tag, ghash_reference_impl(h, &second_message));
|
||||
}
|
||||
|
||||
fn ghash_reference_impl(h: u128, message: &[u8]) -> Vec<u8> {
|
||||
let mut ghash = GhashReference::new(&h.to_be_bytes().into());
|
||||
ghash.update_padded(message);
|
||||
let mac = ghash.finalize();
|
||||
mac.to_vec()
|
||||
}
|
||||
}
|
||||
119
crates/mpc-tls/src/record_layer/aead/ghash/compute.rs
Normal file
119
crates/mpc-tls/src/record_layer/aead/ghash/compute.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesOrdered, StreamExt as _};
|
||||
use mpz_common::{Context, Task};
|
||||
use serio::{stream::IoStreamExt, SinkExt};
|
||||
|
||||
use crate::{
|
||||
decode::OneTimePadShared,
|
||||
record_layer::aead::{
|
||||
ghash::{build_ghash_data, Ghash, TagShare},
|
||||
AeadError,
|
||||
},
|
||||
Role,
|
||||
};
|
||||
|
||||
pub(crate) struct ComputeTagData {
|
||||
pub(crate) j0: OneTimePadShared<[u8; 16]>,
|
||||
pub(crate) ciphertext: Pin<Box<dyn Future<Output = Result<Vec<u8>, AeadError>> + Send + Sync>>,
|
||||
pub(crate) aad: Vec<u8>,
|
||||
}
|
||||
|
||||
#[must_use = "compute tags operation must be awaited"]
|
||||
pub(crate) struct ComputeTags {
|
||||
role: Role,
|
||||
data: Vec<ComputeTagData>,
|
||||
ghash: Arc<dyn Ghash + Send + Sync>,
|
||||
}
|
||||
|
||||
impl ComputeTags {
|
||||
pub(crate) fn new(
|
||||
role: Role,
|
||||
data: Vec<ComputeTagData>,
|
||||
ghash: Arc<dyn Ghash + Send + Sync>,
|
||||
) -> Self {
|
||||
Self { role, data, ghash }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Task for ComputeTags {
|
||||
type Output = Result<Option<Vec<Vec<u8>>>, AeadError>;
|
||||
|
||||
async fn run(self, ctx: &mut Context) -> Self::Output {
|
||||
let Self {
|
||||
role,
|
||||
mut data,
|
||||
ghash,
|
||||
} = self;
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut j0_shares = Vec::with_capacity(data.len());
|
||||
{
|
||||
let mut futs = FuturesOrdered::from_iter(data.iter_mut().map(|data| &mut data.j0));
|
||||
while let Some(j0_share) = futs.next().await.transpose().map_err(AeadError::tag)? {
|
||||
j0_shares.push(j0_share);
|
||||
}
|
||||
}
|
||||
|
||||
let mut ciphertexts = Vec::with_capacity(data.len());
|
||||
{
|
||||
let mut futs =
|
||||
FuturesOrdered::from_iter(data.iter_mut().map(|data| &mut data.ciphertext));
|
||||
while let Some(ciphertext) = futs.next().await.transpose().map_err(AeadError::tag)? {
|
||||
ciphertexts.push(ciphertext);
|
||||
}
|
||||
}
|
||||
|
||||
let mut tag_shares = Vec::with_capacity(data.len());
|
||||
for ((mut tag_share, ciphertext), data) in j0_shares.into_iter().zip(ciphertexts).zip(data)
|
||||
{
|
||||
let ghash_share = ghash
|
||||
.compute(&build_ghash_data(data.aad, ciphertext))
|
||||
.map_err(AeadError::tag)?;
|
||||
tag_share
|
||||
.iter_mut()
|
||||
.zip(ghash_share)
|
||||
.for_each(|(a, b)| *a ^= b);
|
||||
|
||||
tag_shares.push(TagShare(tag_share));
|
||||
}
|
||||
|
||||
let tags = match role {
|
||||
Role::Leader => {
|
||||
let follower_tag_shares: Vec<TagShare> =
|
||||
ctx.io_mut().expect_next().await.map_err(AeadError::tag)?;
|
||||
|
||||
if follower_tag_shares.len() != tag_shares.len() {
|
||||
return Err(AeadError::tag("follower tag shares length mismatch"));
|
||||
}
|
||||
|
||||
let tags = tag_shares
|
||||
.into_iter()
|
||||
.zip(follower_tag_shares)
|
||||
.map(|(a, b)| (a + b).to_vec())
|
||||
.collect();
|
||||
|
||||
Some(tags)
|
||||
}
|
||||
Role::Follower => {
|
||||
ctx.io_mut()
|
||||
.send(tag_shares)
|
||||
.await
|
||||
.map_err(AeadError::tag)?;
|
||||
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
Ok(tags)
|
||||
}
|
||||
|
||||
async fn run_boxed(self: Box<Self>, ctx: &mut Context) -> Self::Output {
|
||||
self.run(ctx).await
|
||||
}
|
||||
}
|
||||
134
crates/mpc-tls/src/record_layer/aead/ghash/verify.rs
Normal file
134
crates/mpc-tls/src/record_layer/aead/ghash/verify.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesOrdered, StreamExt};
|
||||
use mpz_common::{Context, Task};
|
||||
use mpz_core::commit::{Decommitment, HashCommit};
|
||||
use serio::{stream::IoStreamExt, SinkExt};
|
||||
|
||||
use crate::{
|
||||
decode::OneTimePadShared,
|
||||
record_layer::aead::{
|
||||
ghash::{build_ghash_data, Ghash, TagShare},
|
||||
AeadError,
|
||||
},
|
||||
Role,
|
||||
};
|
||||
|
||||
pub(crate) struct VerifyTagData {
|
||||
pub(crate) j0: OneTimePadShared<[u8; 16]>,
|
||||
pub(crate) ciphertext: Vec<u8>,
|
||||
pub(crate) aad: Vec<u8>,
|
||||
pub(crate) tag: Vec<u8>,
|
||||
}
|
||||
|
||||
#[must_use = "verify tags operation must be awaited"]
|
||||
pub(crate) struct VerifyTags {
|
||||
role: Role,
|
||||
data: Vec<VerifyTagData>,
|
||||
ghash: Arc<dyn Ghash + Send + Sync>,
|
||||
}
|
||||
|
||||
impl VerifyTags {
|
||||
pub(crate) fn new(
|
||||
role: Role,
|
||||
data: Vec<VerifyTagData>,
|
||||
ghash: Arc<dyn Ghash + Send + Sync>,
|
||||
) -> Self {
|
||||
Self { role, data, ghash }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Task for VerifyTags {
|
||||
type Output = Result<(), AeadError>;
|
||||
|
||||
async fn run(self, ctx: &mut Context) -> Self::Output {
|
||||
let Self {
|
||||
role,
|
||||
mut data,
|
||||
ghash,
|
||||
} = self;
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut j0_shares = Vec::with_capacity(data.len());
|
||||
{
|
||||
let mut futs = FuturesOrdered::from_iter(data.iter_mut().map(|data| &mut data.j0));
|
||||
while let Some(j0_share) = futs.next().await.transpose().map_err(AeadError::tag)? {
|
||||
j0_shares.push(j0_share);
|
||||
}
|
||||
}
|
||||
|
||||
let mut tag_shares = Vec::with_capacity(data.len());
|
||||
let mut tags = Vec::with_capacity(data.len());
|
||||
for (mut tag_share, data) in j0_shares.into_iter().zip(data) {
|
||||
let ghash_share = ghash
|
||||
.compute(&build_ghash_data(data.aad, data.ciphertext))
|
||||
.map_err(AeadError::tag)?;
|
||||
tag_share
|
||||
.iter_mut()
|
||||
.zip(ghash_share)
|
||||
.for_each(|(a, b)| *a ^= b);
|
||||
|
||||
tag_shares.push(TagShare(tag_share));
|
||||
tags.push(data.tag);
|
||||
}
|
||||
|
||||
let io = ctx.io_mut();
|
||||
let peer_tag_shares = match role {
|
||||
Role::Leader => {
|
||||
// Send commitment to follower.
|
||||
let (decommitment, commitment) = tag_shares.clone().hash_commit();
|
||||
|
||||
io.send(commitment).await.map_err(AeadError::tag)?;
|
||||
|
||||
let follower_tag_shares: Vec<TagShare> =
|
||||
io.expect_next().await.map_err(AeadError::tag)?;
|
||||
|
||||
if follower_tag_shares.len() != tag_shares.len() {
|
||||
return Err(AeadError::tag("follower tag shares length mismatch"));
|
||||
}
|
||||
|
||||
// Send decommitment to follower.
|
||||
io.send(decommitment).await.map_err(AeadError::tag)?;
|
||||
|
||||
follower_tag_shares
|
||||
}
|
||||
Role::Follower => {
|
||||
// Wait for commitment from leader.
|
||||
let commitment = io.expect_next().await.map_err(AeadError::tag)?;
|
||||
|
||||
// Send tag shares to leader.
|
||||
io.send(tag_shares.clone()).await.map_err(AeadError::tag)?;
|
||||
|
||||
// Expect decommitment from leader.
|
||||
let decommitment: Decommitment<Vec<TagShare>> =
|
||||
io.expect_next().await.map_err(AeadError::tag)?;
|
||||
|
||||
// Verify decommitment.
|
||||
decommitment.verify(&commitment).map_err(AeadError::tag)?;
|
||||
|
||||
decommitment.into_inner()
|
||||
}
|
||||
};
|
||||
|
||||
let expected_tags = tag_shares
|
||||
.into_iter()
|
||||
.zip(peer_tag_shares)
|
||||
.map(|(tag_share, peer_tag_share)| tag_share + peer_tag_share)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if tags != expected_tags {
|
||||
return Err(AeadError::tag("failed to verify tags"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_boxed(self: Box<Self>, ctx: &mut Context) -> Self::Output {
|
||||
self.run(ctx).await
|
||||
}
|
||||
}
|
||||
263
crates/mpc-tls/src/record_layer/aes_ctr.rs
Normal file
263
crates/mpc-tls/src/record_layer/aes_ctr.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use cipher_crate::{KeyIvInit, StreamCipher as _, StreamCipherSeek};
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
Array, DecodeFutureTyped,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{MpcTlsError, Role};
|
||||
|
||||
type LocalAesCtr = ctr::Ctr32BE<aes::Aes128>;
|
||||
|
||||
enum State {
|
||||
Init,
|
||||
Alloc {
|
||||
masked_key: Array<U8, 16>,
|
||||
masked_iv: Array<U8, 4>,
|
||||
key_otp: Option<[u8; 16]>,
|
||||
iv_otp: Option<[u8; 4]>,
|
||||
},
|
||||
Decode {
|
||||
masked_key: DecodeFutureTyped<BitVec, [u8; 16]>,
|
||||
masked_iv: DecodeFutureTyped<BitVec, [u8; 4]>,
|
||||
key_otp: Option<[u8; 16]>,
|
||||
iv_otp: Option<[u8; 4]>,
|
||||
},
|
||||
Ready {
|
||||
key: Option<[u8; 16]>,
|
||||
iv: Option<[u8; 4]>,
|
||||
},
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AesCtr {
|
||||
role: Role,
|
||||
key: Option<Array<U8, 16>>,
|
||||
iv: Option<Array<U8, 4>>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl AesCtr {
|
||||
pub(crate) fn new(role: Role) -> Self {
|
||||
Self {
|
||||
role,
|
||||
key: None,
|
||||
iv: None,
|
||||
state: State::Init,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
|
||||
self.key = Some(key);
|
||||
self.iv = Some(iv);
|
||||
}
|
||||
|
||||
pub(crate) fn alloc(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), MpcTlsError> {
|
||||
let State::Init = self.state.take() else {
|
||||
Err(MpcTlsError::record_layer(
|
||||
"aes-ctr must be in initialized state to allocate",
|
||||
))?
|
||||
};
|
||||
|
||||
let key = self
|
||||
.key
|
||||
.ok_or_else(|| MpcTlsError::record_layer("key not set in aes-ctr"))?;
|
||||
let iv = self
|
||||
.iv
|
||||
.ok_or_else(|| MpcTlsError::record_layer("iv not set in aes-ctr"))?;
|
||||
|
||||
let (masked_key, key_otp, masked_iv, iv_otp) = match self.role {
|
||||
Role::Leader => {
|
||||
let mut key_otp = [0u8; 16];
|
||||
thread_rng().fill_bytes(&mut key_otp);
|
||||
let mut iv_otp = [0u8; 4];
|
||||
thread_rng().fill_bytes(&mut iv_otp);
|
||||
let masked_key = vm
|
||||
.mask_private(key, key_otp)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
let masked_iv = vm
|
||||
.mask_private(iv, iv_otp)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
(masked_key, Some(key_otp), masked_iv, Some(iv_otp))
|
||||
}
|
||||
Role::Follower => {
|
||||
let masked_key = vm.mask_blind(key).map_err(MpcTlsError::record_layer)?;
|
||||
let masked_iv = vm.mask_blind(iv).map_err(MpcTlsError::record_layer)?;
|
||||
(masked_key, None, masked_iv, None)
|
||||
}
|
||||
};
|
||||
|
||||
self.state = State::Alloc {
|
||||
masked_key,
|
||||
masked_iv,
|
||||
key_otp,
|
||||
iv_otp,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn decode_key(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), MpcTlsError> {
|
||||
let State::Alloc {
|
||||
masked_key,
|
||||
masked_iv,
|
||||
key_otp,
|
||||
iv_otp,
|
||||
} = self.state.take()
|
||||
else {
|
||||
Err(MpcTlsError::record_layer(
|
||||
"aes-ctr must be in allocated state to decode key",
|
||||
))?
|
||||
};
|
||||
|
||||
let masked_key = vm.decode(masked_key).map_err(MpcTlsError::record_layer)?;
|
||||
let masked_iv = vm.decode(masked_iv).map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
self.state = State::Decode {
|
||||
masked_key,
|
||||
masked_iv,
|
||||
key_otp,
|
||||
iv_otp,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn finish_decode(&mut self) -> Result<(), MpcTlsError> {
|
||||
let State::Decode {
|
||||
mut masked_key,
|
||||
mut masked_iv,
|
||||
key_otp,
|
||||
iv_otp,
|
||||
} = self.state.take()
|
||||
else {
|
||||
Err(MpcTlsError::record_layer(
|
||||
"aes-ctr must be in decode state to finish decode",
|
||||
))?
|
||||
};
|
||||
|
||||
let (key, iv) = if let Role::Leader = self.role {
|
||||
let key_otp = key_otp.expect("leader knows key otp");
|
||||
let iv_otp = iv_otp.expect("leader knows iv otp");
|
||||
|
||||
let masked_key = masked_key
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.ok_or_else(|| MpcTlsError::record_layer("masked key is not decoded"))?;
|
||||
let masked_iv = masked_iv
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.ok_or_else(|| MpcTlsError::record_layer("masked iv is not decoded"))?;
|
||||
|
||||
let mut key = masked_key;
|
||||
let mut iv = masked_iv;
|
||||
|
||||
key.iter_mut().zip(key_otp).for_each(|(key, otp)| {
|
||||
*key ^= otp;
|
||||
});
|
||||
|
||||
iv.iter_mut().zip(iv_otp).for_each(|(iv, otp)| {
|
||||
*iv ^= otp;
|
||||
});
|
||||
|
||||
(Some(key), Some(iv))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
self.state = State::Ready { key, iv };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn decrypt(
|
||||
&mut self,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
) -> Result<Vec<u8>, MpcTlsError> {
|
||||
let State::Ready { key, iv, .. } = &self.state else {
|
||||
Err(MpcTlsError::record_layer(
|
||||
"aes-ctr must be in ready state to decrypt",
|
||||
))?
|
||||
};
|
||||
|
||||
if let Role::Follower = self.role {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"aes-ctr must be in leader role to decrypt",
|
||||
));
|
||||
}
|
||||
|
||||
let key = key.as_ref().expect("leader knows key");
|
||||
let iv = iv.as_ref().expect("leader knows iv");
|
||||
|
||||
let explicit_nonce: [u8; 8] =
|
||||
explicit_nonce
|
||||
.try_into()
|
||||
.map_err(|explicit_nonce: Vec<_>| {
|
||||
MpcTlsError::record_layer(format!(
|
||||
"incorrect explicit nonce length: {} != 8",
|
||||
explicit_nonce.len()
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut full_iv = [0u8; 16];
|
||||
full_iv[..4].copy_from_slice(iv);
|
||||
full_iv[4..12].copy_from_slice(&explicit_nonce);
|
||||
|
||||
let mut aes = LocalAesCtr::new(key.into(), &full_iv.into());
|
||||
|
||||
// Skip the first 32 bytes of the keystream to match the AES-GCM implementation.
|
||||
aes.seek(32);
|
||||
|
||||
let mut plaintext = ciphertext;
|
||||
aes.apply_keystream(&mut plaintext);
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use aes_gcm::{aead::AeadMutInPlace, Aes128Gcm, NewAead};
|
||||
|
||||
#[test]
|
||||
fn test_aes_ctr_local() {
|
||||
let key = [0u8; 16];
|
||||
let iv = [42u8; 4];
|
||||
let explicit_nonce = [69u8; 8];
|
||||
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&iv);
|
||||
nonce[4..].copy_from_slice(&explicit_nonce);
|
||||
|
||||
let mut aes_ctr = AesCtr::new(Role::Leader);
|
||||
aes_ctr.state = State::Ready {
|
||||
key: Some(key),
|
||||
iv: Some(iv),
|
||||
};
|
||||
|
||||
let mut aes_gcm = Aes128Gcm::new(&key.into());
|
||||
|
||||
let msg = b"hello world";
|
||||
|
||||
let mut ciphertext = msg.to_vec();
|
||||
_ = aes_gcm
|
||||
.encrypt_in_place_detached(&nonce.into(), &[], &mut ciphertext)
|
||||
.unwrap();
|
||||
|
||||
let decrypted = aes_ctr
|
||||
.decrypt(explicit_nonce.to_vec(), ciphertext)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(msg, decrypted.as_slice());
|
||||
}
|
||||
}
|
||||
296
crates/mpc-tls/src/record_layer/decrypt.rs
Normal file
296
crates/mpc-tls/src/record_layer/decrypt.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{binary::Binary, DecodeFutureTyped};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
|
||||
|
||||
use crate::{
|
||||
record_layer::{
|
||||
aead::{MpcAesGcm, VerifyTags},
|
||||
aes_ctr::AesCtr,
|
||||
TagData,
|
||||
},
|
||||
MpcTlsError, Role,
|
||||
};
|
||||
|
||||
pub(crate) fn private_mpc(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
decrypter: &mut MpcAesGcm,
|
||||
otp: Option<&mut Vec<u8>>,
|
||||
op: &DecryptOp,
|
||||
) -> Result<DecryptOutput, MpcTlsError> {
|
||||
if let Some(otp) = otp.as_ref() {
|
||||
if op.ciphertext.len() > otp.len() {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"ciphertext length exceeds allocated: {} > {}",
|
||||
op.ciphertext.len(),
|
||||
otp.len()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let (otp_ref, masked_keystream) = decrypter
|
||||
.apply_keystream(vm, op.explicit_nonce.clone(), op.ciphertext.len())
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
let otp = otp.map(|otp| otp.split_off(otp.len() - op.ciphertext.len()));
|
||||
if let Some(otp) = otp.clone() {
|
||||
vm.assign(otp_ref, otp).map_err(MpcTlsError::record_layer)?;
|
||||
}
|
||||
vm.commit(otp_ref).map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
// Decode the masked keystream.
|
||||
let masked_keystream = vm
|
||||
.decode(masked_keystream)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
Ok(DecryptOutput::Private(DecryptPrivate {
|
||||
masked_keystream,
|
||||
otp,
|
||||
ciphertext: op.ciphertext.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn public(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
decrypter: &mut MpcAesGcm,
|
||||
op: &DecryptOp,
|
||||
) -> Result<DecryptOutput, MpcTlsError> {
|
||||
// Instead of computing the plaintext in MPC, we only compute the keystream and
|
||||
// decode it for both parties. Each party then locally computes the plaintext.
|
||||
|
||||
let keystream = decrypter
|
||||
.take_keystream(vm, op.explicit_nonce.clone(), op.ciphertext.len())
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
Ok(DecryptOutput::Public(DecryptPublic {
|
||||
keystream: vm.decode(keystream).map_err(MpcTlsError::record_layer)?,
|
||||
ciphertext: op.ciphertext.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn decrypt_mpc(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
decrypter: &mut MpcAesGcm,
|
||||
mut otp: Option<&mut Vec<u8>>,
|
||||
ops: &[DecryptOp],
|
||||
) -> Result<Vec<PendingDecrypt>, MpcTlsError> {
|
||||
let mut pending_decrypt = Vec::with_capacity(ops.len());
|
||||
for op in ops {
|
||||
match op.mode {
|
||||
DecryptMode::Private => {
|
||||
pending_decrypt.push(PendingDecrypt {
|
||||
output: private_mpc(vm, decrypter, otp.as_deref_mut(), op)?,
|
||||
});
|
||||
}
|
||||
DecryptMode::Public => {
|
||||
pending_decrypt.push(PendingDecrypt {
|
||||
output: public(vm, decrypter, op)?,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(pending_decrypt)
|
||||
}
|
||||
|
||||
pub(crate) fn decrypt_local(
|
||||
role: Role,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
mpc_decrypter: &mut MpcAesGcm,
|
||||
local_decrypter: &mut AesCtr,
|
||||
ops: &[DecryptOp],
|
||||
) -> Result<Vec<PendingDecrypt>, MpcTlsError> {
|
||||
let mut pending_decrypt = Vec::with_capacity(ops.len());
|
||||
for op in ops {
|
||||
match op.mode {
|
||||
DecryptMode::Private => {
|
||||
let plaintext = if let Role::Leader = role {
|
||||
let plaintext = local_decrypter
|
||||
.decrypt(op.explicit_nonce.clone(), op.ciphertext.clone())?;
|
||||
Some(plaintext)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
pending_decrypt.push(PendingDecrypt {
|
||||
output: DecryptOutput::Local(DecryptLocal {
|
||||
plaintext: plaintext.clone(),
|
||||
}),
|
||||
});
|
||||
}
|
||||
DecryptMode::Public => {
|
||||
pending_decrypt.push(PendingDecrypt {
|
||||
output: public(vm, mpc_decrypter, op)?,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(pending_decrypt)
|
||||
}
|
||||
|
||||
pub(crate) fn verify_tags(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
decrypter: &mut MpcAesGcm,
|
||||
ops: &[DecryptOp],
|
||||
) -> Result<VerifyTags, MpcTlsError> {
|
||||
let mut ciphertexts = Vec::with_capacity(ops.len());
|
||||
let mut tags = Vec::with_capacity(ops.len());
|
||||
let mut tags_data = Vec::with_capacity(ops.len());
|
||||
for DecryptOp {
|
||||
ciphertext,
|
||||
tag,
|
||||
explicit_nonce,
|
||||
aad,
|
||||
..
|
||||
} in ops
|
||||
{
|
||||
ciphertexts.push(ciphertext.clone());
|
||||
tags.push(tag.clone());
|
||||
tags_data.push(TagData {
|
||||
explicit_nonce: explicit_nonce.clone(),
|
||||
aad: aad.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
decrypter
|
||||
.verify_tags(vm, tags_data, ciphertexts, tags)
|
||||
.map_err(MpcTlsError::record_layer)
|
||||
}
|
||||
|
||||
pub(crate) struct DecryptOp {
|
||||
pub(crate) seq: u64,
|
||||
pub(crate) typ: ContentType,
|
||||
pub(crate) version: ProtocolVersion,
|
||||
pub(crate) explicit_nonce: Vec<u8>,
|
||||
pub(crate) ciphertext: Vec<u8>,
|
||||
pub(crate) aad: Vec<u8>,
|
||||
pub(crate) tag: Vec<u8>,
|
||||
pub(crate) mode: DecryptMode,
|
||||
}
|
||||
|
||||
impl DecryptOp {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
seq: u64,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
explicit_nonce: Vec<u8>,
|
||||
ciphertext: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
tag: Vec<u8>,
|
||||
mode: DecryptMode,
|
||||
) -> Self {
|
||||
Self {
|
||||
seq,
|
||||
typ,
|
||||
version,
|
||||
explicit_nonce,
|
||||
ciphertext,
|
||||
aad,
|
||||
tag,
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub(crate) enum DecryptMode {
|
||||
Private,
|
||||
Public,
|
||||
}
|
||||
|
||||
pub(crate) enum DecryptOutput {
|
||||
Private(DecryptPrivate),
|
||||
Public(DecryptPublic),
|
||||
Local(DecryptLocal),
|
||||
}
|
||||
|
||||
impl DecryptOutput {
|
||||
pub(crate) fn try_decrypt(self) -> Result<Option<Vec<u8>>, MpcTlsError> {
|
||||
match self {
|
||||
DecryptOutput::Private(decrypt) => decrypt.try_decrypt(),
|
||||
DecryptOutput::Public(decrypt) => decrypt.try_decrypt().map(Some),
|
||||
DecryptOutput::Local(decrypt) => decrypt.try_decrypt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PendingDecrypt {
|
||||
pub(crate) output: DecryptOutput,
|
||||
}
|
||||
|
||||
pub(crate) struct DecryptPrivate {
|
||||
masked_keystream: DecodeFutureTyped<BitVec, Vec<u8>>,
|
||||
otp: Option<Vec<u8>>,
|
||||
ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
impl DecryptPrivate {
|
||||
pub(crate) fn try_decrypt(mut self) -> Result<Option<Vec<u8>>, MpcTlsError> {
|
||||
let masked_keystream = self
|
||||
.masked_keystream
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.ok_or_else(|| MpcTlsError::record_layer("masked keystream is not ready"))?;
|
||||
|
||||
let Some(otp) = self.otp else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Recover the plaintext by removing the OTP from the masked keystream and
|
||||
// applying the ciphertext.
|
||||
let mut plaintext = self.ciphertext;
|
||||
plaintext
|
||||
.iter_mut()
|
||||
.zip(otp)
|
||||
.zip(masked_keystream)
|
||||
.for_each(|((a, b), c)| *a ^= b ^ c);
|
||||
|
||||
Ok(Some(plaintext))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct DecryptPublic {
|
||||
keystream: DecodeFutureTyped<BitVec, Vec<u8>>,
|
||||
ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
impl DecryptPublic {
|
||||
/// Decrypts the ciphertext.
|
||||
pub(crate) fn try_decrypt(mut self) -> Result<Vec<u8>, MpcTlsError> {
|
||||
let keystream = self
|
||||
.keystream
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.ok_or_else(|| MpcTlsError::record_layer("keystream is not ready"))?;
|
||||
|
||||
if keystream.len() != self.ciphertext.len() {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"keystream length does not match ciphertext: {} != {}",
|
||||
keystream.len(),
|
||||
self.ciphertext.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut plaintext = self.ciphertext;
|
||||
plaintext
|
||||
.iter_mut()
|
||||
.zip(keystream)
|
||||
.for_each(|(a, b)| *a ^= b);
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct DecryptLocal {
|
||||
pub(crate) plaintext: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl DecryptLocal {
|
||||
pub(crate) fn try_decrypt(self) -> Result<Option<Vec<u8>>, MpcTlsError> {
|
||||
Ok(self.plaintext)
|
||||
}
|
||||
}
|
||||
264
crates/mpc-tls/src/record_layer/encrypt.rs
Normal file
264
crates/mpc-tls/src/record_layer/encrypt.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use futures::TryFutureExt as _;
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
DecodeFutureTyped, Vector,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
|
||||
|
||||
use crate::{
|
||||
record_layer::{
|
||||
aead::{AeadError, ComputeTags, MpcAesGcm},
|
||||
TagData,
|
||||
},
|
||||
BoxFut, MpcTlsError,
|
||||
};
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn private(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
encrypter: &mut MpcAesGcm,
|
||||
op: &EncryptOp,
|
||||
) -> Result<
|
||||
(
|
||||
Vector<U8>,
|
||||
EncryptOutput,
|
||||
BoxFut<Result<Vec<u8>, AeadError>>,
|
||||
),
|
||||
MpcTlsError,
|
||||
> {
|
||||
let (plaintext, ciphertext) = encrypter
|
||||
.apply_keystream(vm, op.explicit_nonce.clone(), op.len)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
if let Some(data) = op.plaintext.clone() {
|
||||
vm.assign(plaintext, data)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
}
|
||||
vm.commit(plaintext).map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
let ciphertext_fut = Box::pin(
|
||||
vm.decode(ciphertext)
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.map_err(AeadError::tag),
|
||||
);
|
||||
|
||||
Ok((
|
||||
plaintext,
|
||||
EncryptOutput::Private(EncryptPrivate {
|
||||
ciphertext: vm.decode(ciphertext).map_err(MpcTlsError::record_layer)?,
|
||||
}),
|
||||
ciphertext_fut,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn public(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
encrypter: &mut MpcAesGcm,
|
||||
op: &EncryptOp,
|
||||
) -> Result<(EncryptOutput, BoxFut<Result<Vec<u8>, AeadError>>), MpcTlsError> {
|
||||
// Instead of computing the ciphertext in MPC, we only compute the keystream and
|
||||
// decode it for both parties. Each party then locally computes the ciphertext.
|
||||
|
||||
let Some(plaintext) = op.plaintext.clone() else {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"plaintext must be provided in public mode",
|
||||
));
|
||||
};
|
||||
|
||||
let keystream = encrypter
|
||||
.take_keystream(vm, op.explicit_nonce.clone(), op.len)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
let keystream_fut = vm.decode(keystream).map_err(MpcTlsError::record_layer)?;
|
||||
let ciphertext_fut = {
|
||||
let plaintext = plaintext.clone();
|
||||
Box::pin(async move {
|
||||
let mut ciphertext = keystream_fut.await.map_err(AeadError::tag)?;
|
||||
ciphertext
|
||||
.iter_mut()
|
||||
.zip(plaintext)
|
||||
.for_each(|(a, b)| *a ^= b);
|
||||
|
||||
Ok(ciphertext)
|
||||
})
|
||||
};
|
||||
|
||||
Ok((
|
||||
EncryptOutput::Public(EncryptPublic {
|
||||
keystream: vm.decode(keystream).map_err(MpcTlsError::record_layer)?,
|
||||
plaintext,
|
||||
}),
|
||||
ciphertext_fut,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
encrypter: &mut MpcAesGcm,
|
||||
ops: &[EncryptOp],
|
||||
) -> Result<(Vec<PendingEncrypt>, ComputeTags), MpcTlsError> {
|
||||
let mut outputs = Vec::new();
|
||||
let mut ciphertext_futs = Vec::new();
|
||||
let mut tags_data = Vec::new();
|
||||
for op in ops {
|
||||
match op.mode {
|
||||
EncryptMode::Private => {
|
||||
let (plaintext_ref, output, ciphertext_fut) = private(vm, encrypter, op)?;
|
||||
|
||||
outputs.push(PendingEncrypt {
|
||||
plaintext_ref: Some(plaintext_ref),
|
||||
output,
|
||||
});
|
||||
ciphertext_futs.push(ciphertext_fut);
|
||||
}
|
||||
EncryptMode::Public => {
|
||||
let (output, ciphertext_fut) = public(vm, encrypter, op)?;
|
||||
|
||||
outputs.push(PendingEncrypt {
|
||||
plaintext_ref: None,
|
||||
output,
|
||||
});
|
||||
ciphertext_futs.push(ciphertext_fut);
|
||||
}
|
||||
}
|
||||
tags_data.push(TagData {
|
||||
explicit_nonce: op.explicit_nonce.clone(),
|
||||
aad: op.aad.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let compute_tags = encrypter
|
||||
.compute_tags(vm, ciphertext_futs, tags_data)
|
||||
.map_err(MpcTlsError::record_layer)?;
|
||||
|
||||
Ok((outputs, compute_tags))
|
||||
}
|
||||
|
||||
pub(crate) struct EncryptOp {
|
||||
pub(crate) seq: u64,
|
||||
pub(crate) typ: ContentType,
|
||||
pub(crate) version: ProtocolVersion,
|
||||
pub(crate) len: usize,
|
||||
pub(crate) plaintext: Option<Vec<u8>>,
|
||||
pub(crate) explicit_nonce: Vec<u8>,
|
||||
pub(crate) aad: Vec<u8>,
|
||||
pub(crate) mode: EncryptMode,
|
||||
}
|
||||
|
||||
impl EncryptOp {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
seq: u64,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
len: usize,
|
||||
plaintext: Option<Vec<u8>>,
|
||||
explicit_nonce: Vec<u8>,
|
||||
aad: Vec<u8>,
|
||||
mode: EncryptMode,
|
||||
) -> Result<Self, MpcTlsError> {
|
||||
if let Some(plaintext) = &plaintext {
|
||||
if plaintext.len() != len {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"inconsistent plaintext length: {} != {}",
|
||||
plaintext.len(),
|
||||
len
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if mode == EncryptMode::Public && plaintext.is_none() {
|
||||
return Err(MpcTlsError::record_layer(
|
||||
"plaintext must be provided in public mode",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
seq,
|
||||
typ,
|
||||
version,
|
||||
len,
|
||||
plaintext,
|
||||
explicit_nonce,
|
||||
aad,
|
||||
mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) enum EncryptMode {
|
||||
Private,
|
||||
Public,
|
||||
}
|
||||
|
||||
pub(crate) enum EncryptOutput {
|
||||
Private(EncryptPrivate),
|
||||
Public(EncryptPublic),
|
||||
}
|
||||
|
||||
impl EncryptOutput {
|
||||
pub(crate) fn try_encrypt(self) -> Result<Vec<u8>, MpcTlsError> {
|
||||
match self {
|
||||
EncryptOutput::Private(encrypt) => encrypt.try_encrypt(),
|
||||
EncryptOutput::Public(encrypt) => encrypt.try_encrypt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PendingEncrypt {
|
||||
pub(crate) plaintext_ref: Option<Vector<U8>>,
|
||||
pub(crate) output: EncryptOutput,
|
||||
}
|
||||
|
||||
pub(crate) struct EncryptPrivate {
|
||||
ciphertext: DecodeFutureTyped<BitVec, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl EncryptPrivate {
|
||||
/// Encrypts the plaintext.
|
||||
pub(crate) fn try_encrypt(mut self) -> Result<Vec<u8>, MpcTlsError> {
|
||||
let ciphertext = self
|
||||
.ciphertext
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.ok_or_else(|| MpcTlsError::record_layer("ciphertext is not ready"))?;
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct EncryptPublic {
|
||||
keystream: DecodeFutureTyped<BitVec, Vec<u8>>,
|
||||
plaintext: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncryptPublic {
|
||||
pub(crate) fn try_encrypt(mut self) -> Result<Vec<u8>, MpcTlsError> {
|
||||
let keystream = self
|
||||
.keystream
|
||||
.try_recv()
|
||||
.map_err(MpcTlsError::record_layer)?
|
||||
.ok_or_else(|| MpcTlsError::record_layer("keystream is not ready"))?;
|
||||
|
||||
if keystream.len() != self.plaintext.len() {
|
||||
return Err(MpcTlsError::record_layer(format!(
|
||||
"keystream length does not match plaintext length: {} != {}",
|
||||
keystream.len(),
|
||||
self.plaintext.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut ciphertext = self.plaintext;
|
||||
ciphertext
|
||||
.iter_mut()
|
||||
.zip(keystream)
|
||||
.for_each(|(a, b)| *a ^= b);
|
||||
|
||||
Ok(ciphertext)
|
||||
}
|
||||
}
|
||||
19
crates/mpc-tls/src/utils.rs
Normal file
19
crates/mpc-tls/src/utils.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use crate::MpcTlsError;
|
||||
|
||||
/// Split an opaque message into its constituent parts.
|
||||
///
|
||||
/// Returns the explicit nonce, ciphertext, and tag, respectively.
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fn opaque_into_parts(
|
||||
mut msg: Vec<u8>,
|
||||
) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>), MpcTlsError> {
|
||||
let tag = msg.split_off(msg.len() - 16);
|
||||
let ciphertext = msg.split_off(8);
|
||||
let explicit_nonce = msg;
|
||||
|
||||
if explicit_nonce.len() != 8 {
|
||||
return Err(MpcTlsError::other("explicit nonce length is not 8"));
|
||||
}
|
||||
|
||||
Ok((explicit_nonce, ciphertext, tag))
|
||||
}
|
||||
169
crates/mpc-tls/tests/test.rs
Normal file
169
crates/mpc-tls/tests/test.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use mpz_ot::{
|
||||
cot::{DerandCOTReceiver, DerandCOTSender},
|
||||
ideal::rcot::ideal_rcot,
|
||||
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use tls_client::Certificate;
|
||||
use tls_client_async::bind_client;
|
||||
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore = "expensive"]
|
||||
async fn mpc_tls_test() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = Config::builder()
|
||||
.defer_decryption(false)
|
||||
.max_sent(1 << 13)
|
||||
.max_recv_online(1 << 13)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let (leader, follower) = build_pair(config);
|
||||
|
||||
tokio::try_join!(
|
||||
tokio::spawn(leader_task(leader)),
|
||||
tokio::spawn(follower_task(follower))
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn leader_task(mut leader: MpcTlsLeader) {
|
||||
leader.alloc().unwrap();
|
||||
leader.preprocess().await.unwrap();
|
||||
|
||||
let (leader_ctrl, leader_fut) = leader.run();
|
||||
tokio::spawn(async { leader_fut.await.unwrap() });
|
||||
|
||||
let mut root_store = tls_client::RootCertStore::empty();
|
||||
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
|
||||
let config = tls_client::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
let server_name = SERVER_DOMAIN.try_into().unwrap();
|
||||
|
||||
let client = tls_client::ClientConnection::new(
|
||||
Arc::new(config),
|
||||
Box::new(leader_ctrl.clone()),
|
||||
server_name,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
|
||||
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
|
||||
|
||||
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
|
||||
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: keep-alive\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 48];
|
||||
conn.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.defer_decryption().await.unwrap();
|
||||
|
||||
let msg = concat!(
|
||||
"POST /echo HTTP/1.1\r\n",
|
||||
"Host: test-server.io\r\n",
|
||||
"Connection: close\r\n",
|
||||
"Accept-Encoding: identity\r\n",
|
||||
"Content-Length: 5\r\n",
|
||||
"\r\n",
|
||||
"hello",
|
||||
"\r\n"
|
||||
);
|
||||
|
||||
conn.write_all(msg.as_bytes()).await.unwrap();
|
||||
conn.close().await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
conn.read_to_end(&mut buf).await.unwrap();
|
||||
|
||||
leader_ctrl.stop().await.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn follower_task(mut follower: MpcTlsFollower) {
|
||||
follower.alloc().unwrap();
|
||||
follower.preprocess().await.unwrap();
|
||||
follower.run().await.unwrap();
|
||||
}
|
||||
|
||||
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let (mut mt_a, mut mt_b) = test_mt_context(8);
|
||||
|
||||
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
|
||||
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
|
||||
|
||||
let delta_a = Delta::new(rng.gen());
|
||||
let delta_b = Delta::new(rng.gen());
|
||||
|
||||
let (rcot_send_a, rcot_recv_b) = ideal_rcot(rng.gen(), delta_a.into_inner());
|
||||
let (rcot_send_b, rcot_recv_a) = ideal_rcot(rng.gen(), delta_b.into_inner());
|
||||
|
||||
let mut rcot_send_a = SharedRCOTSender::new(4, rcot_send_a);
|
||||
let mut rcot_send_b = SharedRCOTSender::new(1, rcot_send_b);
|
||||
let mut rcot_recv_a = SharedRCOTReceiver::new(1, rcot_recv_a);
|
||||
let mut rcot_recv_b = SharedRCOTReceiver::new(4, rcot_recv_b);
|
||||
|
||||
let mpc_a = Arc::new(Mutex::new(Generator::new(
|
||||
DerandCOTSender::new(rcot_send_a.next().unwrap()),
|
||||
rng.gen(),
|
||||
delta_a,
|
||||
)));
|
||||
let mpc_b = Arc::new(Mutex::new(Evaluator::new(DerandCOTReceiver::new(
|
||||
rcot_recv_b.next().unwrap(),
|
||||
))));
|
||||
|
||||
let leader = MpcTlsLeader::new(
|
||||
config.clone(),
|
||||
ctx_a,
|
||||
mpc_a,
|
||||
(
|
||||
rcot_send_a.next().unwrap(),
|
||||
rcot_send_a.next().unwrap(),
|
||||
rcot_send_a.next().unwrap(),
|
||||
),
|
||||
rcot_recv_a.next().unwrap(),
|
||||
);
|
||||
|
||||
let follower = MpcTlsFollower::new(
|
||||
config,
|
||||
ctx_b,
|
||||
mpc_b,
|
||||
rcot_send_b.next().unwrap(),
|
||||
(
|
||||
rcot_recv_b.next().unwrap(),
|
||||
rcot_recv_b.next().unwrap(),
|
||||
rcot_recv_b.next().unwrap(),
|
||||
),
|
||||
);
|
||||
|
||||
(leader, follower)
|
||||
}
|
||||
@@ -115,7 +115,7 @@ String
|
||||
## Logging
|
||||
The default logging strategy of this server is set to `DEBUG` verbosity level for the crates that are useful for most debugging scenarios, i.e. using the following filtering logic:
|
||||
|
||||
`notary_server=DEBUG,tlsn_verifier=DEBUG,tls_mpc=DEBUG,tls_client_async=DEBUG`
|
||||
`notary_server=DEBUG,tlsn_verifier=DEBUG,mpc_tls=DEBUG,tls_client_async=DEBUG`
|
||||
|
||||
In the config [file](./config/config.yaml), one can toggle the verbosity level for these crates using the `level` field under `logging`.
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ pub struct NotarySigningKeyProperties {
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct LoggingProperties {
|
||||
/// Log verbosity level of the default filtering logic, which is
|
||||
/// notary_server=<level>,tlsn_verifier=<level>,tls_mpc=<level> Must be either of <https://docs.rs/tracing/latest/tracing/struct.Level.html#implementations>
|
||||
/// notary_server=<level>,tlsn_verifier=<level>,mpc_tls=<level> Must be either of <https://docs.rs/tracing/latest/tracing/struct.Level.html#implementations>
|
||||
pub level: String,
|
||||
/// Custom filtering logic, refer to the syntax here https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax
|
||||
/// This will override the default filtering logic above
|
||||
|
||||
@@ -13,7 +13,7 @@ pub fn init_tracing(config: &NotaryServerProperties) -> Result<()> {
|
||||
// Use the default filter when only verbosity level is provided
|
||||
None => {
|
||||
let level = Level::from_str(&config.logging.level)?;
|
||||
format!("notary_server={level},tlsn_verifier={level},tls_mpc={level}")
|
||||
format!("notary_server={level},tlsn_verifier={level},mpc_tls={level}")
|
||||
}
|
||||
};
|
||||
let filter_layer = EnvFilter::builder().parse(directives)?;
|
||||
|
||||
@@ -16,9 +16,11 @@ force-st = ["mpz-common/force-st"]
|
||||
[dependencies]
|
||||
tlsn-common = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-deap = { workspace = true }
|
||||
tlsn-tls-client = { workspace = true }
|
||||
tlsn-tls-client-async = { workspace = true }
|
||||
tlsn-tls-mpc = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-mpc-tls = { workspace = true }
|
||||
|
||||
serio = { workspace = true, features = ["compat"] }
|
||||
uid-mux = { workspace = true, features = ["serio"] }
|
||||
@@ -27,8 +29,11 @@ mpz-common = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-garble-core = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
mpz-ole = { workspace = true }
|
||||
mpz-ot = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-zk = { workspace = true }
|
||||
|
||||
derive_builder = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
@@ -37,4 +42,4 @@ rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
web-time = { workspace = true }
|
||||
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mpz_ot::{chou_orlandi, kos};
|
||||
use tls_mpc::{MpcTlsCommonConfig, MpcTlsLeaderConfig, TranscriptConfig};
|
||||
use mpc_tls::Config;
|
||||
use tlsn_common::config::ProtocolConfig;
|
||||
use tlsn_core::{connection::ServerName, CryptoProvider};
|
||||
|
||||
@@ -15,8 +14,6 @@ pub struct ProverConfig {
|
||||
protocol_config: ProtocolConfig,
|
||||
/// Whether the `deferred decryption` feature is toggled on from the start
|
||||
/// of the MPC-TLS connection.
|
||||
///
|
||||
/// See `defer_decryption_from_start` in [tls_mpc::MpcTlsLeaderConfig].
|
||||
#[builder(default = "true")]
|
||||
defer_decryption_from_start: bool,
|
||||
/// Cryptography provider.
|
||||
@@ -51,52 +48,11 @@ impl ProverConfig {
|
||||
self.defer_decryption_from_start
|
||||
}
|
||||
|
||||
pub(crate) fn build_mpc_tls_config(&self) -> MpcTlsLeaderConfig {
|
||||
MpcTlsLeaderConfig::builder()
|
||||
.common(
|
||||
MpcTlsCommonConfig::builder()
|
||||
.tx_config(
|
||||
TranscriptConfig::default_tx()
|
||||
.max_online_size(self.protocol_config.max_sent_data())
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.rx_config(
|
||||
TranscriptConfig::default_rx()
|
||||
.max_online_size(self.protocol_config.max_recv_data_online())
|
||||
.max_offline_size(
|
||||
self.protocol_config.max_recv_data()
|
||||
- self.protocol_config.max_recv_data_online(),
|
||||
)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.handshake_commit(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn build_base_ot_sender_config(&self) -> chou_orlandi::SenderConfig {
|
||||
chou_orlandi::SenderConfig::builder()
|
||||
.receiver_commit()
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn build_base_ot_receiver_config(&self) -> chou_orlandi::ReceiverConfig {
|
||||
chou_orlandi::ReceiverConfig::default()
|
||||
}
|
||||
|
||||
pub(crate) fn build_ot_sender_config(&self) -> kos::SenderConfig {
|
||||
kos::SenderConfig::default()
|
||||
}
|
||||
|
||||
pub(crate) fn build_ot_receiver_config(&self) -> kos::ReceiverConfig {
|
||||
kos::ReceiverConfig::builder()
|
||||
.sender_commit()
|
||||
pub(crate) fn build_mpc_tls_config(&self) -> Config {
|
||||
Config::builder()
|
||||
.defer_decryption(self.defer_decryption_from_start)
|
||||
.max_sent(self.protocol_config.max_sent_data())
|
||||
.max_recv_online(self.protocol_config.max_recv_data_online())
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use mpc_tls::MpcTlsError;
|
||||
use std::{error::Error, fmt};
|
||||
use tls_mpc::MpcTlsError;
|
||||
use tlsn_common::{encoding::EncodingError, zk_aes::ZkAesCtrError};
|
||||
|
||||
/// Error for [`Prover`](crate::Prover).
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -26,6 +27,27 @@ impl ProverError {
|
||||
Self::new(ErrorKind::Config, source)
|
||||
}
|
||||
|
||||
pub(crate) fn mpc<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::Mpc, source)
|
||||
}
|
||||
|
||||
pub(crate) fn zk<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::Zk, source)
|
||||
}
|
||||
|
||||
pub(crate) fn commit<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::Commit, source)
|
||||
}
|
||||
|
||||
pub(crate) fn attestation<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
@@ -38,7 +60,9 @@ impl ProverError {
|
||||
enum ErrorKind {
|
||||
Io,
|
||||
Mpc,
|
||||
Zk,
|
||||
Config,
|
||||
Commit,
|
||||
Attestation,
|
||||
}
|
||||
|
||||
@@ -49,7 +73,9 @@ impl fmt::Display for ProverError {
|
||||
match self.kind {
|
||||
ErrorKind::Io => f.write_str("io error")?,
|
||||
ErrorKind::Mpc => f.write_str("mpc error")?,
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::Attestation => f.write_str("attestation error")?,
|
||||
}
|
||||
|
||||
@@ -91,50 +117,14 @@ impl From<MpcTlsError> for ProverError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_ot::OTError> for ProverError {
|
||||
fn from(e: mpz_ot::OTError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
impl From<ZkAesCtrError> for ProverError {
|
||||
fn from(e: ZkAesCtrError) -> Self {
|
||||
Self::new(ErrorKind::Zk, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_ot::kos::SenderError> for ProverError {
|
||||
fn from(e: mpz_ot::kos::SenderError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_ole::OLEError> for ProverError {
|
||||
fn from(e: mpz_ole::OLEError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_ot::kos::ReceiverError> for ProverError {
|
||||
fn from(e: mpz_ot::kos::ReceiverError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::VmError> for ProverError {
|
||||
fn from(e: mpz_garble::VmError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::protocol::deap::DEAPError> for ProverError {
|
||||
fn from(e: mpz_garble::protocol::deap::DEAPError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::MemoryError> for ProverError {
|
||||
fn from(e: mpz_garble::MemoryError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_garble::ProveError> for ProverError {
|
||||
fn from(e: mpz_garble::ProveError) -> Self {
|
||||
Self::new(ErrorKind::Mpc, e)
|
||||
impl From<EncodingError> for ProverError {
|
||||
fn from(e: EncodingError) -> Self {
|
||||
Self::new(ErrorKind::Commit, e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,21 +14,20 @@ pub mod state;
|
||||
pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError};
|
||||
pub use error::ProverError;
|
||||
pub use future::ProverFuture;
|
||||
use mpz_common::Context;
|
||||
use mpz_garble_core::Delta;
|
||||
use state::{Notarize, Prove};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
|
||||
use mpz_common::Allocate;
|
||||
use mpz_garble::config::Role as DEAPRole;
|
||||
use mpz_ot::{chou_orlandi, kos};
|
||||
use rand::Rng;
|
||||
use serio::{SinkExt, StreamExt};
|
||||
use mpc_tls::{LeaderCtrl, MpcTlsLeader};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serio::SinkExt;
|
||||
use std::sync::Arc;
|
||||
use tls_client::{ClientConnection, ServerName as TlsServerName};
|
||||
use tls_client_async::{bind_client, ClosedConnection, TlsConnection};
|
||||
use tls_mpc::{build_components, LeaderCtrl, MpcTlsLeader, TlsRole};
|
||||
use tls_client_async::{bind_client, TlsConnection};
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_common::{
|
||||
mux::{attach_mux, MuxControl},
|
||||
DEAPThread, Executor, OTReceiver, OTSender, Role,
|
||||
commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role,
|
||||
};
|
||||
use tlsn_core::{
|
||||
connection::{
|
||||
@@ -37,10 +36,24 @@ use tlsn_core::{
|
||||
},
|
||||
transcript::Transcript,
|
||||
};
|
||||
use uid_mux::FramedUidMux as _;
|
||||
use tlsn_deap::Deap;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use tracing::{debug, info_span, instrument, Instrument, Span};
|
||||
|
||||
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
|
||||
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
|
||||
mpz_core::Block,
|
||||
>;
|
||||
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
|
||||
mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
|
||||
bool,
|
||||
mpz_core::Block,
|
||||
>;
|
||||
pub(crate) type Mpc =
|
||||
mpz_garble::protocol::semihonest::Generator<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
|
||||
pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
|
||||
|
||||
/// A prover instance.
|
||||
#[derive(Debug)]
|
||||
pub struct Prover<T: state::ProverState> {
|
||||
@@ -78,37 +91,43 @@ impl Prover<state::Initialized> {
|
||||
socket: S,
|
||||
) -> Result<Prover<state::Setup>, ProverError> {
|
||||
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
|
||||
|
||||
let mut io = mux_fut
|
||||
.poll_with(mux_ctrl.open_framed(b"tlsnotary"))
|
||||
.await?;
|
||||
let mut mt = build_mt_context(mux_ctrl.clone());
|
||||
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
|
||||
|
||||
// Sends protocol configuration to verifier for compatibility check.
|
||||
mux_fut
|
||||
.poll_with(io.send(self.config.protocol_config().clone()))
|
||||
.poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
|
||||
.await?;
|
||||
|
||||
// Maximum thread forking concurrency of 8.
|
||||
// TODO: Determine the optimal number of threads.
|
||||
let mut exec = Executor::new(mux_ctrl.clone(), 8);
|
||||
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
|
||||
|
||||
let (mpc_tls, vm, ot_recv) = mux_fut
|
||||
.poll_with(setup_mpc_backend(&self.config, &mux_ctrl, &mut exec))
|
||||
.await?;
|
||||
// Allocate resources for MPC-TLS in VM.
|
||||
let keys = mpc_tls.alloc()?;
|
||||
// Allocate for committing to plaintext.
|
||||
let mut zk_aes = ZkAesCtr::new(Role::Prover);
|
||||
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
|
||||
zk_aes.alloc(
|
||||
&mut (*vm.try_lock().expect("VM is not locked").zk()),
|
||||
self.config.protocol_config().max_recv_data(),
|
||||
)?;
|
||||
|
||||
let ctx = mux_fut.poll_with(exec.new_thread()).await?;
|
||||
debug!("setting up mpc-tls");
|
||||
|
||||
mux_fut.poll_with(mpc_tls.preprocess()).await?;
|
||||
|
||||
debug!("mpc-tls setup complete");
|
||||
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Setup {
|
||||
io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
mt,
|
||||
mpc_tls,
|
||||
zk_aes,
|
||||
keys,
|
||||
vm,
|
||||
ot_recv,
|
||||
ctx,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -129,13 +148,13 @@ impl Prover<state::Setup> {
|
||||
socket: S,
|
||||
) -> Result<(TlsConnection, ProverFuture), ProverError> {
|
||||
let state::Setup {
|
||||
io,
|
||||
mux_ctrl,
|
||||
mut mux_fut,
|
||||
mt,
|
||||
mpc_tls,
|
||||
mut zk_aes,
|
||||
keys,
|
||||
vm,
|
||||
ot_recv,
|
||||
ctx,
|
||||
} = self.state;
|
||||
|
||||
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
|
||||
@@ -158,79 +177,122 @@ impl Prover<state::Setup> {
|
||||
|
||||
let (conn, conn_fut) = bind_client(socket, client);
|
||||
|
||||
let start_time = web_time::UNIX_EPOCH.elapsed().unwrap().as_secs();
|
||||
let start_time = web_time::UNIX_EPOCH
|
||||
.elapsed()
|
||||
.expect("system time is available")
|
||||
.as_secs();
|
||||
|
||||
let fut = Box::pin({
|
||||
let span = self.span.clone();
|
||||
let mpc_ctrl = mpc_ctrl.clone();
|
||||
async move {
|
||||
let conn_fut = async {
|
||||
let ClosedConnection { sent, recv, .. } = mux_fut
|
||||
mux_fut
|
||||
.poll_with(conn_fut.map_err(ProverError::from))
|
||||
.await?;
|
||||
|
||||
mpc_ctrl.close_connection().await?;
|
||||
mpc_ctrl.stop().await?;
|
||||
|
||||
Ok::<_, ProverError>((sent, recv))
|
||||
Ok::<_, ProverError>(())
|
||||
};
|
||||
|
||||
let ((sent, recv), mpc_tls_data) = futures::try_join!(
|
||||
let (_, (mut ctx, mut data)) = futures::try_join!(
|
||||
conn_fut,
|
||||
mpc_fut.in_current_span().map_err(ProverError::from)
|
||||
)?;
|
||||
|
||||
{
|
||||
let mut vm = vm.try_lock().expect("VM should not be locked");
|
||||
|
||||
// Prove received plaintext. Prover drops the proof output, as they trust
|
||||
// themselves.
|
||||
_ = commit_records(
|
||||
&mut (*vm.zk()),
|
||||
&mut zk_aes,
|
||||
data.transcript
|
||||
.recv
|
||||
.iter_mut()
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
)
|
||||
.map_err(ProverError::zk)?;
|
||||
|
||||
debug!("finalizing mpc");
|
||||
|
||||
// Finalize DEAP and execute the plaintext proofs.
|
||||
mux_fut
|
||||
.poll_with(vm.finalize(&mut ctx))
|
||||
.await
|
||||
.map_err(ProverError::mpc)?;
|
||||
|
||||
debug!("mpc finalized");
|
||||
}
|
||||
|
||||
let transcript = data
|
||||
.transcript
|
||||
.to_transcript()
|
||||
.expect("transcript is complete");
|
||||
let transcript_refs = data
|
||||
.transcript
|
||||
.to_transcript_refs()
|
||||
.expect("transcript is complete");
|
||||
|
||||
let connection_info = ConnectionInfo {
|
||||
time: start_time,
|
||||
version: mpc_tls_data
|
||||
version: data
|
||||
.protocol_version
|
||||
.try_into()
|
||||
.expect("only supported version should have been accepted"),
|
||||
transcript_length: TranscriptLength {
|
||||
sent: sent.len() as u32,
|
||||
received: recv.len() as u32,
|
||||
sent: transcript.sent().len() as u32,
|
||||
received: transcript.received().len() as u32,
|
||||
},
|
||||
};
|
||||
|
||||
let server_cert_data = ServerCertData {
|
||||
certs: mpc_tls_data
|
||||
.server_cert_details
|
||||
.cert_chain()
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|c| c.into())
|
||||
.collect(),
|
||||
sig: ServerSignature {
|
||||
scheme: mpc_tls_data
|
||||
.server_kx_details
|
||||
.kx_sig()
|
||||
.scheme
|
||||
.try_into()
|
||||
.expect("only supported signature scheme should have been accepted"),
|
||||
sig: mpc_tls_data.server_kx_details.kx_sig().sig.0.clone(),
|
||||
},
|
||||
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
|
||||
client_random: mpc_tls_data.client_random.0,
|
||||
server_random: mpc_tls_data.server_random.0,
|
||||
server_ephemeral_key: mpc_tls_data
|
||||
.server_public_key
|
||||
.try_into()
|
||||
.expect("only supported key scheme should have been accepted"),
|
||||
}),
|
||||
};
|
||||
let server_cert_data =
|
||||
ServerCertData {
|
||||
certs: data
|
||||
.server_cert_details
|
||||
.cert_chain()
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|c| c.into())
|
||||
.collect(),
|
||||
sig: ServerSignature {
|
||||
scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
|
||||
"only supported signature scheme should have been accepted",
|
||||
),
|
||||
sig: data.server_kx_details.kx_sig().sig.0.clone(),
|
||||
},
|
||||
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
|
||||
client_random: data.client_random.0,
|
||||
server_random: data.server_random.0,
|
||||
server_ephemeral_key: data
|
||||
.server_key
|
||||
.try_into()
|
||||
.expect("only supported key scheme should have been accepted"),
|
||||
}),
|
||||
};
|
||||
|
||||
// Pull out ZK VM
|
||||
let (_, vm) = Arc::into_inner(vm)
|
||||
.expect("vm should have only 1 reference")
|
||||
.into_inner()
|
||||
.into_inner();
|
||||
|
||||
Ok(Prover {
|
||||
config: self.config,
|
||||
span: self.span,
|
||||
state: state::Closed {
|
||||
io,
|
||||
mux_ctrl,
|
||||
mux_fut,
|
||||
vm,
|
||||
ot_recv,
|
||||
mt,
|
||||
ctx,
|
||||
_keys: keys,
|
||||
vm,
|
||||
connection_info,
|
||||
server_cert_data,
|
||||
transcript: Transcript::new(sent, recv),
|
||||
transcript,
|
||||
transcript_refs,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -279,123 +341,55 @@ impl Prover<state::Closed> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs a setup of the various MPC subprotocols.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn setup_mpc_backend(
|
||||
config: &ProverConfig,
|
||||
mux: &MuxControl,
|
||||
exec: &mut Executor,
|
||||
) -> Result<(MpcTlsLeader, DEAPThread, OTReceiver), ProverError> {
|
||||
debug!("starting MPC backend setup");
|
||||
fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
|
||||
let mut rng = thread_rng();
|
||||
let delta = Delta::new(rng.gen());
|
||||
|
||||
let mut ot_sender = kos::Sender::new(
|
||||
config.build_ot_sender_config(),
|
||||
chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()),
|
||||
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
|
||||
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
|
||||
let rcot_send = mpz_ot::kos::Sender::new(
|
||||
mpz_ot::kos::SenderConfig::default(),
|
||||
delta.into_inner(),
|
||||
base_ot_recv,
|
||||
);
|
||||
ot_sender.alloc(config.protocol_config().ot_sender_setup_count(Role::Prover));
|
||||
|
||||
let mut ot_receiver = kos::Receiver::new(
|
||||
config.build_ot_receiver_config(),
|
||||
chou_orlandi::Sender::new(config.build_base_ot_sender_config()),
|
||||
);
|
||||
ot_receiver.alloc(
|
||||
config
|
||||
.protocol_config()
|
||||
.ot_receiver_setup_count(Role::Prover),
|
||||
let rcot_recv =
|
||||
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
|
||||
let rcot_recv = mpz_ot::ferret::Receiver::new(
|
||||
mpz_ot::ferret::FerretConfig::builder()
|
||||
.lpn_type(mpz_ot::ferret::LpnType::Regular)
|
||||
.build()
|
||||
.expect("ferret config is valid"),
|
||||
rng.gen(),
|
||||
rcot_recv,
|
||||
);
|
||||
|
||||
let ot_sender = OTSender::new(ot_sender);
|
||||
let ot_receiver = OTReceiver::new(ot_receiver);
|
||||
let mut rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(4, rcot_send);
|
||||
let mut rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(2, rcot_recv);
|
||||
|
||||
let (
|
||||
ctx_vm,
|
||||
ctx_ke_0,
|
||||
ctx_ke_1,
|
||||
ctx_prf_0,
|
||||
ctx_prf_1,
|
||||
ctx_encrypter_block_cipher,
|
||||
ctx_encrypter_stream_cipher,
|
||||
ctx_encrypter_ghash,
|
||||
ctx_encrypter,
|
||||
ctx_decrypter_block_cipher,
|
||||
ctx_decrypter_stream_cipher,
|
||||
ctx_decrypter_ghash,
|
||||
ctx_decrypter,
|
||||
) = futures::try_join!(
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
exec.new_thread(),
|
||||
)?;
|
||||
|
||||
let vm = DEAPThread::new(
|
||||
DEAPRole::Leader,
|
||||
rand::rngs::OsRng.gen(),
|
||||
ctx_vm,
|
||||
ot_sender.clone(),
|
||||
ot_receiver.clone(),
|
||||
let mpc = Mpc::new(
|
||||
mpz_ot::cot::DerandCOTSender::new(rcot_send.next().expect("enough senders are available")),
|
||||
rng.gen(),
|
||||
delta,
|
||||
);
|
||||
|
||||
let mpc_tls_config = config.build_mpc_tls_config();
|
||||
let (ke, prf, encrypter, decrypter) = build_components(
|
||||
TlsRole::Leader,
|
||||
mpc_tls_config.common(),
|
||||
ctx_ke_0,
|
||||
ctx_encrypter,
|
||||
ctx_decrypter,
|
||||
ctx_encrypter_ghash,
|
||||
ctx_decrypter_ghash,
|
||||
vm.new_thread(ctx_ke_1, ot_sender.clone(), ot_receiver.clone())?,
|
||||
vm.new_thread(ctx_prf_0, ot_sender.clone(), ot_receiver.clone())?,
|
||||
vm.new_thread(ctx_prf_1, ot_sender.clone(), ot_receiver.clone())?,
|
||||
vm.new_thread(
|
||||
ctx_encrypter_block_cipher,
|
||||
ot_sender.clone(),
|
||||
ot_receiver.clone(),
|
||||
)?,
|
||||
vm.new_thread(
|
||||
ctx_decrypter_block_cipher,
|
||||
ot_sender.clone(),
|
||||
ot_receiver.clone(),
|
||||
)?,
|
||||
vm.new_thread(
|
||||
ctx_encrypter_stream_cipher,
|
||||
ot_sender.clone(),
|
||||
ot_receiver.clone(),
|
||||
)?,
|
||||
vm.new_thread(
|
||||
ctx_decrypter_stream_cipher,
|
||||
ot_sender.clone(),
|
||||
ot_receiver.clone(),
|
||||
)?,
|
||||
ot_sender.clone(),
|
||||
ot_receiver.clone(),
|
||||
);
|
||||
let zk = Zk::new(rcot_recv.next().expect("enough receivers are available"));
|
||||
|
||||
let channel = mux.open_framed(b"mpc_tls").await?;
|
||||
let mut mpc_tls = MpcTlsLeader::new(
|
||||
mpc_tls_config,
|
||||
Box::new(StreamExt::compat_stream(channel)),
|
||||
ke,
|
||||
prf,
|
||||
encrypter,
|
||||
decrypter,
|
||||
);
|
||||
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
|
||||
|
||||
mpc_tls.setup().await?;
|
||||
|
||||
debug!("MPC backend setup complete");
|
||||
|
||||
Ok((mpc_tls, vm, ot_receiver))
|
||||
(
|
||||
vm.clone(),
|
||||
MpcTlsLeader::new(
|
||||
config.build_mpc_tls_config(),
|
||||
ctx,
|
||||
vm,
|
||||
(
|
||||
rcot_send.next().expect("enough senders are available"),
|
||||
rcot_send.next().expect("enough senders are available"),
|
||||
rcot_send.next().expect("enough senders are available"),
|
||||
),
|
||||
rcot_recv.next().expect("enough receivers are available"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// A controller for the prover.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user