perf: MPC-TLS upgrade (#698)

* fix: add new Cargo.toml

* (alpha.8) - Refactor key-exchange crate (#685)

* refactor(key-exchange): adapt key-exchange to new vm

* fix: fix feature flags

* simplify

* delete old msg module

* clean up error

---------

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>

* (alpha.8) - Refactor prf crate (#684)

* refactor(prf): adapt prf to new mpz vm

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>

* refactor: remove preprocessing bench

* fix: fix feature flags

* clean up attributes

---------

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>

* refactor: key exchange interface (#688)

* refactor: prf interface (#689)

* (alpha.8) - Create cipher crate (#683)

* feat(cipher): add cipher crate, replacing stream/block cipher and aead

* delete old config module

* remove mpz generics

---------

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>

* refactor(core): decouple encoder from mpz (#692)

* WIP: Adding new encoding logic...

* feat: add new encoder

* add feedback

* rename conversions

* feat: DEAP VM (#690)

* feat: DEAP VM

* use rangeset, add desync guard

* move MPC execution up in finalization

* refactor: MPC-TLS (#693)

* refactor: MPC-TLS

Co-authored-by: th4s <th4s@metavoid.xyz>

* output key references

* bump deps

---------

Co-authored-by: th4s <th4s@metavoid.xyz>

* refactor: prover + verifier (#696)

* refactor: wasm crates (#697)

* chore: appease clippy (#699)

* chore: rustfmt

* chore: appease clippy more

* chore: more rustfmt!

* chore: clippy is stubborn

* chore: rustfmt sorting change is annoying!

* fix: remove wasm bundling hack

* fix: aes ctr test

* chore: clippy

* fix: flush client when sending close notify

* fix: failing tests

---------

Co-authored-by: th4s <th4s@metavoid.xyz>
This commit is contained in:
sinu.eth
2025-02-25 13:51:28 -08:00
committed by GitHub
parent 25d65734c0
commit cb13169b82
138 changed files with 11394 additions and 9785 deletions

View File

@@ -6,13 +6,11 @@ members = [
"crates/benches/browser/wasm",
"crates/benches/library",
"crates/common",
"crates/components/aead",
"crates/components/block-cipher",
"crates/components/deap",
"crates/components/cipher",
"crates/components/hmac-sha256",
"crates/components/hmac-sha256-circuits",
"crates/components/key-exchange",
"crates/components/stream-cipher",
"crates/components/universal-hash",
"crates/core",
"crates/data-fixtures",
"crates/examples",
@@ -28,7 +26,7 @@ members = [
"crates/tls/client",
"crates/tls/client-async",
"crates/tls/core",
"crates/tls/mpc",
"crates/mpc-tls",
"crates/tls/server-fixture",
"crates/verifier",
"crates/wasm",
@@ -44,45 +42,47 @@ opt-level = 1
notary-client = { path = "crates/notary/client" }
notary-server = { path = "crates/notary/server" }
tls-server-fixture = { path = "crates/tls/server-fixture" }
tlsn-aead = { path = "crates/components/aead" }
tlsn-cipher = { path = "crates/components/cipher" }
tlsn-benches-browser-core = { path = "crates/benches/browser/core" }
tlsn-benches-browser-native = { path = "crates/benches/browser/native" }
tlsn-benches-library = { path = "crates/benches/library" }
tlsn-block-cipher = { path = "crates/components/block-cipher" }
tlsn-common = { path = "crates/common" }
tlsn-core = { path = "crates/core" }
tlsn-data-fixtures = { path = "crates/data-fixtures" }
tlsn-deap = { path = "crates/components/deap" }
tlsn-formats = { path = "crates/formats" }
tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" }
tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" }
tlsn-key-exchange = { path = "crates/components/key-exchange" }
tlsn-mpc-tls = { path = "crates/mpc-tls" }
tlsn-prover = { path = "crates/prover" }
tlsn-server-fixture = { path = "crates/server-fixture/server" }
tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" }
tlsn-stream-cipher = { path = "crates/components/stream-cipher" }
tlsn-tls-backend = { path = "crates/tls/backend" }
tlsn-tls-client = { path = "crates/tls/client" }
tlsn-tls-client-async = { path = "crates/tls/client-async" }
tlsn-tls-core = { path = "crates/tls/core" }
tlsn-tls-mpc = { path = "crates/tls/mpc" }
tlsn-universal-hash = { path = "crates/components/universal-hash" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
tlsn-verifier = { path = "crates/verifier" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.1" }
serio = { version = "0.1" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
uid-mux = { version = "0.1", features = ["serio"] }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "0040a00" }
serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8555275" }
aes = { version = "0.8" }
aes-gcm = { version = "0.9" }
@@ -113,6 +113,7 @@ http = { version = "1.1" }
http-body-util = { version = "0.1" }
hyper = { version = "1.1" }
hyper-util = { version = "0.1" }
itybity = { version = "0.2" }
k256 = { version = "0.13" }
log = { version = "0.4" }
once_cell = { version = "1.19" }
@@ -123,6 +124,7 @@ pin-project-lite = { version = "0.2" }
rand = { version = "0.8" }
rand_chacha = { version = "0.3" }
rand_core = { version = "0.6" }
rayon = { version = "1.10" }
regex = { version = "1.10" }
ring = { version = "0.17" }
rs_merkle = { git = "https://github.com/tlsnotary/rs-merkle.git", rev = "85f3e82" }
@@ -141,6 +143,7 @@ tokio-util = { version = "0.7" }
tracing = { version = "0.1" }
tracing-subscriber = { version = "0.3" }
uuid = { version = "1.4" }
web-spawn = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "2d93c56" }
web-time = { version = "0.2" }
webpki = { version = "0.22" }
webpki-roots = { version = "0.26" }

View File

@@ -9,8 +9,5 @@ fi
# Run the benchmark binary.
../../../target/release/bench
# Run the benchmark binary in memory profiling mode.
../../../target/release/bench --memory-profiling
# Plot the results.
../../../target/release/plot metrics.csv

View File

@@ -1,71 +1,5 @@
use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role};
use mpz_common::executor::test_st_executor;
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory};
use mpz_ot::ideal::ot::ideal_ot;
use hmac_sha256::build_circuits;
pub async fn preprocess_prf_circuits() {
let pms = [42u8; 32];
let client_random = [69u8; 32];
let (leader_ctx_0, follower_ctx_0) = test_st_executor(128);
let (leader_ctx_1, follower_ctx_1) = test_st_executor(128);
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
let leader_thread_0 = DEAPThread::new(
DEAPRole::Leader,
[0u8; 32],
leader_ctx_0,
leader_ot_send_0,
leader_ot_recv_0,
);
let leader_thread_1 = leader_thread_0
.new_thread(leader_ctx_1, leader_ot_send_1, leader_ot_recv_1)
.unwrap();
let follower_thread_0 = DEAPThread::new(
DEAPRole::Follower,
[0u8; 32],
follower_ctx_0,
follower_ot_send_0,
follower_ot_recv_0,
);
let follower_thread_1 = follower_thread_0
.new_thread(follower_ctx_1, follower_ot_send_1, follower_ot_recv_1)
.unwrap();
// Set up public PMS for testing.
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
let follower_pms = follower_thread_0
.new_public_input::<[u8; 32]>("pms")
.unwrap();
leader_thread_0.assign(&leader_pms, pms).unwrap();
let mut leader = MpcPrf::new(
PrfConfig::builder().role(Role::Leader).build().unwrap(),
leader_thread_0,
leader_thread_1,
);
let mut follower = MpcPrf::new(
PrfConfig::builder().role(Role::Follower).build().unwrap(),
follower_thread_0,
follower_thread_1,
);
futures::join!(
async {
leader.setup(leader_pms).await.unwrap();
leader.set_client_random(Some(client_random)).await.unwrap();
leader.preprocess().await.unwrap();
},
async {
follower.setup(follower_pms).await.unwrap();
follower.set_client_random(None).await.unwrap();
follower.preprocess().await.unwrap();
}
);
build_circuits().await;
}

View File

@@ -105,7 +105,7 @@ async fn run_instance<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
.build()?,
);
_ = verifier.verify(io.compat()).await?;
verifier.verify(io.compat()).await?;
println!("verifier done");

View File

@@ -20,12 +20,11 @@ wasm-bindgen = { version = "0.2.87" }
wasm-bindgen-futures = { version = "0.4.37" }
web-time = { workspace = true }
# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958
ws_stream_wasm = { version = "0.7.4", git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51", features = ["tokio_io"]}
[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen-rayon = { version = "1.2", features = ["no-bundler"] }
ws_stream_wasm = { version = "0.7.4", git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51", features = [
"tokio_io",
] }
[package.metadata.wasm-pack.profile.release]
# Note: these wasm-pack options should match those in crates/wasm/Cargo.toml
opt-level = "z"
opt-level = "z"
wasm-opt = true

View File

@@ -1,6 +1,6 @@
import * as Comlink from "./comlink.mjs";
import init, { wasm_main, initThreadPool, init_logging } from './tlsn_benches_browser_wasm.js';
import init, { wasm_main, initialize } from './tlsn_benches_browser_wasm.js';
class Worker {
async init() {
@@ -12,7 +12,7 @@ class Worker {
// crate_filters: undefined,
// span_events: undefined,
// });
await initThreadPool(navigator.hardwareConcurrency);
await initialize({ thread_count: navigator.hardwareConcurrency });
} catch (e) {
console.error(e);
throw e;
@@ -20,12 +20,12 @@ class Worker {
}
async run(
ws_ip,
ws_port,
wasm_to_server_port,
wasm_to_verifier_port,
wasm_to_native_port
) {
ws_ip,
ws_port,
wasm_to_server_port,
wasm_to_verifier_port,
wasm_to_native_port
) {
try {
await wasm_main(
ws_ip,

View File

@@ -1,3 +1,5 @@
#![cfg(target_arch = "wasm32")]
//! Contains the wasm component of the browser prover.
//!
//! Conceptually the browser prover consists of the native and the wasm
@@ -9,13 +11,10 @@ use tlsn_benches_browser_core::{
FramedIo,
};
use tlsn_benches_library::run_prover;
pub use tlsn_wasm::init_logging;
use anyhow::Result;
use tracing::info;
use wasm_bindgen::prelude::*;
#[cfg(target_arch = "wasm32")]
pub use wasm_bindgen_rayon::init_thread_pool;
use web_time::Instant;
use ws_stream_wasm::WsMeta;

View File

@@ -6,7 +6,7 @@ use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
use anyhow::Context;
use async_trait::async_trait;
use futures::{future::join, AsyncReadExt as _, AsyncWriteExt as _};
use futures::{future::try_join, AsyncReadExt as _, AsyncWriteExt as _, TryFutureExt};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
@@ -106,12 +106,14 @@ pub async fn run_prover(
let mut response = vec![];
mpc_tls_connection.read_to_end(&mut response).await?;
dbg!(response.len());
Ok::<(), anyhow::Error>(())
};
let (prover_task, _) = join(prover_fut, tls_fut).await;
let (prover_task, _) = try_join(prover_fut.map_err(anyhow::Error::from), tls_fut).await?;
let mut prover = prover_task?.start_prove();
let mut prover = prover_task.start_prove();
let (sent_len, recv_len) = prover.transcript().len();
prover

View File

@@ -9,13 +9,21 @@ default = []
[dependencies]
tlsn-core = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-cipher = { workspace = true }
tlsn-utils = { workspace = true }
mpz-core = { workspace = true }
mpz-common = { workspace = true }
mpz-garble = { workspace = true }
mpz-memory-core = { workspace = true }
mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
async-trait = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
once_cell = { workspace = true }
opaque-debug = { workspace = true }
serio = { workspace = true, features = ["codec", "bincode"] }
thiserror = { workspace = true }
tracing = { workspace = true }
@@ -23,5 +31,9 @@ uid-mux = { workspace = true, features = ["serio"] }
serde = { workspace = true, features = ["derive"] }
semver = { version = "1.0", features = ["serde"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen = { version = "0.2" }
web-spawn = { workspace = true }
[dev-dependencies]
rstest = { workspace = true }

108
crates/common/src/commit.rs Normal file
View File

@@ -0,0 +1,108 @@
//! Plaintext commitment and proof of encryption.
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{binary::Binary, DecodeFutureTyped};
use mpz_vm_core::{prelude::*, Vm};
use crate::{
transcript::Record,
zk_aes::{ZkAesCtr, ZkAesCtrError},
Role,
};
/// Commits the plaintext of the provided records, returning a proof of
/// encryption.
///
/// Writes the plaintext VM reference to the provided records.
pub fn commit_records<'record>(
vm: &mut dyn Vm<Binary>,
aes: &mut ZkAesCtr,
records: impl IntoIterator<Item = &'record mut Record>,
) -> Result<RecordProof, RecordProofError> {
let mut ciphertexts = Vec::new();
for record in records {
if record.plaintext_ref.is_some() {
return Err(ErrorRepr::PlaintextRefAlreadySet.into());
}
let (plaintext_ref, ciphertext_ref) = aes
.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())
.map_err(ErrorRepr::Aes)?;
record.plaintext_ref = Some(plaintext_ref);
if let Role::Prover = aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(ErrorRepr::MissingPlaintext.into());
};
vm.assign(plaintext_ref, plaintext)
.map_err(RecordProofError::vm)?;
}
vm.commit(plaintext_ref).map_err(RecordProofError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(RecordProofError::vm)?;
ciphertexts.push((ciphertext, record.ciphertext.clone()));
}
Ok(RecordProof { ciphertexts })
}
/// Proof of encryption.
#[derive(Debug)]
#[must_use]
#[allow(clippy::type_complexity)]
pub struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub fn verify(self) -> Result<(), RecordProofError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(RecordProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?;
if ciphertext != expected {
return Err(ErrorRepr::InvalidCiphertext.into());
}
}
Ok(())
}
}
/// Error for [`RecordProof`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct RecordProofError(#[from] ErrorRepr);
impl RecordProofError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("record proof error: {0}")]
enum ErrorRepr {
#[error("VM error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk aes error: {0}")]
Aes(ZkAesCtrError),
#[error("plaintext is missing")]
MissingPlaintext,
#[error("plaintext reference is already set")]
PlaintextRefAlreadySet,
#[error("ciphertext was not decoded")]
NotDecoded,
#[error("ciphertext does not match expected")]
InvalidCiphertext,
}

View File

@@ -5,16 +5,8 @@ use semver::Version;
use serde::{Deserialize, Serialize};
use std::error::Error;
use crate::Role;
// Extra cushion room, eg. for sharing J0 blocks.
const EXTRA_OTS: usize = 16384;
const OTS_PER_BYTE_SENT: usize = 8;
// Without deferred decryption we use 16, with it we use 8.
const OTS_PER_BYTE_RECV_ONLINE: usize = 16;
const OTS_PER_BYTE_RECV_DEFER: usize = 8;
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
// Current version that is running.
static VERSION: Lazy<Version> = Lazy::new(|| {
@@ -31,7 +23,7 @@ pub struct ProtocolConfig {
max_sent_data: usize,
/// Maximum number of bytes that can be decrypted online, i.e. while the
/// MPC-TLS connection is active.
#[builder(default = "0")]
#[builder(default = "DEFAULT_MAX_RECV_ONLINE")]
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
@@ -71,26 +63,6 @@ impl ProtocolConfig {
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns OT sender setup count.
pub fn ot_sender_setup_count(&self, role: Role) -> usize {
ot_send_estimate(
role,
self.max_sent_data,
self.max_recv_data_online,
self.max_recv_data,
)
}
/// Returns OT receiver setup count.
pub fn ot_receiver_setup_count(&self, role: Role) -> usize {
ot_recv_estimate(
role,
self.max_sent_data,
self.max_recv_data_online,
self.max_recv_data,
)
}
}
/// Protocol configuration validator used by checker (i.e. verifier) to perform
@@ -222,42 +194,6 @@ enum ErrorKind {
Version,
}
/// Returns an estimate of the number of OTs that will be sent.
pub fn ot_send_estimate(
role: Role,
max_sent_data: usize,
max_recv_data_online: usize,
max_recv_data: usize,
) -> usize {
match role {
Role::Prover => EXTRA_OTS,
Role::Verifier => {
EXTRA_OTS
+ (max_sent_data * OTS_PER_BYTE_SENT)
+ (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE)
+ ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER)
}
}
}
/// Returns an estimate of the number of OTs that will be received.
pub fn ot_recv_estimate(
role: Role,
max_sent_data: usize,
max_recv_data_online: usize,
max_recv_data: usize,
) -> usize {
match role {
Role::Prover => {
EXTRA_OTS
+ (max_sent_data * OTS_PER_BYTE_SENT)
+ (max_recv_data_online * OTS_PER_BYTE_RECV_ONLINE)
+ ((max_recv_data - max_recv_data_online) * OTS_PER_BYTE_RECV_DEFER)
}
Role::Verifier => EXTRA_OTS,
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@@ -0,0 +1,21 @@
//! Execution context.
use mpz_common::context::Multithread;
use crate::mux::MuxControl;
/// Maximum concurrency for multi-threaded context.
pub const MAX_CONCURRENCY: usize = 8;
/// Builds a multi-threaded context with the given muxer.
pub fn build_mt_context(mux: MuxControl) -> Multithread {
let builder = Multithread::builder().mux(mux).concurrency(MAX_CONCURRENCY);
#[cfg(target_arch = "wasm32")]
let builder = builder.spawn_handler(|f| {
let _ = web_spawn::spawn(f);
Ok(())
});
builder.build().unwrap()
}

View File

@@ -0,0 +1,173 @@
//! Encoding commitment protocol.
use mpz_common::Context;
use mpz_core::Block;
use serde::{Deserialize, Serialize};
use serio::{stream::IoStreamExt, SinkExt};
use tlsn_core::transcript::{
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider},
Direction, Idx,
};
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
#[derive(Debug, Serialize, Deserialize)]
struct Encodings {
sent: Vec<u8>,
recv: Vec<u8>,
}
/// Transfers the encodings using the provided seed and keys.
pub async fn transfer(
ctx: &mut Context,
secret: &EncoderSecret,
sent_keys: impl IntoIterator<Item = &'_ Block>,
recv_keys: impl IntoIterator<Item = &'_ Block>,
) -> Result<(), EncodingError> {
let encoder = new_encoder(secret);
let sent_keys: Vec<u8> = sent_keys
.into_iter()
.flat_map(|key| key.as_bytes())
.copied()
.collect();
let recv_keys: Vec<u8> = recv_keys
.into_iter()
.flat_map(|key| key.as_bytes())
.copied()
.collect();
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
let mut sent_encoding = encoder.encode_idx(
Direction::Sent,
&Idx::new(0..sent_keys.len() / ENCODING_SIZE),
);
let mut recv_encoding = encoder.encode_idx(
Direction::Received,
&Idx::new(0..recv_keys.len() / ENCODING_SIZE),
);
sent_encoding
.iter_mut()
.zip(sent_keys)
.for_each(|(enc, key)| *enc ^= key);
recv_encoding
.iter_mut()
.zip(recv_keys)
.for_each(|(enc, key)| *enc ^= key);
ctx.io_mut()
.send(Encodings {
sent: sent_encoding,
recv: recv_encoding,
})
.await?;
Ok(())
}
/// Receives the encodings using the provided MACs.
pub async fn receive(
ctx: &mut Context,
sent_macs: impl IntoIterator<Item = &'_ Block>,
recv_macs: impl IntoIterator<Item = &'_ Block>,
) -> Result<impl EncodingProvider, EncodingError> {
let Encodings { mut sent, mut recv } = ctx.io_mut().expect_next().await?;
let sent_macs: Vec<u8> = sent_macs
.into_iter()
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
let recv_macs: Vec<u8> = recv_macs
.into_iter()
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
if sent.len() != sent_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_macs.len(),
got: sent.len(),
}
.into());
}
if recv.len() != recv_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_macs.len(),
got: recv.len(),
}
.into());
}
sent.iter_mut()
.zip(sent_macs)
.for_each(|(enc, mac)| *enc ^= mac);
recv.iter_mut()
.zip(recv_macs)
.for_each(|(enc, mac)| *enc ^= mac);
Ok(Provider { sent, recv })
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
recv: Vec<u8>,
}
impl EncodingProvider for Provider {
fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option<Vec<u8>> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let mut encoding = Vec::with_capacity(idx.len());
for range in idx.iter_ranges() {
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
if end > encodings.len() {
return None;
}
encoding.extend_from_slice(&encodings[start..end]);
}
Some(encoding)
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
IncorrectMacCount {
direction: Direction,
expected: usize,
got: usize,
},
}
impl From<std::io::Error> for EncodingError {
fn from(value: std::io::Error) -> Self {
Self(ErrorRepr::Io(value))
}
}

View File

@@ -4,34 +4,19 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub mod commit;
pub mod config;
pub mod context;
pub mod encoding;
pub mod msg;
pub mod mux;
use serio::codec::Codec;
use crate::mux::MuxControl;
/// IO type.
pub type Io = <serio::codec::Bincode as Codec<uid_mux::yamux::Stream>>::Framed;
/// Base OT sender.
pub type BaseOTSender = mpz_ot::chou_orlandi::Sender;
/// Base OT receiver.
pub type BaseOTReceiver = mpz_ot::chou_orlandi::Receiver;
/// OT sender.
pub type OTSender = mpz_ot::kos::SharedSender<BaseOTReceiver>;
/// OT receiver.
pub type OTReceiver = mpz_ot::kos::SharedReceiver<BaseOTSender>;
/// MPC executor.
pub type Executor = mpz_common::executor::MTExecutor<MuxControl>;
/// MPC thread context.
pub type Context = mpz_common::executor::MTContext<MuxControl, Io>;
/// DEAP thread.
pub type DEAPThread = mpz_garble::protocol::deap::DEAPThread<Context, OTSender, OTReceiver>;
pub mod transcript;
pub mod zk_aes;
/// The party's role in the TLSN protocol.
///
/// A Notary is classified as a Verifier.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
/// The prover.
Prover,

View File

@@ -6,16 +6,15 @@ use futures::{
future::{FusedFuture, FutureExt},
AsyncRead, AsyncWrite, Future,
};
use serio::codec::Bincode;
use tracing::error;
use uid_mux::{yamux, FramedMux};
use uid_mux::yamux;
use crate::Role;
/// Multiplexer supporting unique deterministic stream IDs.
pub type Mux<Io> = yamux::Yamux<Io>;
/// Multiplexer controller providing streams with a codec attached.
pub type MuxControl = FramedMux<yamux::YamuxCtrl, Bincode>;
/// Multiplexer controller providing streams.
pub type MuxControl = yamux::YamuxCtrl;
/// Multiplexer future which must be polled for the muxer to make progress.
pub struct MuxFuture(
@@ -73,7 +72,7 @@ pub fn attach_mux<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
role: Role,
) -> (MuxFuture, MuxControl) {
let mut mux_config = yamux::Config::default();
mux_config.set_max_num_streams(64);
mux_config.set_max_num_streams(32);
let mux_role = match role {
Role::Prover => yamux::Mode::Client,
@@ -81,10 +80,10 @@ pub fn attach_mux<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
};
let mux = Mux::new(socket, mux_config, mux_role);
let ctrl = FramedMux::new(mux.control(), Bincode);
let ctrl = mux.control();
if let Role::Prover = role {
ctrl.mux().alloc(64);
ctrl.alloc(32);
}
(MuxFuture(Box::new(mux.into_future().fuse())), ctrl)

View File

@@ -0,0 +1,169 @@
//! TLS transcript.
use mpz_memory_core::{binary::U8, Vector};
use tls_core::msgs::enums::ContentType;
use tlsn_core::transcript::{Direction, Idx, Transcript};
use utils::range::Intersection;
/// A transcript of sent and received TLS records.
#[derive(Debug, Default, Clone)]
pub struct TlsTranscript {
/// Records sent by the prover.
pub sent: Vec<Record>,
/// Records received by the prover.
pub recv: Vec<Record>,
}
impl TlsTranscript {
/// Returns the application data transcript.
pub fn to_transcript(&self) -> Result<Transcript, IncompleteTranscript> {
let mut sent = Vec::new();
let mut recv = Vec::new();
for record in self
.sent
.iter()
.filter(|record| record.typ == ContentType::ApplicationData)
{
let plaintext = record
.plaintext
.as_ref()
.ok_or(IncompleteTranscript {})?
.clone();
sent.extend_from_slice(&plaintext);
}
for record in self
.recv
.iter()
.filter(|record| record.typ == ContentType::ApplicationData)
{
let plaintext = record
.plaintext
.as_ref()
.ok_or(IncompleteTranscript {})?
.clone();
recv.extend_from_slice(&plaintext);
}
Ok(Transcript::new(sent, recv))
}
/// Returns the application data transcript references.
pub fn to_transcript_refs(&self) -> Result<TranscriptRefs, IncompleteTranscript> {
let mut sent = Vec::new();
let mut recv = Vec::new();
for record in self
.sent
.iter()
.filter(|record| record.typ == ContentType::ApplicationData)
{
let plaintext_ref = record
.plaintext_ref
.as_ref()
.ok_or(IncompleteTranscript {})?;
sent.push(*plaintext_ref);
}
for record in self
.recv
.iter()
.filter(|record| record.typ == ContentType::ApplicationData)
{
let plaintext_ref = record
.plaintext_ref
.as_ref()
.ok_or(IncompleteTranscript {})?;
recv.push(*plaintext_ref);
}
Ok(TranscriptRefs { sent, recv })
}
}
/// A TLS record.
#[derive(Clone)]
pub struct Record {
/// Sequence number.
pub seq: u64,
/// Content type.
pub typ: ContentType,
/// Plaintext.
pub plaintext: Option<Vec<u8>>,
/// VM reference to the plaintext.
pub plaintext_ref: Option<Vector<U8>>,
/// Explicit nonce.
pub explicit_nonce: Vec<u8>,
/// Ciphertext.
pub ciphertext: Vec<u8>,
}
opaque_debug::implement!(Record);
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
pub struct TranscriptRefs {
sent: Vec<Vector<U8>>,
recv: Vec<Vector<U8>>,
}
impl TranscriptRefs {
/// Returns the sent plaintext references.
pub fn sent(&self) -> &[Vector<U8>] {
&self.sent
}
/// Returns the received plaintext references.
pub fn recv(&self) -> &[Vector<U8>] {
&self.recv
}
/// Returns VM references for the given direction and index, otherwise
/// `None` if the index is out of bounds.
pub fn get(&self, direction: Direction, idx: &Idx) -> Option<Vec<Vector<U8>>> {
if idx.is_empty() {
return Some(Vec::new());
}
let refs = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
// Computes the transcript range for each reference.
let mut start = 0;
let mut slice_iter = refs.iter().map(move |slice| {
let out = (slice, start..start + slice.len());
start += slice.len();
out
});
let mut slices = Vec::new();
let (mut slice, mut slice_range) = slice_iter.next()?;
for range in idx.iter_ranges() {
loop {
if let Some(intersection) = slice_range.intersection(&range) {
let start = intersection.start - slice_range.start;
let end = intersection.end - slice_range.start;
slices.push(slice.get(start..end).expect("range should be in bounds"));
}
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
// to the next slice.
if range.end <= slice_range.end {
break;
} else {
(slice, slice_range) = slice_iter.next()?;
}
}
}
Some(slices)
}
}
/// Error for [`TranscriptRefs::from_transcript`].
#[derive(Debug, thiserror::Error)]
#[error("not all application plaintext was committed to in the TLS transcript")]
pub struct IncompleteTranscript {}

212
crates/common/src/zk_aes.rs Normal file
View File

@@ -0,0 +1,212 @@
//! Zero-knowledge AES-CTR encryption.
use cipher::{
aes::{Aes128, AesError},
Cipher, CipherError, Keystream,
};
use mpz_memory_core::{
binary::{Binary, U8},
Array, Vector,
};
use mpz_vm_core::{prelude::*, Vm};
use crate::Role;
type Nonce = Array<U8, 8>;
type Ctr = Array<U8, 4>;
type Block = Array<U8, 16>;
const START_CTR: u32 = 2;
/// ZK AES-CTR encryption.
#[derive(Debug)]
pub struct ZkAesCtr {
role: Role,
aes: Aes128,
state: State,
}
impl ZkAesCtr {
/// Creates a new ZK AES-CTR instance.
pub fn new(role: Role) -> Self {
Self {
role,
aes: Aes128::default(),
state: State::Init,
}
}
/// Returns the role.
pub fn role(&self) -> &Role {
&self.role
}
/// Allocates `len` bytes for encryption.
pub fn alloc(&mut self, vm: &mut dyn Vm<Binary>, len: usize) -> Result<(), ZkAesCtrError> {
let State::Init = self.state.take() else {
Err(ErrorRepr::State {
reason: "must be in init state to allocate",
})?
};
// Round up to the nearest block size.
let len = 16 * len.div_ceil(16);
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.aes.alloc_keystream(vm, len)?;
let output = keystream.apply(vm, input)?;
match self.role {
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
}
self.state = State::Ready {
input,
keystream,
output,
};
Ok(())
}
/// Sets the key and IV for the cipher.
pub fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
self.aes.set_key(key);
self.aes.set_iv(iv);
}
/// Proves the encryption of `len` bytes.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `len` - Length of the plaintext in bytes.
pub fn encrypt(
&mut self,
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
let State::Ready {
input,
keystream,
output,
} = &mut self.state
else {
Err(ErrorRepr::State {
reason: "must be in ready state to encrypt",
})?
};
let explicit_nonce: [u8; 8] =
explicit_nonce
.try_into()
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
expected: 8,
actual: explicit_nonce.len(),
})?;
let block_count = len.div_ceil(16);
let padded_len = block_count * 16;
let padding_len = padded_len - len;
if padded_len > input.len() {
Err(ErrorRepr::InsufficientPreprocessing {
expected: padded_len,
actual: input.len(),
})?
}
let mut input = input.split_off(input.len() - padded_len);
let keystream = keystream.consume(padded_len)?;
let mut output = output.split_off(output.len() - padded_len);
// Assign counter block inputs.
let mut ctr = START_CTR..;
keystream.assign(vm, explicit_nonce, move || {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
// Assign zeroes to the padding.
if padding_len > 0 {
let padding = input.split_off(input.len() - padding_len);
if let Role::Prover = self.role {
vm.assign(padding, vec![0; padding_len])
.map_err(ZkAesCtrError::vm)?;
}
vm.commit(padding).map_err(ZkAesCtrError::vm)?;
output.truncate(len);
}
Ok((input, output))
}
}
enum State {
Init,
Ready {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
output: Vector<U8>,
},
Error,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
State::Init => write!(f, "Init"),
State::Ready { .. } => write!(f, "Ready"),
State::Error => write!(f, "Error"),
}
}
}
/// Error for [`ZkAesCtr`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ZkAesCtrError(#[from] ErrorRepr);
impl ZkAesCtrError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("zk aes error")]
enum ErrorRepr {
#[error("invalid state: {reason}")]
State { reason: &'static str },
#[error("cipher error: {0}")]
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
ExplicitNonceLength { expected: usize, actual: usize },
#[error("insufficient preprocessing: expected {expected}, got {actual}")]
InsufficientPreprocessing { expected: usize, actual: usize },
}
impl From<AesError> for ZkAesCtrError {
fn from(err: AesError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}
impl From<CipherError> for ZkAesCtrError {
fn from(err: CipherError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}

View File

@@ -1,41 +0,0 @@
[package]
name = "tlsn-aead"
authors = ["TLSNotary Team"]
description = "This crate provides an implementation of a two-party version of AES-GCM behind an AEAD trait"
keywords = ["tls", "mpc", "2pc", "aead", "aes", "aes-gcm"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.8-pre"
edition = "2021"
[lib]
name = "aead"
[features]
default = ["mock"]
mock = ["mpz-common/test-utils", "dep:mpz-ot"]
[dependencies]
tlsn-block-cipher = { workspace = true }
tlsn-stream-cipher = { workspace = true }
tlsn-universal-hash = { workspace = true }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", optional = true, features = [
"ideal",
] }
serio = { workspace = true }
async-trait = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
aes-gcm = { workspace = true }

View File

@@ -1,36 +0,0 @@
use derive_builder::Builder;
/// Protocol role.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum Role {
Leader,
Follower,
}
/// Configuration for AES-GCM.
#[derive(Debug, Clone, Builder)]
pub struct AesGcmConfig {
/// The id of this instance.
#[builder(setter(into))]
id: String,
/// The protocol role.
role: Role,
}
impl AesGcmConfig {
/// Creates a new builder for the AES-GCM configuration.
pub fn builder() -> AesGcmConfigBuilder {
AesGcmConfigBuilder::default()
}
/// Returns the id of this instance.
pub fn id(&self) -> &str {
&self.id
}
/// Returns the protocol role.
pub fn role(&self) -> &Role {
&self.role
}
}

View File

@@ -1,102 +0,0 @@
use std::fmt::Display;
/// AES-GCM error.
#[derive(Debug, thiserror::Error)]
pub struct AesGcmError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl AesGcmError {
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
#[cfg(test)]
pub(crate) fn kind(&self) -> ErrorKind {
self.kind
}
pub(crate) fn invalid_tag() -> Self {
Self {
kind: ErrorKind::Tag,
source: None,
}
}
pub(crate) fn peer(reason: impl Into<String>) -> Self {
Self {
kind: ErrorKind::PeerMisbehaved,
source: Some(reason.into().into()),
}
}
pub(crate) fn payload(reason: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Payload,
source: Some(reason.into().into()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum ErrorKind {
Io,
BlockCipher,
StreamCipher,
Ghash,
Tag,
PeerMisbehaved,
Payload,
}
impl Display for AesGcmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ErrorKind::Io => write!(f, "io error")?,
ErrorKind::BlockCipher => write!(f, "block cipher error")?,
ErrorKind::StreamCipher => write!(f, "stream cipher error")?,
ErrorKind::Ghash => write!(f, "ghash error")?,
ErrorKind::Tag => write!(f, "payload has corrupted tag")?,
ErrorKind::PeerMisbehaved => write!(f, "peer misbehaved")?,
ErrorKind::Payload => write!(f, "payload error")?,
}
if let Some(source) = &self.source {
write!(f, " caused by: {}", source)?;
}
Ok(())
}
}
impl From<std::io::Error> for AesGcmError {
fn from(err: std::io::Error) -> Self {
Self::new(ErrorKind::Io, err)
}
}
impl From<block_cipher::BlockCipherError> for AesGcmError {
fn from(err: block_cipher::BlockCipherError) -> Self {
Self::new(ErrorKind::BlockCipher, err)
}
}
impl From<tlsn_stream_cipher::StreamCipherError> for AesGcmError {
fn from(err: tlsn_stream_cipher::StreamCipherError) -> Self {
Self::new(ErrorKind::StreamCipher, err)
}
}
impl From<tlsn_universal_hash::UniversalHashError> for AesGcmError {
fn from(err: tlsn_universal_hash::UniversalHashError) -> Self {
Self::new(ErrorKind::Ghash, err)
}
}

View File

@@ -1,96 +0,0 @@
//! Mock implementation of AES-GCM for testing purposes.
use block_cipher::{BlockCipherConfig, MpcBlockCipher};
use mpz_common::executor::{test_st_executor, STExecutor};
use mpz_garble::protocol::deap::mock::{MockFollower, MockLeader};
use mpz_ot::ideal::ot::ideal_ot;
use serio::channel::MemoryDuplex;
use tlsn_stream_cipher::{MpcStreamCipher, StreamCipherConfig};
use tlsn_universal_hash::ghash::ideal_ghash;
use super::*;
/// Creates a mock AES-GCM pair.
///
/// # Arguments
///
/// * `id` - The id of the AES-GCM instances.
/// * `(leader, follower)` - The leader and follower vms.
/// * `leader_config` - The configuration of the leader.
/// * `follower_config` - The configuration of the follower.
pub async fn create_mock_aes_gcm_pair(
id: &str,
(leader, follower): (MockLeader, MockFollower),
leader_config: AesGcmConfig,
follower_config: AesGcmConfig,
) -> (
MpcAesGcm<STExecutor<MemoryDuplex>>,
MpcAesGcm<STExecutor<MemoryDuplex>>,
) {
let block_cipher_id = format!("{}/block_cipher", id);
let (ctx_leader, ctx_follower) = test_st_executor(128);
let (leader_ot_send, follower_ot_recv) = ideal_ot();
let (follower_ot_send, leader_ot_recv) = ideal_ot();
let block_leader = leader
.new_thread(ctx_leader, leader_ot_send, leader_ot_recv)
.unwrap();
let block_follower = follower
.new_thread(ctx_follower, follower_ot_send, follower_ot_recv)
.unwrap();
let leader_block_cipher = MpcBlockCipher::new(
BlockCipherConfig::builder()
.id(block_cipher_id.clone())
.build()
.unwrap(),
block_leader,
);
let follower_block_cipher = MpcBlockCipher::new(
BlockCipherConfig::builder()
.id(block_cipher_id.clone())
.build()
.unwrap(),
block_follower,
);
let stream_cipher_id = format!("{}/stream_cipher", id);
let leader_stream_cipher = MpcStreamCipher::new(
StreamCipherConfig::builder()
.id(stream_cipher_id.clone())
.build()
.unwrap(),
leader,
);
let follower_stream_cipher = MpcStreamCipher::new(
StreamCipherConfig::builder()
.id(stream_cipher_id.clone())
.build()
.unwrap(),
follower,
);
let (ctx_a, ctx_b) = test_st_executor(128);
let (leader_ghash, follower_ghash) = ideal_ghash(ctx_a, ctx_b);
let (ctx_a, ctx_b) = test_st_executor(128);
let leader = MpcAesGcm::new(
leader_config,
ctx_a,
Box::new(leader_block_cipher),
Box::new(leader_stream_cipher),
Box::new(leader_ghash),
);
let follower = MpcAesGcm::new(
follower_config,
ctx_b,
Box::new(follower_block_cipher),
Box::new(follower_stream_cipher),
Box::new(follower_ghash),
);
(leader, follower)
}

View File

@@ -1,712 +0,0 @@
//! This module provides an implementation of 2PC AES-GCM.
mod config;
mod error;
#[cfg(feature = "mock")]
pub mod mock;
mod tag;
pub use config::{AesGcmConfig, AesGcmConfigBuilder, AesGcmConfigBuilderError, Role};
pub use error::AesGcmError;
use async_trait::async_trait;
use block_cipher::{Aes128, BlockCipher};
use futures::TryFutureExt;
use mpz_common::Context;
use mpz_garble::value::ValueRef;
use tlsn_stream_cipher::{Aes128Ctr, StreamCipher};
use tlsn_universal_hash::UniversalHash;
use tracing::instrument;
use crate::{
aes_gcm::tag::{compute_tag, verify_tag, TAG_LEN},
Aead,
};
/// MPC AES-GCM.
pub struct MpcAesGcm<Ctx> {
config: AesGcmConfig,
ctx: Ctx,
aes_block: Box<dyn BlockCipher<Aes128>>,
aes_ctr: Box<dyn StreamCipher<Aes128Ctr>>,
ghash: Box<dyn UniversalHash>,
}
impl<Ctx> std::fmt::Debug for MpcAesGcm<Ctx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MpcAesGcm")
.field("config", &self.config)
.finish()
}
}
impl<Ctx: Context> MpcAesGcm<Ctx> {
/// Creates a new instance of [`MpcAesGcm`].
pub fn new(
config: AesGcmConfig,
context: Ctx,
aes_block: Box<dyn BlockCipher<Aes128>>,
aes_ctr: Box<dyn StreamCipher<Aes128Ctr>>,
ghash: Box<dyn UniversalHash>,
) -> Self {
Self {
config,
ctx: context,
aes_block,
aes_ctr,
ghash,
}
}
}
#[async_trait]
impl<Ctx: Context> Aead for MpcAesGcm<Ctx> {
type Error = AesGcmError;
#[instrument(level = "info", skip_all, err)]
async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AesGcmError> {
self.aes_block.set_key(key.clone());
self.aes_ctr.set_key(key, iv);
Ok(())
}
#[instrument(level = "info", skip_all, err)]
async fn decode_key_private(&mut self) -> Result<(), AesGcmError> {
self.aes_ctr
.decode_key_private()
.await
.map_err(AesGcmError::from)
}
#[instrument(level = "info", skip_all, err)]
async fn decode_key_blind(&mut self) -> Result<(), AesGcmError> {
self.aes_ctr
.decode_key_blind()
.await
.map_err(AesGcmError::from)
}
fn set_transcript_id(&mut self, id: &str) {
self.aes_ctr.set_transcript_id(id)
}
#[instrument(level = "debug", skip(self), err)]
async fn setup(&mut self) -> Result<(), AesGcmError> {
self.ghash.setup().await?;
Ok(())
}
#[instrument(level = "debug", skip(self), err)]
async fn preprocess(&mut self, len: usize) -> Result<(), AesGcmError> {
futures::try_join!(
// Preprocess the GHASH key block.
self.aes_block
.preprocess(block_cipher::Visibility::Public, 1)
.map_err(AesGcmError::from),
self.aes_ctr.preprocess(len).map_err(AesGcmError::from),
self.ghash.preprocess().map_err(AesGcmError::from),
)?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn start(&mut self) -> Result<(), AesGcmError> {
let h_share = self.aes_block.encrypt_share(vec![0u8; 16]).await?;
self.ghash.set_key(h_share).await?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
let ciphertext = self
.aes_ctr
.encrypt_public(explicit_nonce.clone(), plaintext)
.await?;
let tag = compute_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
explicit_nonce,
ciphertext.clone(),
aad,
)
.await?;
let mut payload = ciphertext;
payload.extend(tag);
Ok(payload)
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
let ciphertext = self
.aes_ctr
.encrypt_private(explicit_nonce.clone(), plaintext)
.await?;
let tag = compute_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
explicit_nonce,
ciphertext.clone(),
aad,
)
.await?;
let mut payload = ciphertext;
payload.extend(tag);
Ok(payload)
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
plaintext_len: usize,
aad: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
let ciphertext = self
.aes_ctr
.encrypt_blind(explicit_nonce.clone(), plaintext_len)
.await?;
let tag = compute_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
explicit_nonce,
ciphertext.clone(),
aad,
)
.await?;
let mut payload = ciphertext;
payload.extend(tag);
Ok(payload)
}
#[instrument(level = "debug", skip_all, err)]
async fn decrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
mut payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
let purported_tag: [u8; TAG_LEN] = payload
.split_off(payload.len() - TAG_LEN)
.try_into()
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
let ciphertext = payload;
verify_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
*self.config.role(),
explicit_nonce.clone(),
ciphertext.clone(),
aad,
purported_tag,
)
.await?;
let plaintext = self
.aes_ctr
.decrypt_public(explicit_nonce, ciphertext)
.await?;
Ok(plaintext)
}
#[instrument(level = "debug", skip_all, err)]
async fn decrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
mut payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
let purported_tag: [u8; TAG_LEN] = payload
.split_off(payload.len() - TAG_LEN)
.try_into()
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
let ciphertext = payload;
verify_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
*self.config.role(),
explicit_nonce.clone(),
ciphertext.clone(),
aad,
purported_tag,
)
.await?;
let plaintext = self
.aes_ctr
.decrypt_private(explicit_nonce, ciphertext)
.await?;
Ok(plaintext)
}
#[instrument(level = "debug", skip_all, err)]
async fn decrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
mut payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AesGcmError> {
let purported_tag: [u8; TAG_LEN] = payload
.split_off(payload.len() - TAG_LEN)
.try_into()
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
let ciphertext = payload;
verify_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
*self.config.role(),
explicit_nonce.clone(),
ciphertext.clone(),
aad,
purported_tag,
)
.await?;
self.aes_ctr
.decrypt_blind(explicit_nonce, ciphertext)
.await?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn verify_tag(
&mut self,
explicit_nonce: Vec<u8>,
mut payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AesGcmError> {
let purported_tag: [u8; TAG_LEN] = payload
.split_off(payload.len() - TAG_LEN)
.try_into()
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
let ciphertext = payload;
verify_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
*self.config.role(),
explicit_nonce,
ciphertext,
aad,
purported_tag,
)
.await?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
mut payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
let purported_tag: [u8; TAG_LEN] = payload
.split_off(payload.len() - TAG_LEN)
.try_into()
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
let ciphertext = payload;
verify_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
*self.config.role(),
explicit_nonce.clone(),
ciphertext.clone(),
aad,
purported_tag,
)
.await?;
let plaintext = self
.aes_ctr
.prove_plaintext(explicit_nonce, ciphertext)
.await?;
Ok(plaintext)
}
#[instrument(level = "debug", skip_all, err)]
async fn prove_plaintext_no_tag(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, AesGcmError> {
self.aes_ctr
.prove_plaintext(explicit_nonce, ciphertext)
.map_err(AesGcmError::from)
.await
}
#[instrument(level = "debug", skip_all, err)]
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
mut payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AesGcmError> {
let purported_tag: [u8; TAG_LEN] = payload
.split_off(payload.len() - TAG_LEN)
.try_into()
.map_err(|_| AesGcmError::payload("payload is not long enough to contain tag"))?;
let ciphertext = payload;
verify_tag(
&mut self.ctx,
self.aes_ctr.as_mut(),
self.ghash.as_mut(),
*self.config.role(),
explicit_nonce.clone(),
ciphertext.clone(),
aad,
purported_tag,
)
.await?;
self.aes_ctr
.verify_plaintext(explicit_nonce, ciphertext)
.await?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn verify_plaintext_no_tag(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), AesGcmError> {
self.aes_ctr
.verify_plaintext(explicit_nonce, ciphertext)
.map_err(AesGcmError::from)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
aes_gcm::{mock::create_mock_aes_gcm_pair, AesGcmConfigBuilder, Role},
Aead,
};
use ::aes_gcm::{aead::AeadInPlace, Aes128Gcm, NewAead, Nonce};
use error::ErrorKind;
use mpz_common::executor::STExecutor;
use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory};
use serio::channel::MemoryDuplex;
fn reference_impl(
key: &[u8],
iv: &[u8],
explicit_nonce: &[u8],
plaintext: &[u8],
aad: &[u8],
) -> Vec<u8> {
let cipher = Aes128Gcm::new_from_slice(key).unwrap();
let nonce = [iv, explicit_nonce].concat();
let nonce = Nonce::from_slice(nonce.as_slice());
let mut ciphertext = plaintext.to_vec();
cipher
.encrypt_in_place(nonce, aad, &mut ciphertext)
.unwrap();
ciphertext
}
async fn setup_pair(
key: Vec<u8>,
iv: Vec<u8>,
) -> (
MpcAesGcm<STExecutor<MemoryDuplex>>,
MpcAesGcm<STExecutor<MemoryDuplex>>,
) {
let (leader_vm, follower_vm) = create_mock_deap_vm();
let leader_key = leader_vm
.new_public_array_input::<u8>("key", key.len())
.unwrap();
let leader_iv = leader_vm
.new_public_array_input::<u8>("iv", iv.len())
.unwrap();
leader_vm.assign(&leader_key, key.clone()).unwrap();
leader_vm.assign(&leader_iv, iv.clone()).unwrap();
let follower_key = follower_vm
.new_public_array_input::<u8>("key", key.len())
.unwrap();
let follower_iv = follower_vm
.new_public_array_input::<u8>("iv", iv.len())
.unwrap();
follower_vm.assign(&follower_key, key.clone()).unwrap();
follower_vm.assign(&follower_iv, iv.clone()).unwrap();
let leader_config = AesGcmConfigBuilder::default()
.id("test".to_string())
.role(Role::Leader)
.build()
.unwrap();
let follower_config = AesGcmConfigBuilder::default()
.id("test".to_string())
.role(Role::Follower)
.build()
.unwrap();
let (mut leader, mut follower) = create_mock_aes_gcm_pair(
"test",
(leader_vm, follower_vm),
leader_config,
follower_config,
)
.await;
futures::try_join!(
leader.set_key(leader_key, leader_iv),
follower.set_key(follower_key, follower_iv)
)
.unwrap();
futures::try_join!(leader.setup(), follower.setup()).unwrap();
futures::try_join!(leader.start(), follower.start()).unwrap();
(leader, follower)
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_encrypt_private() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
leader.encrypt_private(explicit_nonce.clone(), plaintext.clone(), aad.clone(),),
follower.encrypt_blind(explicit_nonce.clone(), plaintext.len(), aad.clone())
)
.unwrap();
assert_eq!(leader_ciphertext, follower_ciphertext);
assert_eq!(
leader_ciphertext,
reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad)
);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_encrypt_public() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
leader.encrypt_public(explicit_nonce.clone(), plaintext.clone(), aad.clone(),),
follower.encrypt_public(explicit_nonce.clone(), plaintext.clone(), aad.clone(),)
)
.unwrap();
assert_eq!(leader_ciphertext, follower_ciphertext);
assert_eq!(
leader_ciphertext,
reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad)
);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_decrypt_private() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
let (leader_plaintext, _) = tokio::try_join!(
leader.decrypt_private(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
follower.decrypt_blind(explicit_nonce.clone(), ciphertext, aad.clone(),)
)
.unwrap();
assert_eq!(leader_plaintext, plaintext);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_decrypt_private_bad_tag() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
let len = ciphertext.len();
// corrupt tag
let mut corrupted = ciphertext.clone();
corrupted[len - 1] -= 1;
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
// leader receives corrupted tag
let err = tokio::try_join!(
leader.decrypt_private(explicit_nonce.clone(), corrupted.clone(), aad.clone(),),
follower.decrypt_blind(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),)
)
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Tag);
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
// follower receives corrupted tag
let err = tokio::try_join!(
leader.decrypt_private(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
follower.decrypt_blind(explicit_nonce.clone(), corrupted.clone(), aad.clone(),)
)
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Tag);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_decrypt_public() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
let (leader_plaintext, follower_plaintext) = tokio::try_join!(
leader.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
follower.decrypt_public(explicit_nonce.clone(), ciphertext, aad.clone(),)
)
.unwrap();
assert_eq!(leader_plaintext, plaintext);
assert_eq!(leader_plaintext, follower_plaintext);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_decrypt_public_bad_tag() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
let len = ciphertext.len();
// Corrupt tag.
let mut corrupted = ciphertext.clone();
corrupted[len - 1] -= 1;
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
// Leader receives corrupted tag.
let err = tokio::try_join!(
leader.decrypt_public(explicit_nonce.clone(), corrupted.clone(), aad.clone(),),
follower.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),)
)
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Tag);
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
// Follower receives corrupted tag.
let err = tokio::try_join!(
leader.decrypt_public(explicit_nonce.clone(), ciphertext.clone(), aad.clone(),),
follower.decrypt_public(explicit_nonce.clone(), corrupted.clone(), aad.clone(),)
)
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Tag);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_aes_gcm_verify_tag() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
let len = ciphertext.len();
let (mut leader, mut follower) = setup_pair(key.clone(), iv.clone()).await;
tokio::try_join!(
leader.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()),
follower.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone())
)
.unwrap();
//Corrupt tag.
let mut corrupted = ciphertext.clone();
corrupted[len - 1] -= 1;
let (leader_res, follower_res) = tokio::join!(
leader.verify_tag(explicit_nonce.clone(), corrupted.clone(), aad.clone()),
follower.verify_tag(explicit_nonce.clone(), corrupted, aad.clone())
);
assert_eq!(leader_res.unwrap_err().kind(), ErrorKind::Tag);
assert_eq!(follower_res.unwrap_err().kind(), ErrorKind::Tag);
}
}

View File

@@ -1,179 +0,0 @@
use futures::TryFutureExt;
use mpz_common::Context;
use mpz_core::{
commit::{Decommitment, HashCommit},
hash::Hash,
};
use serde::{Deserialize, Serialize};
use serio::{stream::IoStreamExt, SinkExt};
use std::ops::Add;
use tlsn_stream_cipher::{Aes128Ctr, StreamCipher};
use tlsn_universal_hash::UniversalHash;
use tracing::instrument;
use crate::aes_gcm::{AesGcmError, Role};
pub(crate) const TAG_LEN: usize = 16;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TagShare([u8; TAG_LEN]);
impl AsRef<[u8]> for TagShare {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Add for TagShare {
type Output = [u8; TAG_LEN];
fn add(self, rhs: Self) -> Self::Output {
core::array::from_fn(|i| self.0[i] ^ rhs.0[i])
}
}
#[instrument(level = "trace", skip_all, err)]
async fn compute_tag_share<C: StreamCipher<Aes128Ctr> + ?Sized, H: UniversalHash + ?Sized>(
aes_ctr: &mut C,
hasher: &mut H,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<TagShare, AesGcmError> {
let (j0, hash) = futures::try_join!(
aes_ctr
.share_keystream_block(explicit_nonce, 1)
.map_err(AesGcmError::from),
hasher
.finalize(build_ghash_data(aad, ciphertext))
.map_err(AesGcmError::from)
)?;
debug_assert!(j0.len() == TAG_LEN);
debug_assert!(hash.len() == TAG_LEN);
let tag_share = core::array::from_fn(|i| j0[i] ^ hash[i]);
Ok(TagShare(tag_share))
}
/// Computes the tag for a ciphertext and additional data.
///
/// The commit-reveal step is not required for computing a tag sent to the
/// Server, as it will be able to detect if the tag is incorrect.
#[instrument(level = "debug", skip_all, err)]
pub(crate) async fn compute_tag<
Ctx: Context,
C: StreamCipher<Aes128Ctr> + ?Sized,
H: UniversalHash + ?Sized,
>(
ctx: &mut Ctx,
aes_ctr: &mut C,
hasher: &mut H,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<[u8; TAG_LEN], AesGcmError> {
let tag_share = compute_tag_share(aes_ctr, hasher, explicit_nonce, ciphertext, aad).await?;
// TODO: The follower doesn't really need to learn the tag,
// we could reduce some latency by not sending it.
let io = ctx.io_mut();
io.send(tag_share.clone()).await?;
let other_tag_share: TagShare = io.expect_next().await?;
let tag = tag_share + other_tag_share;
Ok(tag)
}
/// Verifies a purported tag against the ciphertext and additional data.
///
/// Verifying a tag requires a commit-reveal protocol between the leader and
/// follower. Without it, the party which receives the other's tag share first
/// could trivially compute a tag share which would cause an invalid message to
/// be accepted.
#[instrument(level = "debug", skip_all, err)]
#[allow(clippy::too_many_arguments)]
pub(crate) async fn verify_tag<
Ctx: Context,
C: StreamCipher<Aes128Ctr> + ?Sized,
H: UniversalHash + ?Sized,
>(
ctx: &mut Ctx,
aes_ctr: &mut C,
hasher: &mut H,
role: Role,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
purported_tag: [u8; TAG_LEN],
) -> Result<(), AesGcmError> {
let tag_share = compute_tag_share(aes_ctr, hasher, explicit_nonce, ciphertext, aad).await?;
let io = ctx.io_mut();
let tag = match role {
Role::Leader => {
// Send commitment of tag share to follower.
let (tag_share_decommitment, tag_share_commitment) = tag_share.clone().hash_commit();
io.send(tag_share_commitment).await?;
let follower_tag_share: TagShare = io.expect_next().await?;
// Send decommitment (tag share) to follower.
io.send(tag_share_decommitment).await?;
tag_share + follower_tag_share
}
Role::Follower => {
// Wait for commitment from leader.
let commitment: Hash = io.expect_next().await?;
// Send tag share to leader.
io.send(tag_share.clone()).await?;
// Expect decommitment (tag share) from leader.
let decommitment: Decommitment<TagShare> = io.expect_next().await?;
// Verify decommitment.
decommitment.verify(&commitment).map_err(|_| {
AesGcmError::peer("leader tag share commitment verification failed")
})?;
let leader_tag_share = decommitment.into_inner();
tag_share + leader_tag_share
}
};
// Reject if tag is incorrect.
if tag != purported_tag {
return Err(AesGcmError::invalid_tag());
}
Ok(())
}
/// Builds padded data for GHASH.
fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
let associated_data_bitlen = (aad.len() as u64) * 8;
let text_bitlen = (ciphertext.len() as u64) * 8;
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
data.extend(aad);
data.extend(ciphertext);
data.extend_from_slice(&len_block.to_be_bytes());
data
}

View File

@@ -1,255 +0,0 @@
//! This crate provides implementations of 2PC AEADs for authenticated
//! encryption with a shared key.
//!
//! Both parties can work together to encrypt and decrypt messages with
//! different visibility configurations. See [`Aead`] for more information on
//! the interface.
//!
//! For example, one party can privately provide the plaintext to encrypt, while
//! both parties can see the ciphertext and the tag. Or, both parties can
//! cooperate to decrypt a ciphertext and verify the tag, while only one party
//! can see the plaintext.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub mod aes_gcm;
use async_trait::async_trait;
use mpz_garble::value::ValueRef;
/// This trait defines the interface for AEADs.
#[async_trait]
pub trait Aead: Send {
/// The error type for the AEAD.
type Error: std::error::Error + Send + Sync + 'static;
/// Sets the key for the AEAD.
async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), Self::Error>;
/// Decodes the key for the AEAD, revealing it to this party.
async fn decode_key_private(&mut self) -> Result<(), Self::Error>;
/// Decodes the key for the AEAD, revealing it to the other party(s).
async fn decode_key_blind(&mut self) -> Result<(), Self::Error>;
/// Sets the transcript id.
///
/// The AEAD assigns unique identifiers to each byte of plaintext
/// during encryption and decryption.
///
/// For example, if the transcript id is set to `foo`, then the first byte
/// will be assigned the id `foo/0`, the second byte `foo/1`, and so on.
///
/// Each transcript id has an independent counter.
///
/// # Note
///
/// The state of a transcript counter is preserved between calls to
/// `set_transcript_id`.
fn set_transcript_id(&mut self, id: &str);
/// Performs any necessary one-time setup for the AEAD.
async fn setup(&mut self) -> Result<(), Self::Error>;
/// Preprocesses for the given number of bytes.
async fn preprocess(&mut self, len: usize) -> Result<(), Self::Error>;
/// Starts the AEAD.
///
/// This method performs initialization for the AEAD after setting the key.
async fn start(&mut self) -> Result<(), Self::Error>;
/// Encrypts a plaintext message, returning the ciphertext and tag.
///
/// The plaintext is provided by both parties.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for encryption.
/// * `plaintext` - The plaintext to encrypt.
/// * `aad` - Additional authenticated data.
async fn encrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Encrypts a plaintext message, hiding it from the other party, returning
/// the ciphertext and tag.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for encryption.
/// * `plaintext` - The plaintext to encrypt.
/// * `aad` - Additional authenticated data.
async fn encrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Encrypts a plaintext message provided by the other party, returning
/// the ciphertext and tag.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for encryption.
/// * `plaintext_len` - The length of the plaintext to encrypt.
/// * `aad` - Additional authenticated data.
async fn encrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
plaintext_len: usize,
aad: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Decrypts a ciphertext message, returning the plaintext to both parties.
///
/// This method checks the authenticity of the ciphertext, tag and
/// additional data.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for decryption.
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
/// * `aad` - Additional authenticated data.
async fn decrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Decrypts a ciphertext message, returning the plaintext only to this
/// party.
///
/// This method checks the authenticity of the ciphertext, tag and
/// additional data.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for decryption.
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
/// * `aad` - Additional authenticated data.
async fn decrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Decrypts a ciphertext message, returning the plaintext only to the other
/// party.
///
/// This method checks the authenticity of the ciphertext, tag and
/// additional data.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for decryption.
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
/// * `aad` - Additional authenticated data.
async fn decrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), Self::Error>;
/// Verifies the tag of a ciphertext message.
///
/// This method checks the authenticity of the ciphertext, tag and
/// additional data.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for decryption.
/// * `payload` - The ciphertext and tag to authenticate and decrypt.
/// * `aad` - Additional authenticated data.
async fn verify_tag(
&mut self,
explicit_nonce: Vec<u8>,
payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), Self::Error>;
/// Locally decrypts the provided ciphertext and then proves in ZK to the
/// other party(s) that the plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be
/// achieved by calling the `decode_key_private` method.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `payload` - The ciphertext and tag to decrypt and prove.
/// * `aad` - Additional authenticated data.
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Locally decrypts the provided ciphertext and then proves in ZK to the
/// other party(s) that the plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be
/// achieved by calling the `decode_key_private` method.
///
/// # WARNING
///
/// This method does not verify the tag of the ciphertext. Only use this if
/// you know what you're doing.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to decrypt and prove.
async fn prove_plaintext_no_tag(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, Self::Error>;
/// Verifies the other party(s) can prove they know a plaintext which
/// encrypts to the given ciphertext.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `payload` - The ciphertext and tag to verify.
/// * `aad` - Additional authenticated data.
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
payload: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), Self::Error>;
/// Verifies the other party(s) can prove they know a plaintext which
/// encrypts to the given ciphertext.
///
/// # WARNING
///
/// This method does not verify the tag of the ciphertext. Only use this if
/// you know what you're doing.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to verify.
async fn verify_plaintext_no_tag(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), Self::Error>;
}

View File

@@ -1,30 +0,0 @@
[package]
name = "tlsn-block-cipher"
authors = ["TLSNotary Team"]
description = "2PC block cipher implementation"
keywords = ["tls", "mpc", "2pc", "block-cipher"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.8-pre"
edition = "2021"
[lib]
name = "block_cipher"
[features]
default = ["mock"]
mock = []
[dependencies]
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
tlsn-utils = { workspace = true }
async-trait = { workspace = true }
thiserror = { workspace = true }
derive_builder = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
aes = { workspace = true }
cipher = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }

View File

@@ -1,277 +0,0 @@
use std::{collections::VecDeque, marker::PhantomData};
use async_trait::async_trait;
use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Load, Memory};
use tracing::instrument;
use utils::id::NestedId;
use crate::{BlockCipher, BlockCipherCircuit, BlockCipherConfig, BlockCipherError, Visibility};
#[derive(Debug)]
struct State {
private_execution_id: NestedId,
public_execution_id: NestedId,
preprocessed_private: VecDeque<BlockVars>,
preprocessed_public: VecDeque<BlockVars>,
key: Option<ValueRef>,
}
#[derive(Debug)]
struct BlockVars {
msg: ValueRef,
ciphertext: ValueRef,
}
/// An MPC block cipher.
#[derive(Debug)]
pub struct MpcBlockCipher<C, E>
where
C: BlockCipherCircuit,
E: Memory + Execute + Decode + DecodePrivate + Send + Sync,
{
state: State,
executor: E,
_cipher: PhantomData<C>,
}
impl<C, E> MpcBlockCipher<C, E>
where
C: BlockCipherCircuit,
E: Memory + Execute + Decode + DecodePrivate + Send + Sync,
{
/// Creates a new MPC block cipher.
///
/// # Arguments
///
/// * `config` - The configuration for the block cipher.
/// * `executor` - The executor to use for the MPC.
pub fn new(config: BlockCipherConfig, executor: E) -> Self {
let private_execution_id = NestedId::new(&config.id)
.append_string("private")
.append_counter();
let public_execution_id = NestedId::new(&config.id)
.append_string("public")
.append_counter();
Self {
state: State {
private_execution_id,
public_execution_id,
preprocessed_private: VecDeque::new(),
preprocessed_public: VecDeque::new(),
key: None,
},
executor,
_cipher: PhantomData,
}
}
fn define_block(&mut self, vis: Visibility) -> BlockVars {
let (id, msg) = match vis {
Visibility::Private => {
let id = self
.state
.private_execution_id
.increment_in_place()
.to_string();
let msg = self
.executor
.new_private_input::<C::BLOCK>(&format!("{}/msg", &id))
.expect("message is not defined");
(id, msg)
}
Visibility::Blind => {
let id = self
.state
.private_execution_id
.increment_in_place()
.to_string();
let msg = self
.executor
.new_blind_input::<C::BLOCK>(&format!("{}/msg", &id))
.expect("message is not defined");
(id, msg)
}
Visibility::Public => {
let id = self
.state
.public_execution_id
.increment_in_place()
.to_string();
let msg = self
.executor
.new_public_input::<C::BLOCK>(&format!("{}/msg", &id))
.expect("message is not defined");
(id, msg)
}
};
let ciphertext = self
.executor
.new_output::<C::BLOCK>(&format!("{}/ciphertext", &id))
.expect("message is not defined");
BlockVars { msg, ciphertext }
}
}
#[async_trait]
impl<C, E> BlockCipher<C> for MpcBlockCipher<C, E>
where
C: BlockCipherCircuit,
E: Memory + Load + Execute + Decode + DecodePrivate + Send + Sync + Send,
{
#[instrument(level = "trace", skip_all)]
fn set_key(&mut self, key: ValueRef) {
self.state.key = Some(key);
}
#[instrument(level = "debug", skip_all, err)]
async fn preprocess(
&mut self,
visibility: Visibility,
count: usize,
) -> Result<(), BlockCipherError> {
let key = self
.state
.key
.clone()
.ok_or_else(BlockCipherError::key_not_set)?;
for _ in 0..count {
let vars = self.define_block(visibility);
self.executor
.load(
C::circuit(),
&[key.clone(), vars.msg.clone()],
&[vars.ciphertext.clone()],
)
.await?;
match visibility {
Visibility::Private | Visibility::Blind => {
self.state.preprocessed_private.push_back(vars)
}
Visibility::Public => self.state.preprocessed_public.push_back(vars),
}
}
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_private(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError> {
let len = plaintext.len();
let block: C::BLOCK = plaintext
.try_into()
.map_err(|_| BlockCipherError::invalid_message_length::<C>(len))?;
let key = self
.state
.key
.clone()
.ok_or_else(BlockCipherError::key_not_set)?;
let BlockVars { msg, ciphertext } =
if let Some(vars) = self.state.preprocessed_private.pop_front() {
vars
} else {
self.define_block(Visibility::Private)
};
self.executor.assign(&msg, block)?;
self.executor
.execute(C::circuit(), &[key, msg], &[ciphertext.clone()])
.await?;
let mut outputs = self.executor.decode(&[ciphertext]).await?;
let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs
.pop()
.expect("ciphertext should be present")
.try_into()
{
ciphertext
} else {
panic!("ciphertext should be a block")
};
Ok(ciphertext.into())
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_blind(&mut self) -> Result<Vec<u8>, BlockCipherError> {
let key = self
.state
.key
.clone()
.ok_or_else(BlockCipherError::key_not_set)?;
let BlockVars { msg, ciphertext } =
if let Some(vars) = self.state.preprocessed_private.pop_front() {
vars
} else {
self.define_block(Visibility::Blind)
};
self.executor
.execute(C::circuit(), &[key, msg], &[ciphertext.clone()])
.await?;
let mut outputs = self.executor.decode(&[ciphertext]).await?;
let ciphertext: C::BLOCK = if let Ok(ciphertext) = outputs
.pop()
.expect("ciphertext should be present")
.try_into()
{
ciphertext
} else {
panic!("ciphertext should be a block")
};
Ok(ciphertext.into())
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_share(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError> {
let len = plaintext.len();
let block: C::BLOCK = plaintext
.try_into()
.map_err(|_| BlockCipherError::invalid_message_length::<C>(len))?;
let key = self
.state
.key
.clone()
.ok_or_else(BlockCipherError::key_not_set)?;
let BlockVars { msg, ciphertext } =
if let Some(vars) = self.state.preprocessed_public.pop_front() {
vars
} else {
self.define_block(Visibility::Public)
};
self.executor.assign(&msg, block)?;
self.executor
.execute(C::circuit(), &[key, msg], &[ciphertext.clone()])
.await?;
let mut outputs = self.executor.decode_shared(&[ciphertext]).await?;
let share: C::BLOCK =
if let Ok(share) = outputs.pop().expect("share should be present").try_into() {
share
} else {
panic!("share should be a block")
};
Ok(share.into())
}
}

View File

@@ -1,39 +0,0 @@
use std::sync::Arc;
use mpz_circuits::{
circuits::AES128,
types::{StaticValueType, Value},
Circuit,
};
/// A block cipher circuit.
pub trait BlockCipherCircuit: Default + Clone + Send + Sync {
/// The key type.
type KEY: StaticValueType + Send + Sync;
/// The block type.
type BLOCK: StaticValueType + TryFrom<Vec<u8>> + TryFrom<Value> + Into<Vec<u8>> + Send + Sync;
/// The length of the key.
const KEY_LEN: usize;
/// The length of the block.
const BLOCK_LEN: usize;
/// Returns the circuit of the cipher.
fn circuit() -> Arc<Circuit>;
}
/// Aes128 block cipher circuit.
#[derive(Default, Debug, Clone)]
pub struct Aes128;
impl BlockCipherCircuit for Aes128 {
type KEY = [u8; 16];
type BLOCK = [u8; 16];
const KEY_LEN: usize = 16;
const BLOCK_LEN: usize = 16;
fn circuit() -> Arc<Circuit> {
AES128.clone()
}
}

View File

@@ -1,16 +0,0 @@
use derive_builder::Builder;
/// Configuration for a block cipher.
#[derive(Debug, Clone, Builder)]
pub struct BlockCipherConfig {
/// The ID of the block cipher.
#[builder(setter(into))]
pub(crate) id: String,
}
impl BlockCipherConfig {
/// Creates a new builder for the block cipher configuration.
pub fn builder() -> BlockCipherConfigBuilder {
BlockCipherConfigBuilder::default()
}
}

View File

@@ -1,92 +0,0 @@
use core::fmt;
use std::error::Error;
use crate::BlockCipherCircuit;
/// A block cipher error.
#[derive(Debug, thiserror::Error)]
pub struct BlockCipherError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn Error + Send + Sync>>,
}
impl BlockCipherError {
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
pub(crate) fn key_not_set() -> Self {
Self {
kind: ErrorKind::Key,
source: Some("key not set".into()),
}
}
pub(crate) fn invalid_message_length<C: BlockCipherCircuit>(len: usize) -> Self {
Self {
kind: ErrorKind::Msg,
source: Some(
format!(
"message length does not equal block length: {} != {}",
len,
C::BLOCK_LEN
)
.into(),
),
}
}
}
#[derive(Debug)]
pub(crate) enum ErrorKind {
Vm,
Key,
Msg,
}
impl fmt::Display for BlockCipherError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
ErrorKind::Vm => write!(f, "vm error")?,
ErrorKind::Key => write!(f, "key error")?,
ErrorKind::Msg => write!(f, "message error")?,
}
if let Some(ref source) = self.source {
write!(f, " caused by: {}", source)?;
}
Ok(())
}
}
impl From<mpz_garble::MemoryError> for BlockCipherError {
fn from(error: mpz_garble::MemoryError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::LoadError> for BlockCipherError {
fn from(error: mpz_garble::LoadError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::ExecutionError> for BlockCipherError {
fn from(error: mpz_garble::ExecutionError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::DecodeError> for BlockCipherError {
fn from(error: mpz_garble::DecodeError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

View File

@@ -1,236 +0,0 @@
//! This crate provides a 2PC block cipher implementation.
//!
//! Both parties work together to encrypt or share an encrypted block using a
//! shared key.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![deny(unsafe_code)]
mod cipher;
mod circuit;
mod config;
mod error;
use async_trait::async_trait;
use mpz_garble::value::ValueRef;
pub use crate::{
cipher::MpcBlockCipher,
circuit::{Aes128, BlockCipherCircuit},
};
pub use config::{BlockCipherConfig, BlockCipherConfigBuilder, BlockCipherConfigBuilderError};
pub use error::BlockCipherError;
/// Visibility of a message plaintext.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Visibility {
/// Private message.
Private,
/// Blind message.
Blind,
/// Public message.
Public,
}
/// A trait for MPC block ciphers.
#[async_trait]
pub trait BlockCipher<Cipher>: Send + Sync
where
Cipher: BlockCipherCircuit,
{
/// Sets the key for the block cipher.
fn set_key(&mut self, key: ValueRef);
/// Preprocesses `count` blocks.
///
/// # Arguments
///
/// * `visibility` - The visibility of the plaintext.
/// * `count` - The number of blocks to preprocess.
async fn preprocess(
&mut self,
visibility: Visibility,
count: usize,
) -> Result<(), BlockCipherError>;
/// Encrypts the given plaintext keeping it hidden from the other party(s).
///
/// Returns the ciphertext.
///
/// # Arguments
///
/// * `plaintext` - The plaintext to encrypt.
async fn encrypt_private(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError>;
/// Encrypts a plaintext provided by the other party(s).
///
/// Returns the ciphertext.
async fn encrypt_blind(&mut self) -> Result<Vec<u8>, BlockCipherError>;
/// Encrypts a plaintext provided by both parties. Fails if the
/// plaintext provided by both parties does not match.
///
/// Returns an additive share of the ciphertext.
///
/// # Arguments
///
/// * `plaintext` - The plaintext to encrypt.
async fn encrypt_share(&mut self, plaintext: Vec<u8>) -> Result<Vec<u8>, BlockCipherError>;
}
#[cfg(test)]
mod tests {
use super::*;
use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory};
use crate::circuit::Aes128;
use ::aes::Aes128 as TestAes128;
use ::cipher::{BlockEncrypt, KeyInit};
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
let mut msg = msg.into();
let cipher = TestAes128::new(&key.into());
cipher.encrypt_block(&mut msg);
msg.into()
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_block_cipher_blind() {
let leader_config = BlockCipherConfig::builder().id("test").build().unwrap();
let follower_config = BlockCipherConfig::builder().id("test").build().unwrap();
let key = [0u8; 16];
let (leader_vm, follower_vm) = create_mock_deap_vm();
// Key is public just for this test, typically it is private.
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
leader_vm.assign(&leader_key, key).unwrap();
follower_vm.assign(&follower_key, key).unwrap();
let mut leader = MpcBlockCipher::<Aes128, _>::new(leader_config, leader_vm);
leader.set_key(leader_key);
let mut follower = MpcBlockCipher::<Aes128, _>::new(follower_config, follower_vm);
follower.set_key(follower_key);
let plaintext = [0u8; 16];
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
leader.encrypt_private(plaintext.to_vec()),
follower.encrypt_blind()
)
.unwrap();
let expected = aes128(key, plaintext);
assert_eq!(leader_ciphertext, expected.to_vec());
assert_eq!(leader_ciphertext, follower_ciphertext);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_block_cipher_share() {
let leader_config = BlockCipherConfig::builder().id("test").build().unwrap();
let follower_config = BlockCipherConfig::builder().id("test").build().unwrap();
let key = [0u8; 16];
let (leader_vm, follower_vm) = create_mock_deap_vm();
// Key is public just for this test, typically it is private.
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
leader_vm.assign(&leader_key, key).unwrap();
follower_vm.assign(&follower_key, key).unwrap();
let mut leader = MpcBlockCipher::<Aes128, _>::new(leader_config, leader_vm);
leader.set_key(leader_key);
let mut follower = MpcBlockCipher::<Aes128, _>::new(follower_config, follower_vm);
follower.set_key(follower_key);
let plaintext = [0u8; 16];
let (leader_share, follower_share) = tokio::try_join!(
leader.encrypt_share(plaintext.to_vec()),
follower.encrypt_share(plaintext.to_vec())
)
.unwrap();
let expected = aes128(key, plaintext);
let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]);
assert_eq!(result, expected);
}
#[tokio::test]
#[ignore = "expensive"]
async fn test_block_cipher_preprocess() {
let leader_config = BlockCipherConfig::builder().id("test").build().unwrap();
let follower_config = BlockCipherConfig::builder().id("test").build().unwrap();
let key = [0u8; 16];
let (leader_vm, follower_vm) = create_mock_deap_vm();
// Key is public just for this test, typically it is private.
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
leader_vm.assign(&leader_key, key).unwrap();
follower_vm.assign(&follower_key, key).unwrap();
let mut leader = MpcBlockCipher::<Aes128, _>::new(leader_config, leader_vm);
leader.set_key(leader_key);
let mut follower = MpcBlockCipher::<Aes128, _>::new(follower_config, follower_vm);
follower.set_key(follower_key);
let plaintext = [0u8; 16];
tokio::try_join!(
leader.preprocess(Visibility::Private, 1),
follower.preprocess(Visibility::Blind, 1)
)
.unwrap();
let (leader_ciphertext, follower_ciphertext) = tokio::try_join!(
leader.encrypt_private(plaintext.to_vec()),
follower.encrypt_blind()
)
.unwrap();
let expected = aes128(key, plaintext);
assert_eq!(leader_ciphertext, expected.to_vec());
assert_eq!(leader_ciphertext, follower_ciphertext);
tokio::try_join!(
leader.preprocess(Visibility::Public, 1),
follower.preprocess(Visibility::Public, 1)
)
.unwrap();
let (leader_share, follower_share) = tokio::try_join!(
leader.encrypt_share(plaintext.to_vec()),
follower.encrypt_share(plaintext.to_vec())
)
.unwrap();
let expected = aes128(key, plaintext);
let result: [u8; 16] = std::array::from_fn(|i| leader_share[i] ^ follower_share[i]);
assert_eq!(result, expected);
}
}

View File

@@ -0,0 +1,31 @@
[package]
name = "tlsn-cipher"
authors = ["TLSNotary Team"]
description = "This crate provides implementations of ciphers for two parties"
keywords = ["tls", "mpc", "2pc", "aes"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.8-pre"
edition = "2021"
[lib]
name = "cipher"
[dependencies]
mpz-circuits = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
async-trait = { workspace = true }
thiserror = { workspace = true }
aes = { workspace = true }
[dev-dependencies]
mpz-garble = { workspace = true }
mpz-common = { workspace = true }
mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }

View File

@@ -1,26 +1,31 @@
use mpz_circuits::{circuits::aes128_trace, once_cell::sync::Lazy, trace, Circuit, CircuitBuilder};
use std::sync::Arc;
/// AES encrypts a counter block.
///
/// # Inputs
///
/// 0. KEY: 16-byte encryption key
/// 1. IV: 4-byte IV
/// 2. EXPLICIT_NONCE: 8-byte explicit nonce
/// 3. CTR: 4-byte counter
///
/// # Outputs
///
/// 0. ECB: 16-byte output
pub(crate) static AES_CTR: Lazy<Arc<Circuit>> = Lazy::new(|| {
/// `fn(key: [u8; 16], iv: [u8; 4], nonce: [u8; 8], ctr: [u8; 4]) -> [u8; 16]`
pub(crate) static AES128_CTR: Lazy<Arc<Circuit>> = Lazy::new(|| {
let builder = CircuitBuilder::new();
let key = builder.add_array_input::<u8, 16>();
let iv = builder.add_array_input::<u8, 4>();
let nonce = builder.add_array_input::<u8, 8>();
let ctr = builder.add_array_input::<u8, 4>();
let ecb = aes_ctr_trace(builder.state(), key, iv, nonce, ctr);
builder.add_output(ecb);
let keystream = aes_ctr_trace(builder.state(), key, iv, nonce, ctr);
builder.add_output(keystream);
Arc::new(builder.build().unwrap())
});
/// `fn(key: [u8; 16], msg: [u8; 16]) -> [u8; 16]`
pub(crate) static AES128_ECB: Lazy<Arc<Circuit>> = Lazy::new(|| {
let builder = CircuitBuilder::new();
let key = builder.add_array_input::<u8, 16>();
let message = builder.add_array_input::<u8, 16>();
let block = aes128_trace(builder.state(), key, message);
builder.add_output(block);
Arc::new(builder.build().unwrap())
});
@@ -45,13 +50,3 @@ fn aes_128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
aes.encrypt_block(&mut ciphertext);
ciphertext.into()
}
/// Builds a circuit for computing the XOR of two arrays.
pub(crate) fn build_array_xor(len: usize) -> Arc<Circuit> {
let builder = CircuitBuilder::new();
let a = builder.add_vec_input::<u8>(len);
let b = builder.add_vec_input::<u8>(len);
let c = a.into_iter().zip(b).map(|(a, b)| a ^ b).collect::<Vec<_>>();
builder.add_output(c);
Arc::new(builder.build().expect("circuit is valid"))
}

View File

@@ -0,0 +1,44 @@
use std::fmt::Display;
/// AES error.
#[derive(Debug, thiserror::Error)]
pub struct AesError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl AesError {
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum ErrorKind {
Vm,
Key,
Iv,
}
impl Display for AesError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ErrorKind::Vm => write!(f, "vm error")?,
ErrorKind::Key => write!(f, "key error")?,
ErrorKind::Iv => write!(f, "iv error")?,
}
if let Some(source) = &self.source {
write!(f, " caused by: {}", source)?;
}
Ok(())
}
}

View File

@@ -0,0 +1,375 @@
//! The AES-128 block cipher.
use crate::{Cipher, CtrBlock, Keystream};
use async_trait::async_trait;
use mpz_memory_core::binary::{Binary, U8};
use mpz_vm_core::{prelude::*, Call, Vm};
use std::fmt::Debug;
mod circuit;
mod error;
pub use error::AesError;
use error::ErrorKind;
/// Computes AES-128.
#[derive(Default, Debug)]
pub struct Aes128 {
key: Option<Array<U8, 16>>,
iv: Option<Array<U8, 4>>,
}
#[async_trait]
impl Cipher for Aes128 {
type Error = AesError;
type Key = Array<U8, 16>;
type Iv = Array<U8, 4>;
type Nonce = Array<U8, 8>;
type Counter = Array<U8, 4>;
type Block = Array<U8, 16>;
fn set_key(&mut self, key: Array<U8, 16>) {
self.key = Some(key);
}
fn set_iv(&mut self, iv: Array<U8, 4>) {
self.iv = Some(iv);
}
fn key(&self) -> Option<&Array<U8, 16>> {
self.key.as_ref()
}
fn iv(&self) -> Option<&Array<U8, 4>> {
self.iv.as_ref()
}
fn alloc_block(
&self,
vm: &mut dyn Vm<Binary>,
input: Array<U8, 16>,
) -> Result<Self::Block, Self::Error> {
let key = self
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let output = vm
.call(
Call::new(circuit::AES128_ECB.clone())
.arg(key)
.arg(input)
.build()
.expect("call should be valid"),
)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok(output)
}
fn alloc_ctr_block(
&self,
vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
let key = self
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self
.iv
.ok_or_else(|| AesError::new(ErrorKind::Iv, "iv not set"))?;
let explicit_nonce: Array<U8, 8> = vm
.alloc()
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
vm.mark_public(explicit_nonce)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
let counter: Array<U8, 4> = vm
.alloc()
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
vm.mark_public(counter)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
let output = vm
.call(
Call::new(circuit::AES128_CTR.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(counter)
.build()
.expect("call should be valid"),
)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok(CtrBlock {
explicit_nonce,
counter,
output,
})
}
fn alloc_keystream(
&self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
let key = self
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self
.iv
.ok_or_else(|| AesError::new(ErrorKind::Iv, "iv not set"))?;
let block_count = len.div_ceil(16);
let inputs = (0..block_count)
.map(|_| {
let explicit_nonce: Array<U8, 8> = vm
.alloc()
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
let counter: Array<U8, 4> = vm
.alloc()
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
vm.mark_public(explicit_nonce)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
vm.mark_public(counter)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok((explicit_nonce, counter))
})
.collect::<Result<Vec<_>, AesError>>()?;
let blocks = inputs
.into_iter()
.map(|(explicit_nonce, counter)| {
let output = vm
.call(
Call::new(circuit::AES128_CTR.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(counter)
.build()
.expect("call should be valid"),
)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok(CtrBlock {
explicit_nonce,
counter,
output,
})
})
.collect::<Result<Vec<_>, AesError>>()?;
Ok(Keystream::new(&blocks))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Cipher;
use mpz_common::context::test_st_context;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_memory_core::{
binary::{Binary, U8},
correlated::Delta,
Array, MemoryExt, Vector, ViewExt,
};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{Execute, Vm};
use rand::{rngs::StdRng, SeedableRng};
#[tokio::test]
async fn test_aes_ctr() {
let key = [42_u8; 16];
let iv = [3_u8; 4];
let nonce = [5_u8; 8];
let start_counter = 3u32;
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm();
let aes_gen = setup_ctr(key, iv, &mut gen);
let aes_ev = setup_ctr(key, iv, &mut ev);
let msg = vec![42u8; 128];
let keystream_gen = aes_gen.alloc_keystream(&mut gen, msg.len()).unwrap();
let keystream_ev = aes_ev.alloc_keystream(&mut ev, msg.len()).unwrap();
let msg_ref_gen: Vector<U8> = gen.alloc_vec(msg.len()).unwrap();
gen.mark_public(msg_ref_gen).unwrap();
gen.assign(msg_ref_gen, msg.clone()).unwrap();
gen.commit(msg_ref_gen).unwrap();
let msg_ref_ev: Vector<U8> = ev.alloc_vec(msg.len()).unwrap();
ev.mark_public(msg_ref_ev).unwrap();
ev.assign(msg_ref_ev, msg.clone()).unwrap();
ev.commit(msg_ref_ev).unwrap();
let mut ctr = start_counter..;
keystream_gen
.assign(&mut gen, nonce, move || ctr.next().unwrap().to_be_bytes())
.unwrap();
let mut ctr = start_counter..;
keystream_ev
.assign(&mut ev, nonce, move || ctr.next().unwrap().to_be_bytes())
.unwrap();
let cipher_out_gen = keystream_gen.apply(&mut gen, msg_ref_gen).unwrap();
let cipher_out_ev = keystream_ev.apply(&mut ev, msg_ref_ev).unwrap();
let (ct_gen, ct_ev) = tokio::try_join!(
async {
let out = gen.decode(cipher_out_gen).unwrap();
gen.flush(&mut ctx_a).await.unwrap();
gen.execute(&mut ctx_a).await.unwrap();
gen.flush(&mut ctx_a).await.unwrap();
out.await
},
async {
let out = ev.decode(cipher_out_ev).unwrap();
ev.flush(&mut ctx_b).await.unwrap();
ev.execute(&mut ctx_b).await.unwrap();
ev.flush(&mut ctx_b).await.unwrap();
out.await
}
)
.unwrap();
assert_eq!(ct_gen, ct_ev);
let expected = aes_apply_keystream(key, iv, nonce, start_counter as usize, msg);
assert_eq!(ct_gen, expected);
}
#[tokio::test]
async fn test_aes_ecb() {
let key = [1_u8; 16];
let input = [5_u8; 16];
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm();
let aes_gen = setup_block(key, &mut gen);
let aes_ev = setup_block(key, &mut ev);
let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap();
gen.mark_public(block_ref_gen).unwrap();
gen.assign(block_ref_gen, input).unwrap();
gen.commit(block_ref_gen).unwrap();
let block_ref_ev: Array<U8, 16> = ev.alloc().unwrap();
ev.mark_public(block_ref_ev).unwrap();
ev.assign(block_ref_ev, input).unwrap();
ev.commit(block_ref_ev).unwrap();
let block_gen = aes_gen.alloc_block(&mut gen, block_ref_gen).unwrap();
let block_ev = aes_ev.alloc_block(&mut ev, block_ref_ev).unwrap();
let (ciphertext_gen, ciphetext_ev) = tokio::try_join!(
async {
let out = gen.decode(block_gen).unwrap();
gen.flush(&mut ctx_a).await.unwrap();
gen.execute(&mut ctx_a).await.unwrap();
gen.flush(&mut ctx_a).await.unwrap();
out.await
},
async {
let out = ev.decode(block_ev).unwrap();
ev.flush(&mut ctx_b).await.unwrap();
ev.execute(&mut ctx_b).await.unwrap();
ev.flush(&mut ctx_b).await.unwrap();
out.await
}
)
.unwrap();
assert_eq!(ciphertext_gen, ciphetext_ev);
let expected = aes128(key, input);
assert_eq!(ciphertext_gen, expected);
}
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Generator::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 {
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
vm.mark_public(key_ref).unwrap();
vm.assign(key_ref, key).unwrap();
vm.commit(key_ref).unwrap();
let iv_ref: Array<U8, 4> = vm.alloc().unwrap();
vm.mark_public(iv_ref).unwrap();
vm.assign(iv_ref, iv).unwrap();
vm.commit(iv_ref).unwrap();
let mut aes = Aes128::default();
aes.set_key(key_ref);
aes.set_iv(iv_ref);
aes
}
fn setup_block(key: [u8; 16], vm: &mut dyn Vm<Binary>) -> Aes128 {
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
vm.mark_public(key_ref).unwrap();
vm.assign(key_ref, key).unwrap();
vm.commit(key_ref).unwrap();
let mut aes = Aes128::default();
aes.set_key(key_ref);
aes
}
fn aes_apply_keystream(
key: [u8; 16],
iv: [u8; 4],
explicit_nonce: [u8; 8],
start_ctr: usize,
msg: Vec<u8>,
) -> Vec<u8> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use ctr::Ctr32BE;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(&iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
let mut out = msg.clone();
cipher
.try_seek(start_ctr * 16)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut out);
out
}
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
use ::aes::Aes128 as TestAes128;
use ::cipher::{BlockEncrypt, KeyInit};
let mut msg = msg.into();
let cipher = TestAes128::new(&key.into());
cipher.encrypt_block(&mut msg);
msg.into()
}
}

View File

@@ -0,0 +1,23 @@
//! Ciphers and circuits.
use mpz_circuits::{types::ValueType, Circuit, CircuitBuilder, Tracer};
use std::sync::Arc;
/// Builds a circuit which XORs the provided values.
pub(crate) fn build_xor_circuit(inputs: &[ValueType]) -> Arc<Circuit> {
let builder = CircuitBuilder::new();
for input_ty in inputs {
let input_0 = builder.add_input_by_type(input_ty.clone());
let input_1 = builder.add_input_by_type(input_ty.clone());
let input_0 = Tracer::new(builder.state(), input_0);
let input_1 = Tracer::new(builder.state(), input_1);
let output = input_0 ^ input_1;
builder.add_output(output);
}
let circ = builder.build().expect("circuit should be valid");
Arc::new(circ)
}

View File

@@ -0,0 +1,299 @@
//! This crate provides implementations of 2PC ciphers for encryption with a
//! shared key.
//!
//! Both parties can work together to encrypt and decrypt messages with
//! different visibility configurations. See [`Cipher`] and [`Keystream`] for
//! more information on the interface.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub mod aes;
mod circuit;
use async_trait::async_trait;
use circuit::build_xor_circuit;
use mpz_circuits::types::ValueType;
use mpz_memory_core::{
binary::{Binary, U8},
FromRaw, MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
};
use mpz_vm_core::{prelude::*, CallBuilder, CallError, Vm};
use std::collections::VecDeque;
/// Provides computation of 2PC ciphers in counter and ECB mode.
///
/// After setting `key` and `iv` allows to compute the keystream via
/// [`Cipher::alloc`] or a single block in ECB mode via
/// [`Cipher::assign_block`]. [`Keystream`] provides more tooling to compute the
/// final cipher output in counter mode.
#[async_trait]
pub trait Cipher {
/// The error type for the cipher.
type Error: std::error::Error + Send + Sync + 'static;
/// Cipher key.
type Key;
/// Cipher IV.
type Iv;
/// Cipher nonce.
type Nonce;
/// Cipher counter.
type Counter;
/// Cipher block.
type Block;
/// Sets the key.
fn set_key(&mut self, key: Self::Key);
/// Sets the initialization vector.
fn set_iv(&mut self, iv: Self::Iv);
/// Returns the key reference.
fn key(&self) -> Option<&Self::Key>;
/// Returns the iv reference.
fn iv(&self) -> Option<&Self::Iv>;
/// Allocates a single block in ECB mode.
fn alloc_block(
&self,
vm: &mut dyn Vm<Binary>,
input: Self::Block,
) -> Result<Self::Block, Self::Error>;
/// Allocates a single block in counter mode.
#[allow(clippy::type_complexity)]
fn alloc_ctr_block(
&self,
vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
/// Allocates a keystream in counter mode.
///
/// # Arguments
///
/// * `vm` - Virtual machine to allocate into.
/// * `len` - Length of the stream in bytes.
#[allow(clippy::type_complexity)]
fn alloc_keystream(
&self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
}
/// A block in counter mode.
#[derive(Debug, Clone, Copy)]
pub struct CtrBlock<N, C, O> {
/// Explicit nonce reference.
pub explicit_nonce: N,
/// Counter reference.
pub counter: C,
/// Output reference.
pub output: O,
}
/// The keystream of the cipher.
///
/// Can be used to XOR with the cipher input to operate the cipher in counter
/// mode.
pub struct Keystream<N, C, O> {
blocks: VecDeque<CtrBlock<N, C, O>>,
}
impl<N, C, O> Default for Keystream<N, C, O> {
fn default() -> Self {
Self {
blocks: VecDeque::new(),
}
}
}
impl<N, C, O> Keystream<N, C, O>
where
N: Repr<Binary> + StaticSize<Binary> + Copy,
C: Repr<Binary> + StaticSize<Binary> + Copy,
O: Repr<Binary> + StaticSize<Binary> + Copy,
{
/// Creates a new keystream from the provided blocks.
///
/// # Panics
///
/// * If the output of the keystream is not ordered and contiguous in
/// memory.
pub fn new(blocks: &[CtrBlock<N, C, O>]) -> Self {
let mut pos = blocks
.first()
.map(|block| block.output.to_raw().ptr().as_usize())
.unwrap_or(0);
for block in blocks {
if block.output.to_raw().ptr().as_usize() != pos {
panic!("output of keystream blocks must be ordered and contiguous in memory");
}
pos += O::SIZE;
}
Self {
blocks: VecDeque::from_iter(blocks.iter().copied()),
}
}
/// Consumes keystream material.
///
/// Returns the consumed keystream material, leaving the remaining material
/// in place.
///
/// # Arguments
///
/// * `len` - Length of the keystream in bytes to return.
pub fn consume(&mut self, len: usize) -> Result<Self, CipherError> {
let block_count = len.div_ceil(self.block_size());
if block_count > self.blocks.len() {
return Err(CipherError::new("insufficient keystream"));
}
let blocks = self.blocks.split_off(self.blocks.len() - block_count);
Ok(Self { blocks })
}
/// Applies the keystream to the provided input.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `input` - Input data.
pub fn apply(
&self,
vm: &mut dyn Vm<Binary>,
input: Vector<U8>,
) -> Result<Vector<U8>, CipherError> {
if input.len() != self.len() {
return Err(CipherError::new("input length must match keystream length"));
} else if self.blocks.is_empty() {
return Err(CipherError::new("no keystream material available"));
}
let xor = build_xor_circuit(&[ValueType::new_array::<u8>(self.block_size())]);
let mut pos = 0;
let mut outputs = Vec::with_capacity(self.blocks.len());
for block in &self.blocks {
let call = CallBuilder::new(xor.clone())
.arg(block.output)
.arg(
input
.get(pos..pos + self.block_size())
.expect("input length was checked"),
)
.build()?;
let output: Vector<U8> = vm.call(call).map_err(CipherError::new)?;
outputs.push(output);
pos += self.block_size();
}
// Calls were performed contiguously, so the output data is contiguous.
let ptr = outputs
.first()
.map(|output| output.to_raw().ptr())
.expect("keystream is not empty");
let size = self.blocks.len() * O::SIZE;
let output = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
Ok(output)
}
/// Returns `len` bytes of the keystream as a vector.
pub fn to_vector(&self, len: usize) -> Result<Vector<U8>, CipherError> {
if len == 0 {
return Err(CipherError::new("length must be greater than 0"));
} else if self.blocks.is_empty() {
return Err(CipherError::new("no keystream material available"));
}
let block_count = len.div_ceil(self.block_size());
if block_count != self.blocks.len() {
return Err(CipherError::new("length does not match keystream length"));
}
let ptr = self
.blocks
.front()
.map(|block| block.output.to_raw().ptr())
.expect("block count should be greater than 0");
let size = block_count * O::SIZE;
let mut keystream = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
keystream.truncate(len);
Ok(keystream)
}
/// Assigns the keystream inputs.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `ctr` - Counter function. The provided function will be called to
/// assign the counter values for each block.
pub fn assign(
&self,
vm: &mut dyn Vm<Binary>,
explicit_nonce: N::Clear,
mut ctr: impl FnMut() -> C::Clear,
) -> Result<(), CipherError>
where
N::Clear: Copy,
C::Clear: Copy,
{
for block in &self.blocks {
vm.assign(block.explicit_nonce, explicit_nonce)
.map_err(CipherError::new)?;
vm.commit(block.explicit_nonce).map_err(CipherError::new)?;
vm.assign(block.counter, ctr()).map_err(CipherError::new)?;
vm.commit(block.counter).map_err(CipherError::new)?;
}
Ok(())
}
/// Returns the block size in bytes.
fn block_size(&self) -> usize {
O::SIZE / 8
}
/// Returns the length of the keystream in bytes.
fn len(&self) -> usize {
self.block_size() * self.blocks.len()
}
}
/// A cipher error.
#[derive(Debug, thiserror::Error)]
#[error("{source}")]
pub struct CipherError {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
}
impl CipherError {
pub(crate) fn new<E>(source: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
source: source.into(),
}
}
}
impl From<CallError> for CipherError {
fn from(value: CallError) -> Self {
Self::new(value)
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "tlsn-deap"
version = "0.1.0"
edition = "2021"
[dependencies]
mpz-core = { workspace = true }
mpz-common = { workspace = true }
mpz-vm-core = { workspace = true }
tlsn-utils = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serio = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
mpz-ot = { workspace = true }
mpz-zk = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }

View File

@@ -0,0 +1,516 @@
//! Dual-execution with Asymmetric Privacy (DEAP) protocol.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
use std::{
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use async_trait::async_trait;
use mpz_common::{scoped_futures::ScopedFutureExt as _, Context};
use mpz_core::bitvec::BitVec;
use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
use utils::range::{Difference, RangeSet, UnionMut};
type Error = DeapError;
/// The role of the DEAP VM.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum Role {
Leader,
Follower,
}
/// DEAP Vm.
#[derive(Debug)]
pub struct Deap<Mpc, Zk> {
role: Role,
mpc: Arc<Mutex<Mpc>>,
zk: Arc<Mutex<Zk>>,
/// Private inputs of the follower.
follower_inputs: RangeSet<usize>,
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
/// Whether the memories of the two VMs are potentially desynchronized.
desync: AtomicBool,
}
impl<Mpc, Zk> Deap<Mpc, Zk> {
/// Create a new DEAP Vm.
pub fn new(role: Role, mpc: Mpc, zk: Zk) -> Self {
Self {
role,
mpc: Arc::new(Mutex::new(mpc)),
zk: Arc::new(Mutex::new(zk)),
follower_inputs: RangeSet::default(),
outputs: Vec::default(),
desync: AtomicBool::new(false),
}
}
/// Returns the MPC and ZK VMs.
pub fn into_inner(self) -> (Mpc, Zk) {
(
Arc::into_inner(self.mpc).unwrap().into_inner(),
Arc::into_inner(self.zk).unwrap().into_inner(),
)
}
/// Returns a mutable reference to the ZK VM.
///
/// # Note
///
/// After calling this method, allocations will no longer be allowed in the
/// DEAP VM as the memory will potentially be desynchronized.
///
/// # Panics
///
/// Panics if the mutex locked by another thread.
pub fn zk(&self) -> MutexGuard<'_, Zk> {
self.desync.store(true, Ordering::Relaxed);
self.zk.try_lock().unwrap()
}
/// Returns an owned mutex guard to the ZK VM.
///
/// # Note
///
/// After calling this method, allocations will no longer be allowed in the
/// DEAP VM as the memory will potentially be desynchronized.
///
/// # Panics
///
/// Panics if the mutex locked by another thread.
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
self.desync.store(true, Ordering::Relaxed);
self.zk.clone().try_lock_owned().unwrap()
}
#[cfg(test)]
fn mpc(&self) -> MutexGuard<'_, Mpc> {
self.mpc.try_lock().unwrap()
}
}
impl<Mpc, Zk> Deap<Mpc, Zk>
where
Mpc: Vm<Binary> + Send + 'static,
Zk: Vm<Binary> + Send + 'static,
{
/// Finalize the DEAP Vm.
///
/// This reveals all private inputs of the follower.
pub async fn finalize(&mut self, ctx: &mut Context) -> Result<(), VmError> {
let mut mpc = self.mpc.try_lock().unwrap();
let mut zk = self.zk.try_lock().unwrap();
// Decode the private inputs of the follower.
//
// # Security
//
// This assumes that the decoding process is authenticated from the leader's
// perspective. In the case of garbled circuits, the leader should be the
// generator such that the follower proves their inputs using their committed
// MACs.
let input_futs = self
.follower_inputs
.iter_ranges()
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
.collect::<Result<Vec<_>, _>>()?;
mpc.execute_all(ctx).await?;
// Assign inputs to the ZK VM.
for (mut decode, input) in input_futs
.into_iter()
.zip(self.follower_inputs.iter_ranges())
{
let input = Slice::from_range_unchecked(input);
// Follower has already assigned the inputs.
if let Role::Leader = self.role {
let value = decode
.try_recv()
.map_err(VmError::memory)?
.expect("input should be decoded");
zk.assign_raw(input, value)?;
}
// Now the follower's inputs are public.
zk.commit_raw(input)?;
}
zk.execute_all(ctx).await.map_err(VmError::execute)?;
// Follower verifies the outputs are consistent.
if let Role::Follower = self.role {
for (output, mut value) in mem::take(&mut self.outputs) {
// If the output is not available in the MPC VM, we did not execute and decode
// it. Therefore, we do not need to check for equality.
//
// This can occur if some function was preprocessed but ultimately not used.
if let Some(mpc_output) = mpc.get_raw(output)? {
let zk_output = value
.try_recv()
.map_err(VmError::memory)?
.expect("output should be decoded");
// Asserts equality of all the output values from both VMs.
if zk_output != mpc_output {
return Err(VmError::execute(Error::from(ErrorRepr::EqualityCheck)));
}
}
}
}
Ok(())
}
}
impl<Mpc, Zk> Memory<Binary> for Deap<Mpc, Zk>
where
Mpc: Memory<Binary, Error = VmError>,
Zk: Memory<Binary, Error = VmError>,
{
type Error = VmError;
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
if self.desync.load(Ordering::Relaxed) {
return Err(VmError::memory(
"DEAP VM memories are potentially desynchronized",
));
}
self.zk.try_lock().unwrap().alloc_raw(size)?;
self.mpc.try_lock().unwrap().alloc_raw(size)
}
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
self.zk
.try_lock()
.unwrap()
.assign_raw(slice, data.clone())?;
self.mpc.try_lock().unwrap().assign_raw(slice, data)
}
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower.iter_ranges() {
zk.commit_raw(Slice::from_range_unchecked(input))?;
}
self.mpc.try_lock().unwrap().commit_raw(slice)
}
fn get_raw(&self, slice: Slice) -> Result<Option<BitVec>, VmError> {
self.mpc.try_lock().unwrap().get_raw(slice)
}
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
self.outputs.push((slice, fut));
self.mpc.try_lock().unwrap().decode_raw(slice)
}
}
impl<Mpc, Zk> View<Binary> for Deap<Mpc, Zk>
where
Mpc: View<Binary, Error = VmError>,
Zk: View<Binary, Error = VmError>,
{
type Error = VmError;
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
self.mpc.try_lock().unwrap().mark_public_raw(slice)
}
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let mut zk = self.zk.try_lock().unwrap();
let mut mpc = self.mpc.try_lock().unwrap();
match self.role {
Role::Leader => {
zk.mark_private_raw(slice)?;
mpc.mark_private_raw(slice)?;
}
Role::Follower => {
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(slice)?;
mpc.mark_private_raw(slice)?;
self.follower_inputs.union_mut(&slice.to_range());
}
}
Ok(())
}
fn mark_blind_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let mut zk = self.zk.try_lock().unwrap();
let mut mpc = self.mpc.try_lock().unwrap();
match self.role {
Role::Leader => {
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(slice)?;
mpc.mark_blind_raw(slice)?;
self.follower_inputs.union_mut(&slice.to_range());
}
Role::Follower => {
zk.mark_blind_raw(slice)?;
mpc.mark_blind_raw(slice)?;
}
}
Ok(())
}
}
impl<Mpc, Zk> Callable<Binary> for Deap<Mpc, Zk>
where
Mpc: Vm<Binary>,
Zk: Vm<Binary>,
{
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
if self.desync.load(Ordering::Relaxed) {
return Err(VmError::memory(
"DEAP VM memories are potentially desynchronized",
));
}
self.zk.try_lock().unwrap().call_raw(call.clone())?;
self.mpc.try_lock().unwrap().call_raw(call)
}
}
#[async_trait]
impl<Mpc, Zk> Execute for Deap<Mpc, Zk>
where
Mpc: Execute + Send + 'static,
Zk: Execute + Send + 'static,
{
fn wants_flush(&self) -> bool {
self.mpc.try_lock().unwrap().wants_flush() || self.zk.try_lock().unwrap().wants_flush()
}
async fn flush(&mut self, ctx: &mut Context) -> Result<(), VmError> {
let mut zk = self.zk.clone().try_lock_owned().unwrap();
let mut mpc = self.mpc.clone().try_lock_owned().unwrap();
ctx.try_join(
|ctx| async move { zk.flush(ctx).await }.scope_boxed(),
|ctx| async move { mpc.flush(ctx).await }.scope_boxed(),
)
.await
.map_err(VmError::execute)??;
Ok(())
}
fn wants_preprocess(&self) -> bool {
self.mpc.try_lock().unwrap().wants_preprocess()
|| self.zk.try_lock().unwrap().wants_preprocess()
}
async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), VmError> {
let mut zk = self.zk.clone().try_lock_owned().unwrap();
let mut mpc = self.mpc.clone().try_lock_owned().unwrap();
ctx.try_join(
|ctx| async move { zk.preprocess(ctx).await }.scope_boxed(),
|ctx| async move { mpc.preprocess(ctx).await }.scope_boxed(),
)
.await
.map_err(VmError::execute)??;
Ok(())
}
fn wants_execute(&self) -> bool {
self.mpc.try_lock().unwrap().wants_execute()
}
async fn execute(&mut self, ctx: &mut Context) -> Result<(), VmError> {
// Only MPC VM is executed until finalization.
self.mpc.try_lock().unwrap().execute(ctx).await
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct DeapError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("equality check failed")]
EqualityCheck,
}
#[cfg(test)]
mod tests {
use mpz_circuits::circuits::AES128;
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
};
use mpz_zk::{Prover, Verifier};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
#[tokio::test]
async fn test_deap() {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gb = Generator::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta, rcot_send);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let (ct_leader, ct_follower) = futures::join!(
async {
let key: Array<U8, 16> = leader.alloc().unwrap();
let msg: Array<U8, 16> = leader.alloc().unwrap();
leader.mark_private(key).unwrap();
leader.mark_blind(msg).unwrap();
leader.assign(key, [42u8; 16]).unwrap();
leader.commit(key).unwrap();
leader.commit(msg).unwrap();
let ct: Array<U8, 16> = leader
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
.unwrap();
let ct = leader.decode(ct).unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.execute(&mut ctx_a).await.unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.finalize(&mut ctx_a).await.unwrap();
ct.await.unwrap()
},
async {
let key: Array<U8, 16> = follower.alloc().unwrap();
let msg: Array<U8, 16> = follower.alloc().unwrap();
follower.mark_blind(key).unwrap();
follower.mark_private(msg).unwrap();
follower.assign(msg, [69u8; 16]).unwrap();
follower.commit(key).unwrap();
follower.commit(msg).unwrap();
let ct: Array<U8, 16> = follower
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
.unwrap();
let ct = follower.decode(ct).unwrap();
follower.flush(&mut ctx_b).await.unwrap();
follower.execute(&mut ctx_b).await.unwrap();
follower.flush(&mut ctx_b).await.unwrap();
follower.finalize(&mut ctx_b).await.unwrap();
ct.await.unwrap()
}
);
assert_eq!(ct_leader, ct_follower);
}
// Tests that the leader can not use different inputs in each VM without
// detection by the follower.
#[tokio::test]
async fn test_malicious() {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gb = Generator::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta, rcot_send);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let (_, follower_res) = futures::join!(
async {
let key: Array<U8, 16> = leader.alloc().unwrap();
let msg: Array<U8, 16> = leader.alloc().unwrap();
leader.mark_private(key).unwrap();
leader.mark_blind(msg).unwrap();
// Use different inputs in each VM.
leader.mpc().assign(key, [42u8; 16]).unwrap();
leader
.zk
.try_lock()
.unwrap()
.assign(key, [69u8; 16])
.unwrap();
leader.commit(key).unwrap();
leader.commit(msg).unwrap();
let ct: Array<U8, 16> = leader
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
.unwrap();
let ct = leader.decode(ct).unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.execute(&mut ctx_a).await.unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.finalize(&mut ctx_a).await.unwrap();
ct.await.unwrap()
},
async {
let key: Array<U8, 16> = follower.alloc().unwrap();
let msg: Array<U8, 16> = follower.alloc().unwrap();
follower.mark_blind(key).unwrap();
follower.mark_private(msg).unwrap();
follower.assign(msg, [69u8; 16]).unwrap();
follower.commit(key).unwrap();
follower.commit(msg).unwrap();
let ct: Array<U8, 16> = follower
.call(Call::new(AES128.clone()).arg(key).arg(msg).build().unwrap())
.unwrap();
drop(follower.decode(ct).unwrap());
follower.flush(&mut ctx_b).await.unwrap();
follower.execute(&mut ctx_b).await.unwrap();
follower.flush(&mut ctx_b).await.unwrap();
follower.finalize(&mut ctx_b).await
}
);
assert!(follower_res.is_err());
}
}

View File

@@ -19,21 +19,23 @@ mock = []
[dependencies]
tlsn-hmac-sha256-circuits = { workspace = true }
mpz-garble = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-common = { workspace = true }
mpz-common = { workspace = true, features = ["cpu"] }
async-trait = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
futures = { workspace = true }
[dev-dependencies]
criterion = { workspace = true, features = ["async_tokio"] }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] }
criterion = { workspace = true, features = ["async_tokio"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
[[bench]]
name = "prf"

View File

@@ -1,9 +1,16 @@
#![allow(clippy::let_underscore_future)]
use criterion::{criterion_group, criterion_main, Criterion};
use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role};
use mpz_common::executor::test_mt_executor;
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory};
use mpz_ot::ideal::ot::ideal_ot;
use hmac_sha256::{MpcPrf, PrfConfig, Role};
use mpz_common::context::test_mt_context;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
};
use rand::{rngs::StdRng, SeedableRng};
#[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) {
@@ -11,178 +18,113 @@ fn criterion_benchmark(c: &mut Criterion) {
group.sample_size(10);
let rt = tokio::runtime::Runtime::new().unwrap();
group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess));
group.bench_function("prf", |b| b.to_async(&rt).iter(prf));
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
async fn preprocess() {
let (mut leader_exec, mut follower_exec) = test_mt_executor(128);
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
let leader_thread_0 = DEAPThread::new(
DEAPRole::Leader,
[0u8; 32],
leader_exec.new_thread().await.unwrap(),
leader_ot_send_0,
leader_ot_recv_0,
);
let leader_thread_1 = leader_thread_0
.new_thread(
leader_exec.new_thread().await.unwrap(),
leader_ot_send_1,
leader_ot_recv_1,
)
.unwrap();
let follower_thread_0 = DEAPThread::new(
DEAPRole::Follower,
[0u8; 32],
follower_exec.new_thread().await.unwrap(),
follower_ot_send_0,
follower_ot_recv_0,
);
let follower_thread_1 = follower_thread_0
.new_thread(
follower_exec.new_thread().await.unwrap(),
follower_ot_send_1,
follower_ot_recv_1,
)
.unwrap();
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
let follower_pms = follower_thread_0
.new_public_input::<[u8; 32]>("pms")
.unwrap();
let mut leader = MpcPrf::new(
PrfConfig::builder().role(Role::Leader).build().unwrap(),
leader_thread_0,
leader_thread_1,
);
let mut follower = MpcPrf::new(
PrfConfig::builder().role(Role::Follower).build().unwrap(),
follower_thread_0,
follower_thread_1,
);
futures::join!(
async {
leader.setup(leader_pms).await.unwrap();
leader.set_client_random(Some([0u8; 32])).await.unwrap();
leader.preprocess().await.unwrap();
},
async {
follower.setup(follower_pms).await.unwrap();
follower.set_client_random(None).await.unwrap();
follower.preprocess().await.unwrap();
}
);
}
async fn prf() {
let (mut leader_exec, mut follower_exec) = test_mt_executor(128);
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
let leader_thread_0 = DEAPThread::new(
DEAPRole::Leader,
[0u8; 32],
leader_exec.new_thread().await.unwrap(),
leader_ot_send_0,
leader_ot_recv_0,
);
let leader_thread_1 = leader_thread_0
.new_thread(
leader_exec.new_thread().await.unwrap(),
leader_ot_send_1,
leader_ot_recv_1,
)
.unwrap();
let follower_thread_0 = DEAPThread::new(
DEAPRole::Follower,
[0u8; 32],
follower_exec.new_thread().await.unwrap(),
follower_ot_send_0,
follower_ot_recv_0,
);
let follower_thread_1 = follower_thread_0
.new_thread(
follower_exec.new_thread().await.unwrap(),
follower_ot_send_1,
follower_ot_recv_1,
)
.unwrap();
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
let follower_pms = follower_thread_0
.new_public_input::<[u8; 32]>("pms")
.unwrap();
let mut leader = MpcPrf::new(
PrfConfig::builder().role(Role::Leader).build().unwrap(),
leader_thread_0,
leader_thread_1,
);
let mut follower = MpcPrf::new(
PrfConfig::builder().role(Role::Follower).build().unwrap(),
follower_thread_0,
follower_thread_1,
);
let mut rng = StdRng::seed_from_u64(0);
let pms = [42u8; 32];
let client_random = [0u8; 32];
let server_random = [1u8; 32];
let cf_hs_hash = [2u8; 32];
let sf_hs_hash = [3u8; 32];
let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32];
let (mut leader_exec, mut follower_exec) = test_mt_context(8);
let mut leader_ctx = leader_exec.new_context().await.unwrap();
let mut follower_ctx = follower_exec.new_context().await.unwrap();
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap();
leader_vm.assign(leader_pms, pms).unwrap();
leader_vm.commit(leader_pms).unwrap();
let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
follower_vm.mark_public(follower_pms).unwrap();
follower_vm.assign(follower_pms, pms).unwrap();
follower_vm.commit(follower_pms).unwrap();
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
leader
.set_client_random(&mut leader_vm, Some(client_random))
.unwrap();
follower.set_client_random(&mut follower_vm, None).unwrap();
leader
.set_server_random(&mut leader_vm, server_random)
.unwrap();
follower
.set_server_random(&mut follower_vm, server_random)
.unwrap();
let _ = leader_vm
.decode(leader_output.keys.client_write_key)
.unwrap();
let _ = leader_vm
.decode(leader_output.keys.server_write_key)
.unwrap();
let _ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
let _ = leader_vm.decode(leader_output.keys.server_iv).unwrap();
let _ = follower_vm
.decode(follower_output.keys.client_write_key)
.unwrap();
let _ = follower_vm
.decode(follower_output.keys.server_write_key)
.unwrap();
let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap();
futures::join!(
async {
leader.setup(leader_pms.clone()).await.unwrap();
leader.set_client_random(Some(client_random)).await.unwrap();
leader.preprocess().await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower.setup(follower_pms.clone()).await.unwrap();
follower.set_client_random(None).await.unwrap();
follower.preprocess().await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
leader.thread_mut().assign(&leader_pms, pms).unwrap();
follower.thread_mut().assign(&follower_pms, pms).unwrap();
let cf_hs_hash = [1u8; 32];
let sf_hs_hash = [2u8; 32];
let (_leader_keys, _follower_keys) = futures::try_join!(
leader.compute_session_keys(server_random),
follower.compute_session_keys(server_random)
)
.unwrap();
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
let _ = futures::try_join!(
leader.compute_client_finished_vd(cf_hs_hash),
follower.compute_client_finished_vd(cf_hs_hash)
)
.unwrap();
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
let _ = futures::try_join!(
leader.compute_server_finished_vd(sf_hs_hash),
follower.compute_server_finished_vd(sf_hs_hash)
)
.unwrap();
let _ = leader_vm.decode(leader_output.cf_vd).unwrap();
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();
futures::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
let _ = follower_vm.decode(follower_output.cf_vd).unwrap();
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();
futures::join!(
async {
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
}

View File

@@ -33,6 +33,10 @@ impl PrfError {
source: Some(msg.into().into()),
}
}
pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
Self::new(ErrorKind::Vm, err)
}
}
#[derive(Debug)]
@@ -58,26 +62,8 @@ impl fmt::Display for PrfError {
}
}
impl From<mpz_garble::MemoryError> for PrfError {
fn from(error: mpz_garble::MemoryError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::LoadError> for PrfError {
fn from(error: mpz_garble::LoadError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::ExecutionError> for PrfError {
fn from(error: mpz_garble::ExecutionError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::DecodeError> for PrfError {
fn from(error: mpz_garble::DecodeError) -> Self {
impl From<mpz_common::ContextError> for PrfError {
fn from(error: mpz_common::ContextError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

View File

@@ -12,84 +12,52 @@ pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role};
pub use error::PrfError;
pub use prf::MpcPrf;
use async_trait::async_trait;
use mpz_garble::value::ValueRef;
use mpz_vm_core::memory::{binary::U8, Array};
pub(crate) static CF_LABEL: &[u8] = b"client finished";
pub(crate) static SF_LABEL: &[u8] = b"server finished";
/// Session keys computed by the PRF.
#[derive(Debug, Clone)]
pub struct SessionKeys {
/// Client write key.
pub client_write_key: ValueRef,
/// Server write key.
pub server_write_key: ValueRef,
/// Client IV.
pub client_iv: ValueRef,
/// Server IV.
pub server_iv: ValueRef,
/// Builds the circuits for the PRF.
///
/// This function can be used ahead of time to build the circuits for the PRF,
/// which at the moment is CPU and memory intensive.
pub async fn build_circuits() {
prf::Circuits::get().await;
}
/// PRF trait for computing TLS PRF.
#[async_trait]
pub trait Prf {
/// Sets up the PRF.
///
/// # Arguments
///
/// * `pms` - The pre-master secret.
async fn setup(&mut self, pms: ValueRef) -> Result<SessionKeys, PrfError>;
/// PRF output.
#[derive(Debug, Clone, Copy)]
pub struct PrfOutput {
/// TLS session keys.
pub keys: SessionKeys,
/// Client finished verify data.
pub cf_vd: Array<U8, 12>,
/// Server finished verify data.
pub sf_vd: Array<U8, 12>,
}
/// Sets the client random.
///
/// This must be set after calling [`Prf::setup`].
///
/// Only the leader can provide the client random.
async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError>;
/// Preprocesses the PRF.
async fn preprocess(&mut self) -> Result<(), PrfError>;
/// Computes the client finished verify data.
///
/// # Arguments
///
/// * `handshake_hash` - The handshake transcript hash.
async fn compute_client_finished_vd(
&mut self,
handshake_hash: [u8; 32],
) -> Result<[u8; 12], PrfError>;
/// Computes the server finished verify data.
///
/// # Arguments
///
/// * `handshake_hash` - The handshake transcript hash.
async fn compute_server_finished_vd(
&mut self,
handshake_hash: [u8; 32],
) -> Result<[u8; 12], PrfError>;
/// Computes the session keys.
///
/// # Arguments
///
/// * `server_random` - The server random.
async fn compute_session_keys(
&mut self,
server_random: [u8; 32],
) -> Result<SessionKeys, PrfError>;
/// Session keys computed by the PRF.
#[derive(Debug, Clone, Copy)]
pub struct SessionKeys {
/// Client write key.
pub client_write_key: Array<U8, 16>,
/// Server write key.
pub server_write_key: Array<U8, 16>,
/// Client IV.
pub client_iv: Array<U8, 4>,
/// Server IV.
pub server_iv: Array<U8, 4>,
}
#[cfg(test)]
mod tests {
use mpz_common::executor::test_st_executor;
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Decode, Memory};
use mpz_common::context::test_st_context;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys};
use mpz_ot::ideal::ot::ideal_ot;
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{memory::correlated::Delta, prelude::*};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
@@ -113,120 +81,89 @@ mod tests {
#[ignore = "expensive"]
#[tokio::test]
async fn test_prf() {
let mut rng = StdRng::seed_from_u64(0);
let pms = [42u8; 32];
let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32];
let ms = compute_ms(pms, client_random, server_random);
let (leader_ctx_0, follower_ctx_0) = test_st_executor(128);
let (leader_ctx_1, follower_ctx_1) = test_st_executor(128);
let (mut leader_ctx, mut follower_ctx) = test_st_context(128);
let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let leader_thread_0 = DEAPThread::new(
DEAPRole::Leader,
[0u8; 32],
leader_ctx_0,
leader_ot_send_0,
leader_ot_recv_0,
);
let leader_thread_1 = leader_thread_0
.new_thread(leader_ctx_1, leader_ot_send_1, leader_ot_recv_1)
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap();
leader_vm.assign(leader_pms, pms).unwrap();
leader_vm.commit(leader_pms).unwrap();
let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
follower_vm.mark_public(follower_pms).unwrap();
follower_vm.assign(follower_pms, pms).unwrap();
follower_vm.commit(follower_pms).unwrap();
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
leader
.set_client_random(&mut leader_vm, Some(client_random))
.unwrap();
follower.set_client_random(&mut follower_vm, None).unwrap();
leader
.set_server_random(&mut leader_vm, server_random)
.unwrap();
follower
.set_server_random(&mut follower_vm, server_random)
.unwrap();
let follower_thread_0 = DEAPThread::new(
DEAPRole::Follower,
[0u8; 32],
follower_ctx_0,
follower_ot_send_0,
follower_ot_recv_0,
);
let follower_thread_1 = follower_thread_0
.new_thread(follower_ctx_1, follower_ot_send_1, follower_ot_recv_1)
let leader_cwk = leader_vm
.decode(leader_output.keys.client_write_key)
.unwrap();
// Set up public PMS for testing.
let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
let follower_pms = follower_thread_0
.new_public_input::<[u8; 32]>("pms")
let leader_swk = leader_vm
.decode(leader_output.keys.server_write_key)
.unwrap();
let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap();
leader_thread_0.assign(&leader_pms, pms).unwrap();
follower_thread_0.assign(&follower_pms, pms).unwrap();
let mut leader = MpcPrf::new(
PrfConfig::builder().role(Role::Leader).build().unwrap(),
leader_thread_0,
leader_thread_1,
);
let mut follower = MpcPrf::new(
PrfConfig::builder().role(Role::Follower).build().unwrap(),
follower_thread_0,
follower_thread_1,
);
let follower_cwk = follower_vm
.decode(follower_output.keys.client_write_key)
.unwrap();
let follower_swk = follower_vm
.decode(follower_output.keys.server_write_key)
.unwrap();
let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap();
futures::join!(
async {
leader.setup(leader_pms).await.unwrap();
leader.set_client_random(Some(client_random)).await.unwrap();
leader.preprocess().await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower.setup(follower_pms).await.unwrap();
follower.set_client_random(None).await.unwrap();
follower.preprocess().await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
let (leader_session_keys, follower_session_keys) = futures::try_join!(
leader.compute_session_keys(server_random),
follower.compute_session_keys(server_random)
)
.unwrap();
let leader_cwk = leader_cwk.await.unwrap();
let leader_swk = leader_swk.await.unwrap();
let leader_civ = leader_civ.await.unwrap();
let leader_siv = leader_siv.await.unwrap();
let SessionKeys {
client_write_key: leader_cwk,
server_write_key: leader_swk,
client_iv: leader_civ,
server_iv: leader_siv,
} = leader_session_keys;
let SessionKeys {
client_write_key: follower_cwk,
server_write_key: follower_swk,
client_iv: follower_civ,
server_iv: follower_siv,
} = follower_session_keys;
// Decode session keys
let (leader_session_keys, follower_session_keys) = futures::try_join!(
async {
leader
.thread_mut()
.decode(&[leader_cwk, leader_swk, leader_civ, leader_siv])
.await
},
async {
follower
.thread_mut()
.decode(&[follower_cwk, follower_swk, follower_civ, follower_siv])
.await
}
)
.unwrap();
let leader_cwk: [u8; 16] = leader_session_keys[0].clone().try_into().unwrap();
let leader_swk: [u8; 16] = leader_session_keys[1].clone().try_into().unwrap();
let leader_civ: [u8; 4] = leader_session_keys[2].clone().try_into().unwrap();
let leader_siv: [u8; 4] = leader_session_keys[3].clone().try_into().unwrap();
let follower_cwk: [u8; 16] = follower_session_keys[0].clone().try_into().unwrap();
let follower_swk: [u8; 16] = follower_session_keys[1].clone().try_into().unwrap();
let follower_civ: [u8; 4] = follower_session_keys[2].clone().try_into().unwrap();
let follower_siv: [u8; 4] = follower_session_keys[3].clone().try_into().unwrap();
let follower_cwk = follower_cwk.await.unwrap();
let follower_swk = follower_swk.await.unwrap();
let follower_civ = follower_civ.await.unwrap();
let follower_siv = follower_siv.await.unwrap();
let (expected_cwk, expected_swk, expected_civ, expected_siv) =
session_keys(pms, client_random, server_random);
@@ -244,24 +181,43 @@ mod tests {
let cf_hs_hash = [1u8; 32];
let sf_hs_hash = [2u8; 32];
let (cf_vd, _) = futures::try_join!(
leader.compute_client_finished_vd(cf_hs_hash),
follower.compute_client_finished_vd(cf_hs_hash)
)
.unwrap();
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap();
let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap();
let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap();
let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap();
futures::join!(
async {
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
let leader_cf_vd = leader_cf_vd.await.unwrap();
let leader_sf_vd = leader_sf_vd.await.unwrap();
let follower_cf_vd = follower_cf_vd.await.unwrap();
let follower_sf_vd = follower_sf_vd.await.unwrap();
let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash);
assert_eq!(cf_vd, expected_cf_vd);
let (sf_vd, _) = futures::try_join!(
leader.compute_server_finished_vd(sf_hs_hash),
follower.compute_server_finished_vd(sf_hs_hash)
)
.unwrap();
let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash);
assert_eq!(sf_vd, expected_sf_vd);
assert_eq!(leader_cf_vd, expected_cf_vd);
assert_eq!(leader_sf_vd, expected_sf_vd);
assert_eq!(follower_cf_vd, expected_cf_vd);
assert_eq!(follower_sf_vd, expected_sf_vd);
}
}

View File

@@ -3,60 +3,65 @@ use std::{
sync::{Arc, OnceLock},
};
use async_trait::async_trait;
use hmac_sha256_circuits::{build_session_keys, build_verify_data};
use mpz_circuits::Circuit;
use mpz_common::cpu::CpuBackend;
use mpz_garble::{config::Visibility, value::ValueRef, Decode, Execute, Load, Memory};
use mpz_vm_core::{
memory::{
binary::{Binary, U32, U8},
Array,
},
prelude::*,
Call, Vm,
};
use tracing::instrument;
use crate::{Prf, PrfConfig, PrfError, Role, SessionKeys, CF_LABEL, SF_LABEL};
use crate::{PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL};
/// Circuit for computing TLS session keys.
static SESSION_KEYS_CIRC: OnceLock<Arc<Circuit>> = OnceLock::new();
/// Circuit for computing TLS client verify data.
static CLIENT_VD_CIRC: OnceLock<Arc<Circuit>> = OnceLock::new();
/// Circuit for computing TLS server verify data.
static SERVER_VD_CIRC: OnceLock<Arc<Circuit>> = OnceLock::new();
#[derive(Debug)]
pub(crate) struct Randoms {
pub(crate) client_random: ValueRef,
pub(crate) server_random: ValueRef,
pub(crate) struct Circuits {
session_keys: Arc<Circuit>,
client_vd: Arc<Circuit>,
server_vd: Arc<Circuit>,
}
#[derive(Debug, Clone)]
pub(crate) struct HashState {
pub(crate) ms_outer_hash_state: ValueRef,
pub(crate) ms_inner_hash_state: ValueRef,
}
impl Circuits {
pub(crate) async fn get() -> &'static Self {
static CIRCUITS: OnceLock<Circuits> = OnceLock::new();
if let Some(circuits) = CIRCUITS.get() {
return circuits;
}
#[derive(Debug)]
pub(crate) struct VerifyData {
pub(crate) handshake_hash: ValueRef,
pub(crate) vd: ValueRef,
let (session_keys, client_vd, server_vd) = futures::join!(
CpuBackend::blocking(build_session_keys),
CpuBackend::blocking(|| build_verify_data(CF_LABEL)),
CpuBackend::blocking(|| build_verify_data(SF_LABEL)),
);
_ = CIRCUITS.set(Circuits {
session_keys,
client_vd,
server_vd,
});
CIRCUITS.get().unwrap()
}
}
#[derive(Debug)]
pub(crate) enum State {
Initialized,
SessionKeys {
pms: ValueRef,
randoms: Randoms,
hash_state: HashState,
keys: crate::SessionKeys,
cf_vd: VerifyData,
sf_vd: VerifyData,
client_random: Array<U8, 32>,
server_random: Array<U8, 32>,
cf_hash: Array<U8, 32>,
sf_hash: Array<U8, 32>,
},
ClientFinished {
hash_state: HashState,
cf_vd: VerifyData,
sf_vd: VerifyData,
cf_hash: Array<U8, 32>,
sf_hash: Array<U8, 32>,
},
ServerFinished {
hash_state: HashState,
sf_vd: VerifyData,
sf_hash: Array<U8, 32>,
},
Complete,
Error,
@@ -69,14 +74,12 @@ impl State {
}
/// MPC PRF for computing TLS HMAC-SHA256 PRF.
pub struct MpcPrf<E> {
pub struct MpcPrf {
config: PrfConfig,
state: State,
thread_0: E,
thread_1: E,
}
impl<E> Debug for MpcPrf<E> {
impl Debug for MpcPrf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MpcPrf")
.field("config", &self.config)
@@ -85,359 +88,225 @@ impl<E> Debug for MpcPrf<E> {
}
}
impl<E> MpcPrf<E>
where
E: Load + Memory + Execute + Decode + Send,
{
impl MpcPrf {
/// Creates a new instance of the PRF.
pub fn new(config: PrfConfig, thread_0: E, thread_1: E) -> MpcPrf<E> {
pub fn new(config: PrfConfig) -> MpcPrf {
MpcPrf {
config,
state: State::Initialized,
thread_0,
thread_1,
}
}
/// Returns a mutable reference to the MPC thread.
pub fn thread_mut(&mut self) -> &mut E {
&mut self.thread_0
}
/// Executes a circuit which computes TLS session keys.
/// Allocates resources for the PRF.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `pms` - The pre-master secret.
#[instrument(level = "debug", skip_all, err)]
async fn execute_session_keys(
pub fn alloc(
&mut self,
server_random: [u8; 32],
) -> Result<SessionKeys, PrfError> {
let State::SessionKeys {
pms,
randoms: randoms_refs,
hash_state,
keys,
cf_vd,
sf_vd,
} = self.state.take()
else {
return Err(PrfError::state("session keys not initialized"));
};
let circ = SESSION_KEYS_CIRC
.get()
.expect("session keys circuit is set");
self.thread_0
.assign(&randoms_refs.server_random, server_random)?;
self.thread_0
.execute(
circ.clone(),
&[pms, randoms_refs.client_random, randoms_refs.server_random],
&[
keys.client_write_key.clone(),
keys.server_write_key.clone(),
keys.client_iv.clone(),
keys.server_iv.clone(),
hash_state.ms_outer_hash_state.clone(),
hash_state.ms_inner_hash_state.clone(),
],
)
.await?;
self.state = State::ClientFinished {
hash_state,
cf_vd,
sf_vd,
};
Ok(keys)
}
#[instrument(level = "debug", skip_all, err)]
async fn execute_cf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> {
let State::ClientFinished {
hash_state,
cf_vd,
sf_vd,
} = self.state.take()
else {
return Err(PrfError::state("PRF not in client finished state"));
};
let circ = CLIENT_VD_CIRC.get().expect("client vd circuit is set");
self.thread_0
.assign(&cf_vd.handshake_hash, handshake_hash)?;
self.thread_0
.execute(
circ.clone(),
&[
hash_state.ms_outer_hash_state.clone(),
hash_state.ms_inner_hash_state.clone(),
cf_vd.handshake_hash,
],
&[cf_vd.vd.clone()],
)
.await?;
let mut outputs = self.thread_0.decode(&[cf_vd.vd]).await?;
let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes");
self.state = State::ServerFinished { hash_state, sf_vd };
Ok(vd)
}
#[instrument(level = "debug", skip_all, err)]
async fn execute_sf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> {
let State::ServerFinished { hash_state, sf_vd } = self.state.take() else {
return Err(PrfError::state("PRF not in server finished state"));
};
let circ = SERVER_VD_CIRC.get().expect("server vd circuit is set");
self.thread_0
.assign(&sf_vd.handshake_hash, handshake_hash)?;
self.thread_0
.execute(
circ.clone(),
&[
hash_state.ms_outer_hash_state,
hash_state.ms_inner_hash_state,
sf_vd.handshake_hash,
],
&[sf_vd.vd.clone()],
)
.await?;
let mut outputs = self.thread_0.decode(&[sf_vd.vd]).await?;
let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes");
self.state = State::Complete;
Ok(vd)
}
}
#[async_trait]
impl<E> Prf for MpcPrf<E>
where
E: Memory + Load + Execute + Decode + Send,
{
#[instrument(level = "debug", skip_all, err)]
async fn setup(&mut self, pms: ValueRef) -> Result<SessionKeys, PrfError> {
vm: &mut dyn Vm<Binary>,
pms: Array<U8, 32>,
) -> Result<PrfOutput, PrfError> {
let State::Initialized = self.state.take() else {
return Err(PrfError::state("PRF not in initialized state"));
};
let thread = &mut self.thread_0;
let circuits = futures::executor::block_on(Circuits::get());
let randoms = Randoms {
// The client random is kept private so that the handshake transcript
// hashes do not leak information about the server's identity.
client_random: thread.new_input::<[u8; 32]>(
"client_random",
match self.config.role {
Role::Leader => Visibility::Private,
Role::Follower => Visibility::Blind,
},
)?,
server_random: thread.new_input::<[u8; 32]>("server_random", Visibility::Public)?,
};
let client_random = vm.alloc().map_err(PrfError::vm)?;
let server_random = vm.alloc().map_err(PrfError::vm)?;
// The client random is kept private so that the handshake transcript
// hashes do not leak information about the server's identity.
match self.config.role {
Role::Leader => vm.mark_private(client_random),
Role::Follower => vm.mark_blind(client_random),
}
.map_err(PrfError::vm)?;
vm.mark_public(server_random).map_err(PrfError::vm)?;
#[allow(clippy::type_complexity)]
let (
client_write_key,
server_write_key,
client_iv,
server_iv,
ms_outer_hash_state,
ms_inner_hash_state,
): (
Array<U8, 16>,
Array<U8, 16>,
Array<U8, 4>,
Array<U8, 4>,
Array<U32, 8>,
Array<U32, 8>,
) = vm
.call(
Call::new(circuits.session_keys.clone())
.arg(pms)
.arg(client_random)
.arg(server_random)
.build()
.map_err(PrfError::vm)?,
)
.map_err(PrfError::vm)?;
let keys = SessionKeys {
client_write_key: thread.new_output::<[u8; 16]>("client_write_key")?,
server_write_key: thread.new_output::<[u8; 16]>("server_write_key")?,
client_iv: thread.new_output::<[u8; 4]>("client_write_iv")?,
server_iv: thread.new_output::<[u8; 4]>("server_write_iv")?,
client_write_key,
server_write_key,
client_iv,
server_iv,
};
let hash_state = HashState {
ms_outer_hash_state: thread.new_output::<[u32; 8]>("ms_outer_hash_state")?,
ms_inner_hash_state: thread.new_output::<[u32; 8]>("ms_inner_hash_state")?,
};
let cf_hash = vm.alloc().map_err(PrfError::vm)?;
vm.mark_public(cf_hash).map_err(PrfError::vm)?;
let cf_vd = VerifyData {
handshake_hash: thread.new_input::<[u8; 32]>("cf_hash", Visibility::Public)?,
vd: thread.new_output::<[u8; 12]>("cf_vd")?,
};
let cf_vd = vm
.call(
Call::new(circuits.client_vd.clone())
.arg(ms_outer_hash_state)
.arg(ms_inner_hash_state)
.arg(cf_hash)
.build()
.map_err(PrfError::vm)?,
)
.map_err(PrfError::vm)?;
let sf_vd = VerifyData {
handshake_hash: thread.new_input::<[u8; 32]>("sf_hash", Visibility::Public)?,
vd: thread.new_output::<[u8; 12]>("sf_vd")?,
};
let sf_hash = vm.alloc().map_err(PrfError::vm)?;
vm.mark_public(sf_hash).map_err(PrfError::vm)?;
let sf_vd = vm
.call(
Call::new(circuits.server_vd.clone())
.arg(ms_outer_hash_state)
.arg(ms_inner_hash_state)
.arg(sf_hash)
.build()
.map_err(PrfError::vm)?,
)
.map_err(PrfError::vm)?;
self.state = State::SessionKeys {
pms,
randoms,
hash_state,
keys: keys.clone(),
cf_vd,
sf_vd,
client_random,
server_random,
cf_hash,
sf_hash,
};
Ok(keys)
Ok(PrfOutput { keys, cf_vd, sf_vd })
}
/// Sets the client random.
///
/// Only the leader can provide the client random.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `client_random` - The client random.
#[instrument(level = "debug", skip_all, err)]
async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError> {
let State::SessionKeys { randoms, .. } = &self.state else {
pub fn set_client_random(
&mut self,
vm: &mut dyn Vm<Binary>,
random: Option<[u8; 32]>,
) -> Result<(), PrfError> {
let State::SessionKeys { client_random, .. } = &self.state else {
return Err(PrfError::state("PRF not set up"));
};
if self.config.role == Role::Leader {
let Some(client_random) = client_random else {
let Some(random) = random else {
return Err(PrfError::role("leader must provide client random"));
};
self.thread_0
.assign(&randoms.client_random, client_random)?;
} else if client_random.is_some() {
vm.assign(*client_random, random).map_err(PrfError::vm)?;
} else if random.is_some() {
return Err(PrfError::role("only leader can set client random"));
}
self.thread_0
.commit(&[randoms.client_random.clone()])
.await?;
vm.commit(*client_random).map_err(PrfError::vm)?;
Ok(())
}
/// Sets the server random.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `server_random` - The server random.
#[instrument(level = "debug", skip_all, err)]
async fn preprocess(&mut self) -> Result<(), PrfError> {
pub fn set_server_random(
&mut self,
vm: &mut dyn Vm<Binary>,
random: [u8; 32],
) -> Result<(), PrfError> {
let State::SessionKeys {
pms,
randoms,
hash_state,
keys,
cf_vd,
sf_vd,
server_random,
cf_hash,
sf_hash,
..
} = self.state.take()
else {
return Err(PrfError::state("PRF not set up"));
};
// Builds all circuits in parallel and preprocesses the session keys circuit.
futures::try_join!(
async {
if SESSION_KEYS_CIRC.get().is_none() {
_ = SESSION_KEYS_CIRC.set(CpuBackend::blocking(build_session_keys).await);
}
vm.assign(server_random, random).map_err(PrfError::vm)?;
vm.commit(server_random).map_err(PrfError::vm)?;
let circ = SESSION_KEYS_CIRC
.get()
.expect("session keys circuit should be built");
self.thread_0
.load(
circ.clone(),
&[
pms.clone(),
randoms.client_random.clone(),
randoms.server_random.clone(),
],
&[
keys.client_write_key.clone(),
keys.server_write_key.clone(),
keys.client_iv.clone(),
keys.server_iv.clone(),
hash_state.ms_outer_hash_state.clone(),
hash_state.ms_inner_hash_state.clone(),
],
)
.await?;
Ok::<_, PrfError>(())
},
async {
if CLIENT_VD_CIRC.get().is_none() {
_ = CLIENT_VD_CIRC
.set(CpuBackend::blocking(move || build_verify_data(CF_LABEL)).await);
}
Ok::<_, PrfError>(())
},
async {
if SERVER_VD_CIRC.get().is_none() {
_ = SERVER_VD_CIRC
.set(CpuBackend::blocking(move || build_verify_data(SF_LABEL)).await);
}
Ok::<_, PrfError>(())
}
)?;
// Finishes preprocessing the verify data circuits.
futures::try_join!(
async {
self.thread_0
.load(
CLIENT_VD_CIRC
.get()
.expect("client finished circuit should be built")
.clone(),
&[
hash_state.ms_outer_hash_state.clone(),
hash_state.ms_inner_hash_state.clone(),
cf_vd.handshake_hash.clone(),
],
&[cf_vd.vd.clone()],
)
.await
},
async {
self.thread_1
.load(
SERVER_VD_CIRC
.get()
.expect("server finished circuit should be built")
.clone(),
&[
hash_state.ms_outer_hash_state.clone(),
hash_state.ms_inner_hash_state.clone(),
sf_vd.handshake_hash.clone(),
],
&[sf_vd.vd.clone()],
)
.await
}
)?;
self.state = State::SessionKeys {
pms,
randoms,
hash_state,
keys,
cf_vd,
sf_vd,
};
self.state = State::ClientFinished { cf_hash, sf_hash };
Ok(())
}
/// Sets the client finished handshake hash.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
async fn compute_client_finished_vd(
pub fn set_cf_hash(
&mut self,
vm: &mut dyn Vm<Binary>,
handshake_hash: [u8; 32],
) -> Result<[u8; 12], PrfError> {
self.execute_cf_vd(handshake_hash).await
) -> Result<(), PrfError> {
let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else {
return Err(PrfError::state("PRF not in client finished state"));
};
vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?;
vm.commit(cf_hash).map_err(PrfError::vm)?;
self.state = State::ServerFinished { sf_hash };
Ok(())
}
/// Sets the server finished handshake hash.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `handshake_hash` - The handshake transcript hash.
#[instrument(level = "debug", skip_all, err)]
async fn compute_server_finished_vd(
pub fn set_sf_hash(
&mut self,
vm: &mut dyn Vm<Binary>,
handshake_hash: [u8; 32],
) -> Result<[u8; 12], PrfError> {
self.execute_sf_vd(handshake_hash).await
}
) -> Result<(), PrfError> {
let State::ServerFinished { sf_hash } = self.state.take() else {
return Err(PrfError::state("PRF not in server finished state"));
};
#[instrument(level = "debug", skip_all, err)]
async fn compute_session_keys(
&mut self,
server_random: [u8; 32],
) -> Result<SessionKeys, PrfError> {
self.execute_session_keys(server_random).await
vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?;
vm.commit(sf_hash).map_err(PrfError::vm)?;
self.state = State::Complete;
Ok(())
}
}

View File

@@ -13,32 +13,30 @@ name = "key_exchange"
[features]
default = ["mock"]
mock = []
mock = ["mpz-share-conversion/test-utils", "mpz-common/ideal"]
[dependencies]
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [
"ideal",
] }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
mpz-common = { workspace = true }
mpz-fields = { workspace = true }
mpz-share-conversion = { workspace = true }
mpz-circuits = { workspace = true }
mpz-core = { workspace = true }
p256 = { workspace = true, features = ["ecdh", "serde"] }
async-trait = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }
futures = { workspace = true }
serio = { workspace = true }
derive_builder = { workspace = true }
tracing = { workspace = true }
rand = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [
"ideal",
] }
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
rand_chacha = { workspace = true }
rand_core = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rstest = { workspace = true }

View File

@@ -1,8 +1,7 @@
//! This module provides the circuits used in the key exchange protocol.
use std::sync::Arc;
use mpz_circuits::{circuits::big_num::nbyte_add_mod_trace, Circuit, CircuitBuilder};
use std::sync::Arc;
/// NIST P-256 prime big-endian.
static P: [u8; 32] = [

View File

@@ -1,29 +1 @@
use derive_builder::Builder;
/// Role in the key exchange protocol.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
/// Leader.
Leader,
/// Follower.
Follower,
}
/// A config used for [MpcKeyExchange](super::MpcKeyExchange).
#[derive(Debug, Clone, Builder)]
pub struct KeyExchangeConfig {
/// Protocol role.
role: Role,
}
impl KeyExchangeConfig {
/// Creates a new builder for the key exchange configuration.
pub fn builder() -> KeyExchangeConfigBuilder {
KeyExchangeConfigBuilder::default()
}
/// Get the role of this instance.
pub fn role(&self) -> &Role {
&self.role
}
}

View File

@@ -1,120 +1,87 @@
use core::fmt;
use std::error::Error;
/// A key exchange error.
/// MPC-TLS protocol error.
#[derive(Debug, thiserror::Error)]
pub struct KeyExchangeError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn Error + Send + Sync>>,
#[error(transparent)]
pub struct KeyExchangeError(#[from] pub(crate) ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("key exchange error: {0}")]
pub(crate) enum ErrorRepr {
#[error("state error: {0}")]
State(Box<dyn Error + Send + Sync + 'static>),
#[error("context error: {0}")]
Ctx(Box<dyn Error + Send + Sync + 'static>),
#[error("io error: {0}")]
Io(std::io::Error),
#[error("vm error: {0}")]
Vm(Box<dyn Error + Send + Sync + 'static>),
#[error("share conversion error: {0}")]
ShareConversion(Box<dyn Error + Send + Sync + 'static>),
#[error("role error: {0}")]
Role(Box<dyn Error + Send + Sync + 'static>),
#[error("key error: {0}")]
Key(Box<dyn Error + Send + Sync + 'static>),
}
impl KeyExchangeError {
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
pub(crate) fn state<E>(err: E) -> KeyExchangeError
where
E: Into<Box<dyn Error + Send + Sync>>,
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self {
kind,
source: Some(source.into()),
}
Self(ErrorRepr::State(err.into()))
}
#[cfg(test)]
pub(crate) fn kind(&self) -> &ErrorKind {
&self.kind
pub(crate) fn ctx<E>(err: E) -> KeyExchangeError
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Ctx(err.into()))
}
pub(crate) fn state(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::State,
source: Some(msg.into().into()),
}
pub(crate) fn vm<E>(err: E) -> KeyExchangeError
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
pub(crate) fn role(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Role,
source: Some(msg.into().into()),
}
pub(crate) fn share_conversion<E>(err: E) -> KeyExchangeError
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::ShareConversion(err.into()))
}
}
#[derive(Debug)]
pub(crate) enum ErrorKind {
Io,
Context,
Vm,
ShareConversion,
Key,
State,
Role,
}
pub(crate) fn role<E>(err: E) -> KeyExchangeError
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Role(err.into()))
}
impl fmt::Display for KeyExchangeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
ErrorKind::Io => write!(f, "io error")?,
ErrorKind::Context => write!(f, "context error")?,
ErrorKind::Vm => write!(f, "vm error")?,
ErrorKind::ShareConversion => write!(f, "share conversion error")?,
ErrorKind::Key => write!(f, "key error")?,
ErrorKind::State => write!(f, "state error")?,
ErrorKind::Role => write!(f, "role error")?,
}
if let Some(ref source) = self.source {
write!(f, " caused by: {}", source)?;
}
Ok(())
pub(crate) fn key<E>(err: E) -> KeyExchangeError
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Key(err.into()))
}
}
impl From<mpz_common::ContextError> for KeyExchangeError {
fn from(error: mpz_common::ContextError) -> Self {
Self::new(ErrorKind::Context, error)
}
}
impl From<mpz_garble::MemoryError> for KeyExchangeError {
fn from(error: mpz_garble::MemoryError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::LoadError> for KeyExchangeError {
fn from(error: mpz_garble::LoadError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::ExecutionError> for KeyExchangeError {
fn from(error: mpz_garble::ExecutionError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::DecodeError> for KeyExchangeError {
fn from(error: mpz_garble::DecodeError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_share_conversion::ShareConversionError> for KeyExchangeError {
fn from(error: mpz_share_conversion::ShareConversionError) -> Self {
Self::new(ErrorKind::ShareConversion, error)
fn from(value: mpz_common::ContextError) -> Self {
Self::ctx(value)
}
}
impl From<p256::elliptic_curve::Error> for KeyExchangeError {
fn from(error: p256::elliptic_curve::Error) -> Self {
Self::new(ErrorKind::Key, error)
fn from(value: p256::elliptic_curve::Error) -> Self {
Self::key(value)
}
}
impl From<std::io::Error> for KeyExchangeError {
fn from(error: std::io::Error) -> Self {
Self::new(ErrorKind::Io, error)
fn from(err: std::io::Error) -> Self {
Self(ErrorRepr::Io(err))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,63 +15,64 @@
#![forbid(unsafe_code)]
mod circuit;
mod config;
pub(crate) mod error;
mod exchange;
#[cfg(feature = "mock")]
pub mod mock;
pub(crate) mod point_addition;
pub use config::{
KeyExchangeConfig, KeyExchangeConfigBuilder, KeyExchangeConfigBuilderError, Role,
};
pub use error::KeyExchangeError;
pub use exchange::MpcKeyExchange;
use async_trait::async_trait;
use mpz_garble::value::ValueRef;
use mpz_common::Context;
use mpz_memory_core::{
binary::{Binary, U8},
Array,
};
use mpz_vm_core::Vm;
use p256::PublicKey;
/// Pre-master secret.
#[derive(Debug, Clone)]
pub struct Pms(ValueRef);
pub type Pms = Array<U8, 32>;
impl Pms {
/// Creates a new PMS.
pub fn new(value: ValueRef) -> Self {
Self(value)
}
/// Gets the value of the PMS.
pub fn into_value(self) -> ValueRef {
self.0
}
/// Role in the key exchange protocol.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
/// Leader.
Leader,
/// Follower.
Follower,
}
/// A trait for the 3-party key exchange protocol.
#[async_trait]
pub trait KeyExchange {
/// Gets the server's public key.
fn server_key(&self) -> Option<PublicKey>;
/// Allocate necessary computational resources.
fn alloc(&mut self, vm: &mut dyn Vm<Binary>) -> Result<Pms, KeyExchangeError>;
/// Sets the server's public key.
async fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>;
fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>;
/// Gets the server's public key.
fn server_key(&self) -> Option<PublicKey>;
/// Computes the client's public key.
///
/// The client's public key in this context is the combined public key (EC
/// point addition) of the leader's public key and the follower's public
/// key.
async fn client_key(&mut self) -> Result<PublicKey, KeyExchangeError>;
fn client_key(&self) -> Result<PublicKey, KeyExchangeError>;
/// Performs any necessary one-time setup, returning a reference to the PMS.
///
/// The PMS will not be assigned until `compute_pms` is called.
async fn setup(&mut self) -> Result<Pms, KeyExchangeError>;
/// Performs one-time setup for the key exchange protocol.
async fn setup(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError>;
/// Preprocesses the key exchange.
async fn preprocess(&mut self) -> Result<(), KeyExchangeError>;
/// Computes the shares of the PMS.
async fn compute_shares(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError>;
/// Computes the PMS.
async fn compute_pms(&mut self) -> Result<Pms, KeyExchangeError>;
/// Assigns the PMS shares to the VM.
fn assign(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), KeyExchangeError>;
/// Finalizes the key exchange protocol.
async fn finalize(&mut self) -> Result<(), KeyExchangeError>;
}

View File

@@ -1,71 +1,51 @@
//! This module provides mock types for key exchange leader and follower and a
//! function to create such a pair.
use crate::{KeyExchangeConfig, MpcKeyExchange, Role};
use mpz_common::executor::{test_st_executor, STExecutor};
use mpz_garble::{Decode, Execute, Memory};
use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter};
use serio::channel::MemoryDuplex;
use crate::{MpcKeyExchange, Role};
use mpz_core::Block;
use mpz_fields::p256::P256;
use mpz_share_conversion::ideal::{
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
};
/// A mock key exchange instance.
pub type MockKeyExchange<E> =
MpcKeyExchange<STExecutor<MemoryDuplex>, IdealShareConverter, IdealShareConverter, E>;
pub type MockKeyExchange =
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>;
/// Creates a mock pair of key exchange leader and follower.
pub fn create_mock_key_exchange_pair<E: Memory + Execute + Decode + Send>(
leader_executor: E,
follower_executor: E,
) -> (MockKeyExchange<E>, MockKeyExchange<E>) {
let (leader_ctx, follower_ctx) = test_st_executor(8);
let (leader_converter_0, follower_converter_0) = ideal_share_converter();
let (leader_converter_1, follower_converter_1) = ideal_share_converter();
pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
let (leader_converter_0, follower_converter_0) = ideal_share_convert(Block::ZERO);
let (follower_converter_1, leader_converter_1) = ideal_share_convert(Block::ZERO);
let key_exchange_config_leader = KeyExchangeConfig::builder()
.role(Role::Leader)
.build()
.unwrap();
let leader = MpcKeyExchange::new(Role::Leader, leader_converter_0, leader_converter_1);
let key_exchange_config_follower = KeyExchangeConfig::builder()
.role(Role::Follower)
.build()
.unwrap();
let leader = MpcKeyExchange::new(
key_exchange_config_leader,
leader_ctx,
leader_converter_0,
leader_converter_1,
leader_executor,
);
let follower = MpcKeyExchange::new(
key_exchange_config_follower,
follower_ctx,
follower_converter_0,
follower_converter_1,
follower_executor,
);
let follower = MpcKeyExchange::new(Role::Follower, follower_converter_1, follower_converter_0);
(leader, follower)
}
#[cfg(test)]
mod tests {
use mpz_garble::protocol::deap::mock::create_mock_deap_vm;
use crate::KeyExchange;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
use super::*;
use crate::KeyExchange;
#[test]
fn test_mock_is_ke() {
let (leader_vm, follower_vm) = create_mock_deap_vm();
let (leader, follower) = create_mock_key_exchange_pair(leader_vm, follower_vm);
let (leader, follower) = create_mock_key_exchange_pair();
fn is_key_exchange<T: KeyExchange>(_: T) {}
fn is_key_exchange<T: KeyExchange, V>(_: T) {}
is_key_exchange(leader);
is_key_exchange(follower);
is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
Generator<IdealCOTSender>,
>(leader);
is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
Evaluator<IdealCOTReceiver>,
>(follower);
}
}

View File

@@ -1,47 +0,0 @@
//! This module contains the message types exchanged between the prover and the TLS verifier.
use std::fmt::{self, Display, Formatter};
use p256::{elliptic_curve::sec1::ToEncodedPoint, PublicKey as P256PublicKey};
use serde::{Deserialize, Serialize};
/// A type for messages exchanged between the prover and the TLS verifier during the key exchange
/// protocol.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum KeyExchangeMessage {
FollowerPublicKey(PublicKey),
ServerPublicKey(PublicKey),
}
/// A wrapper for a serialized public key.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PublicKey {
/// The sec1 serialized public key.
pub key: Vec<u8>,
}
/// An error that can occur during parsing of a public key.
#[derive(Debug, thiserror::Error)]
pub struct KeyParseError(#[from] p256::elliptic_curve::Error);
impl Display for KeyParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Unable to parse public key: {}", self.0)
}
}
impl From<P256PublicKey> for PublicKey {
fn from(value: P256PublicKey) -> Self {
let key = value.to_encoded_point(false).as_bytes().to_vec();
PublicKey { key }
}
}
impl TryFrom<PublicKey> for P256PublicKey {
type Error = KeyParseError;
fn try_from(value: PublicKey) -> Result<Self, Self::Error> {
P256PublicKey::from_sec1_bytes(&value.key).map_err(Into::into)
}
}

View File

@@ -1,27 +1,28 @@
//! This module implements a secure two-party computation protocol for adding
//! two private EC points and secret-sharing the resulting x coordinate (the
//! shares are field elements of the field underlying the elliptic curve).
//! This protocol has semi-honest security.
//! shares are field elements of the field underlying the elliptic curve). This
//! protocol has semi-honest security.
//!
//! The protocol is described in <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
//! The protocol is described in
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
use mpz_common::Context;
use crate::{KeyExchangeError, Role};
use mpz_common::{Context, Flush};
use mpz_fields::{p256::P256, Field};
use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive};
use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConvert};
use p256::EncodedPoint;
use crate::{config::Role, error::ErrorKind, KeyExchangeError};
/// Derives the x-coordinate share of an elliptic curve point.
pub(crate) async fn derive_x_coord_share<Ctx, C>(
pub(crate) async fn derive_x_coord_share<C>(
ctx: &mut Context,
role: Role,
ctx: &mut Ctx,
converter: &mut C,
share: EncodedPoint,
) -> Result<P256, KeyExchangeError>
where
Ctx: Context,
C: AdditiveToMultiplicative<Ctx, P256> + MultiplicativeToAdditive<Ctx, P256>,
C: ShareConvert<P256> + Flush + Send,
<C as AdditiveToMultiplicative<P256>>::Future: Send,
<C as MultiplicativeToAdditive<P256>>::Future: Send,
{
let [x, y] = decompose_point(share)?;
@@ -31,16 +32,40 @@ where
Role::Follower => vec![-y, -x],
};
let [a, b] = converter
.to_multiplicative(ctx, inputs)
.await?
let a2m = converter
.queue_to_multiplicative(&inputs)
.map_err(KeyExchangeError::share_conversion)?;
converter
.flush(ctx)
.await
.map_err(KeyExchangeError::share_conversion)?;
let [a, b] = a2m
.await
.map_err(KeyExchangeError::share_conversion)?
.shares
.try_into()
.expect("output is same length as input");
let c = a * b.inverse();
let c = a * b
.inverse()
.expect("field element should not be zero when inverting");
let c = c * c;
let d = converter.to_additive(ctx, vec![c]).await?[0];
let m2a = converter
.queue_to_additive(&[c])
.map_err(KeyExchangeError::share_conversion)?;
converter
.flush(ctx)
.await
.map_err(KeyExchangeError::share_conversion)?;
let d = m2a
.await
.map_err(KeyExchangeError::share_conversion)?
.shares[0];
let x_r = d + -x;
@@ -50,13 +75,11 @@ where
/// Decomposes the x and y coordinates of a SEC1 encoded point.
fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> {
// Coordinates are stored as big-endian bytes.
let mut x: [u8; 32] = (*point.x().ok_or(KeyExchangeError::new(
ErrorKind::Key,
"key share is an identity point",
))?)
let mut x: [u8; 32] = (*point
.x()
.ok_or(KeyExchangeError::key("key share is an identity point"))?)
.into();
let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::new(
ErrorKind::Key,
let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::key(
"key share is an identity point or compressed",
))?)
.into();
@@ -75,20 +98,20 @@ fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> {
mod tests {
use super::*;
use mpz_common::executor::test_st_executor;
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_fields::{p256::P256, Field};
use mpz_share_conversion::ideal::ideal_share_converter;
use mpz_share_conversion::ideal::ideal_share_convert;
use p256::{
elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint},
EncodedPoint, NonZeroScalar, ProjectivePoint, PublicKey,
};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha12Rng;
use rand::{rngs::StdRng, Rng, SeedableRng};
#[tokio::test]
async fn test_point_addition() {
let (mut ctx_a, mut ctx_b) = test_st_executor(8);
let mut rng = ChaCha12Rng::from_seed([0u8; 32]);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut rng = StdRng::seed_from_u64(0);
let p1: [u8; 32] = rng.gen();
let p2: [u8; 32] = rng.gen();
@@ -98,11 +121,11 @@ mod tests {
let p = add_curve_points(&p1, &p2);
let (mut c_a, mut c_b) = ideal_share_converter();
let (mut c_a, mut c_b) = ideal_share_convert(Block::ZERO);
let (a, b) = tokio::try_join!(
derive_x_coord_share(Role::Leader, &mut ctx_a, &mut c_a, p1),
derive_x_coord_share(Role::Follower, &mut ctx_b, &mut c_b, p2)
derive_x_coord_share(&mut ctx_a, Role::Leader, &mut c_a, p1),
derive_x_coord_share(&mut ctx_b, Role::Follower, &mut c_b, p2)
)
.unwrap();
@@ -113,7 +136,7 @@ mod tests {
#[test]
fn test_decompose_point() {
let mut rng = ChaCha12Rng::from_seed([0_u8; 32]);
let mut rng = StdRng::seed_from_u64(0);
let p_expected: [u8; 32] = rng.gen();
let p_expected = curve_point_from_be_bytes(p_expected);

View File

@@ -1,37 +0,0 @@
[package]
name = "tlsn-stream-cipher"
authors = ["TLSNotary Team"]
description = "2PC stream cipher implementation"
keywords = ["tls", "mpc", "2pc", "stream-cipher"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.8-pre"
edition = "2021"
[features]
default = ["mock"]
rayon = ["mpz-garble/rayon"]
mock = []
[dependencies]
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
tlsn-utils = { workspace = true }
aes = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }
async-trait = { workspace = true }
thiserror = { workspace = true }
derive_builder = { workspace = true }
tracing = { workspace = true }
opaque-debug = { workspace = true }
[dev-dependencies]
futures = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rstest = { workspace = true, features = ["async-timeout"] }
criterion = { workspace = true, features = ["async_tokio"] }
[[bench]]
name = "mock"
harness = false

View File

@@ -1,132 +0,0 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory};
use tlsn_stream_cipher::{
Aes128Ctr, CtrCircuit, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder,
};
async fn bench_stream_cipher_encrypt(len: usize) {
let (leader_vm, follower_vm) = create_mock_deap_vm();
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap();
leader_vm.assign(&leader_key, [0u8; 16]).unwrap();
leader_vm.assign(&leader_iv, [0u8; 4]).unwrap();
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap();
follower_vm.assign(&follower_key, [0u8; 16]).unwrap();
follower_vm.assign(&follower_iv, [0u8; 4]).unwrap();
let leader_config = StreamCipherConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();
let follower_config = StreamCipherConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();
let mut leader = MpcStreamCipher::<Aes128Ctr, _>::new(leader_config, leader_vm);
leader.set_key(leader_key, leader_iv);
let mut follower = MpcStreamCipher::<Aes128Ctr, _>::new(follower_config, follower_vm);
follower.set_key(follower_key, follower_iv);
let plaintext = vec![0u8; len];
let explicit_nonce = vec![0u8; 8];
_ = tokio::try_join!(
leader.encrypt_private(explicit_nonce.clone(), plaintext),
follower.encrypt_blind(explicit_nonce, len)
)
.unwrap();
_ = tokio::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
}
async fn bench_stream_cipher_zk(len: usize) {
let (leader_vm, follower_vm) = create_mock_deap_vm();
let key = [0u8; 16];
let iv = [0u8; 4];
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap();
leader_vm.assign(&leader_key, key).unwrap();
leader_vm.assign(&leader_iv, iv).unwrap();
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap();
follower_vm.assign(&follower_key, key).unwrap();
follower_vm.assign(&follower_iv, iv).unwrap();
let leader_config = StreamCipherConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();
let follower_config = StreamCipherConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();
let mut leader = MpcStreamCipher::<Aes128Ctr, _>::new(leader_config, leader_vm);
leader.set_key(leader_key, leader_iv);
let mut follower = MpcStreamCipher::<Aes128Ctr, _>::new(follower_config, follower_vm);
follower.set_key(follower_key, follower_iv);
futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap();
let plaintext = vec![0u8; len];
let explicit_nonce = [0u8; 8];
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap();
_ = tokio::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), plaintext),
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
)
.unwrap();
_ = tokio::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let rt = tokio::runtime::Runtime::new().unwrap();
let len = 1024;
let mut group = c.benchmark_group("stream_cipher/encrypt_private");
group.throughput(Throughput::Bytes(len as u64));
group.bench_function(BenchmarkId::from_parameter(len), |b| {
b.to_async(&rt)
.iter(|| async { bench_stream_cipher_encrypt(len).await })
});
drop(group);
let mut group = c.benchmark_group("stream_cipher/zk");
group.throughput(Throughput::Bytes(len as u64));
group.bench_function(BenchmarkId::from_parameter(len), |b| {
b.to_async(&rt)
.iter(|| async { bench_stream_cipher_zk(len).await })
});
drop(group);
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@@ -1,118 +0,0 @@
use std::sync::Arc;
use mpz_circuits::{
types::{StaticValueType, Value},
Circuit,
};
use crate::{circuit::AES_CTR, StreamCipherError};
/// A counter-mode block cipher circuit.
pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {
/// The key type.
type KEY: StaticValueType + TryFrom<Vec<u8>> + Send + Sync + 'static;
/// The block type.
type BLOCK: StaticValueType
+ TryFrom<Vec<u8>>
+ TryFrom<Value>
+ Into<Vec<u8>>
+ Default
+ Send
+ Sync
+ 'static;
/// The IV type.
type IV: StaticValueType
+ TryFrom<Vec<u8>>
+ TryFrom<Value>
+ Into<Vec<u8>>
+ Send
+ Sync
+ 'static;
/// The nonce type.
type NONCE: StaticValueType
+ TryFrom<Vec<u8>>
+ TryFrom<Value>
+ Into<Vec<u8>>
+ Clone
+ Copy
+ Send
+ Sync
+ std::fmt::Debug
+ 'static;
/// The length of the key.
const KEY_LEN: usize;
/// The length of the block.
const BLOCK_LEN: usize;
/// The length of the IV.
const IV_LEN: usize;
/// The length of the nonce.
const NONCE_LEN: usize;
/// Returns the circuit of the cipher.
fn circuit() -> Arc<Circuit>;
/// Applies the keystream to the message.
fn apply_keystream(
key: &[u8],
iv: &[u8],
start_ctr: usize,
explicit_nonce: &[u8],
msg: &[u8],
) -> Result<Vec<u8>, StreamCipherError>;
}
/// A circuit for AES-128 in counter mode.
#[derive(Default, Debug, Clone)]
pub struct Aes128Ctr;
impl CtrCircuit for Aes128Ctr {
type KEY = [u8; 16];
type BLOCK = [u8; 16];
type IV = [u8; 4];
type NONCE = [u8; 8];
const KEY_LEN: usize = 16;
const BLOCK_LEN: usize = 16;
const IV_LEN: usize = 4;
const NONCE_LEN: usize = 8;
fn circuit() -> Arc<Circuit> {
AES_CTR.clone()
}
fn apply_keystream(
key: &[u8],
iv: &[u8],
start_ctr: usize,
explicit_nonce: &[u8],
msg: &[u8],
) -> Result<Vec<u8>, StreamCipherError> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use ctr::Ctr32BE;
let key: &[u8; 16] = key
.try_into()
.map_err(|_| StreamCipherError::key_len::<Self>(key.len()))?;
let iv: &[u8; 4] = iv
.try_into()
.map_err(|_| StreamCipherError::iv_len::<Self>(iv.len()))?;
let explicit_nonce: &[u8; 8] = explicit_nonce
.try_into()
.map_err(|_| StreamCipherError::explicit_nonce_len::<Self>(explicit_nonce.len()))?;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(key.into(), &full_iv.into());
let mut buf = msg.to_vec();
cipher
.try_seek(start_ctr * Self::BLOCK_LEN)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut buf);
Ok(buf)
}
}

View File

@@ -1,68 +0,0 @@
use derive_builder::Builder;
use std::fmt::Debug;
/// Configuration for a stream cipher.
#[derive(Debug, Clone, Builder)]
pub struct StreamCipherConfig {
/// The ID of the stream cipher.
#[builder(setter(into))]
pub(crate) id: String,
/// The start block counter value.
#[builder(default = "2")]
pub(crate) start_ctr: usize,
/// Transcript ID used to determine the unique identifiers
/// for the plaintext bytes during encryption and decryption.
#[builder(setter(into), default = "\"transcript\".to_string()")]
pub(crate) transcript_id: String,
}
impl StreamCipherConfig {
/// Creates a new builder for the stream cipher configuration.
pub fn builder() -> StreamCipherConfigBuilder {
StreamCipherConfigBuilder::default()
}
}
pub(crate) enum InputText {
Public { ids: Vec<String>, text: Vec<u8> },
Private { ids: Vec<String>, text: Vec<u8> },
Blind { ids: Vec<String> },
}
impl std::fmt::Debug for InputText {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Public { ids, .. } => f
.debug_struct("Public")
.field("ids", ids)
.field("text", &"{{ ... }}")
.finish(),
Self::Private { ids, .. } => f
.debug_struct("Private")
.field("ids", ids)
.field("text", &"{{ ... }}")
.finish(),
Self::Blind { ids, .. } => f.debug_struct("Blind").field("ids", ids).finish(),
}
}
}
/// The mode of execution.
#[derive(Debug, Clone, Copy)]
pub(crate) enum ExecutionMode {
/// Computes either the plaintext or the ciphertext.
Mpc,
/// Computes the ciphertext and proves its authenticity and correctness.
Prove,
/// Computes the ciphertext and verifies its authenticity and correctness.
Verify,
}
pub(crate) fn is_valid_mode(mode: &ExecutionMode, input_text: &InputText) -> bool {
matches!(
(mode, input_text),
(ExecutionMode::Mpc, _)
| (ExecutionMode::Prove, InputText::Private { .. })
| (ExecutionMode::Verify, InputText::Blind { .. })
)
}

View File

@@ -1,122 +0,0 @@
use core::fmt;
use std::error::Error;
use crate::CtrCircuit;
/// A stream cipher error.
#[derive(Debug, thiserror::Error)]
pub struct StreamCipherError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn Error + Send + Sync>>,
}
impl StreamCipherError {
pub(crate) fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
pub(crate) fn key_len<C: CtrCircuit>(len: usize) -> Self {
Self {
kind: ErrorKind::Key,
source: Some(
format!("invalid key length: expected {}, got {}", C::KEY_LEN, len).into(),
),
}
}
pub(crate) fn iv_len<C: CtrCircuit>(len: usize) -> Self {
Self {
kind: ErrorKind::Iv,
source: Some(format!("invalid iv length: expected {}, got {}", C::IV_LEN, len).into()),
}
}
pub(crate) fn explicit_nonce_len<C: CtrCircuit>(len: usize) -> Self {
Self {
kind: ErrorKind::ExplicitNonce,
source: Some(
format!(
"invalid explicit nonce length: expected {}, got {}",
C::NONCE_LEN,
len
)
.into(),
),
}
}
pub(crate) fn key_not_set() -> Self {
Self {
kind: ErrorKind::Key,
source: Some("key not set".into()),
}
}
}
#[derive(Debug)]
pub(crate) enum ErrorKind {
Vm,
Key,
Iv,
ExplicitNonce,
}
impl fmt::Display for StreamCipherError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
ErrorKind::Vm => write!(f, "vm error")?,
ErrorKind::Key => write!(f, "key error")?,
ErrorKind::Iv => write!(f, "iv error")?,
ErrorKind::ExplicitNonce => write!(f, "explicit nonce error")?,
}
if let Some(ref source) = self.source {
write!(f, " caused by: {}", source)?;
}
Ok(())
}
}
impl From<mpz_garble::MemoryError> for StreamCipherError {
fn from(error: mpz_garble::MemoryError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::LoadError> for StreamCipherError {
fn from(error: mpz_garble::LoadError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::ExecutionError> for StreamCipherError {
fn from(error: mpz_garble::ExecutionError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::ProveError> for StreamCipherError {
fn from(error: mpz_garble::ProveError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::VerifyError> for StreamCipherError {
fn from(error: mpz_garble::VerifyError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
impl From<mpz_garble::DecodeError> for StreamCipherError {
fn from(error: mpz_garble::DecodeError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

View File

@@ -1,216 +0,0 @@
use std::{collections::VecDeque, marker::PhantomData};
use mpz_garble::{value::ValueRef, Execute, Load, Memory, Prove, Thread, Verify};
use tracing::instrument;
use utils::id::NestedId;
use crate::{config::ExecutionMode, CtrCircuit, StreamCipherError};
pub(crate) struct KeyStream<C> {
block_counter: NestedId,
preprocessed: BlockVars,
_pd: PhantomData<C>,
}
#[derive(Default)]
struct BlockVars {
blocks: VecDeque<ValueRef>,
nonces: VecDeque<ValueRef>,
ctrs: VecDeque<ValueRef>,
}
impl BlockVars {
fn is_empty(&self) -> bool {
self.blocks.is_empty()
}
fn len(&self) -> usize {
self.blocks.len()
}
fn drain(&mut self, count: usize) -> BlockVars {
let blocks = self.blocks.drain(0..count).collect();
let nonces = self.nonces.drain(0..count).collect();
let ctrs = self.ctrs.drain(0..count).collect();
BlockVars {
blocks,
nonces,
ctrs,
}
}
fn extend(&mut self, vars: BlockVars) {
self.blocks.extend(vars.blocks);
self.nonces.extend(vars.nonces);
self.ctrs.extend(vars.ctrs);
}
fn iter(&self) -> impl Iterator<Item = (&ValueRef, &ValueRef, &ValueRef)> {
self.blocks
.iter()
.zip(self.nonces.iter())
.zip(self.ctrs.iter())
.map(|((block, nonce), ctr)| (block, nonce, ctr))
}
fn flatten(&self, len: usize) -> Vec<ValueRef> {
self.blocks
.iter()
.flat_map(|block| block.iter())
.take(len)
.cloned()
.map(|byte| ValueRef::Value { id: byte })
.collect()
}
}
impl<C: CtrCircuit> KeyStream<C> {
pub(crate) fn new(id: &str) -> Self {
let block_counter = NestedId::new(id).append_counter();
Self {
block_counter,
preprocessed: BlockVars::default(),
_pd: PhantomData,
}
}
fn define_vars(
&mut self,
mem: &mut impl Memory,
count: usize,
) -> Result<BlockVars, StreamCipherError> {
let mut vars = BlockVars::default();
for _ in 0..count {
let block_id = self.block_counter.increment_in_place();
let block = mem.new_output::<C::BLOCK>(&block_id.to_string())?;
let nonce =
mem.new_public_input::<C::NONCE>(&block_id.append_string("nonce").to_string())?;
let ctr =
mem.new_public_input::<[u8; 4]>(&block_id.append_string("ctr").to_string())?;
vars.blocks.push_back(block);
vars.nonces.push_back(nonce);
vars.ctrs.push_back(ctr);
}
Ok(vars)
}
#[instrument(level = "debug", skip_all, err)]
pub(crate) async fn preprocess<T>(
&mut self,
thread: &mut T,
key: &ValueRef,
iv: &ValueRef,
len: usize,
) -> Result<(), StreamCipherError>
where
T: Thread + Memory + Load + Send + 'static,
{
let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize;
let vars = self.define_vars(thread, block_count)?;
let calls = vars
.iter()
.map(|(block, nonce, ctr)| {
(
C::circuit(),
vec![key.clone(), iv.clone(), nonce.clone(), ctr.clone()],
vec![block.clone()],
)
})
.collect::<Vec<_>>();
for (circ, inputs, outputs) in calls {
thread.load(circ, &inputs, &outputs).await?;
}
self.preprocessed.extend(vars);
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
#[allow(clippy::too_many_arguments)]
pub(crate) async fn compute<T>(
&mut self,
thread: &mut T,
mode: ExecutionMode,
key: &ValueRef,
iv: &ValueRef,
explicit_nonce: Vec<u8>,
start_ctr: usize,
len: usize,
) -> Result<ValueRef, StreamCipherError>
where
T: Thread + Memory + Execute + Prove + Verify + Send + 'static,
{
let block_count = (len / C::BLOCK_LEN) + (len % C::BLOCK_LEN != 0) as usize;
let explicit_nonce_len = explicit_nonce.len();
let explicit_nonce: C::NONCE = explicit_nonce
.try_into()
.map_err(|_| StreamCipherError::explicit_nonce_len::<C>(explicit_nonce_len))?;
// Take any preprocessed blocks if available, and define new ones if needed.
let vars = if !self.preprocessed.is_empty() {
let mut vars = self
.preprocessed
.drain(block_count.min(self.preprocessed.len()));
if vars.len() < block_count {
vars.extend(self.define_vars(thread, block_count - vars.len())?)
}
vars
} else {
self.define_vars(thread, block_count)?
};
let mut calls = Vec::with_capacity(vars.len());
let mut inputs = Vec::with_capacity(vars.len() * 4);
for (i, (block, nonce_ref, ctr_ref)) in vars.iter().enumerate() {
thread.assign(nonce_ref, explicit_nonce)?;
thread.assign(ctr_ref, ((start_ctr + i) as u32).to_be_bytes())?;
inputs.push(key.clone());
inputs.push(iv.clone());
inputs.push(nonce_ref.clone());
inputs.push(ctr_ref.clone());
calls.push((
C::circuit(),
vec![key.clone(), iv.clone(), nonce_ref.clone(), ctr_ref.clone()],
vec![block.clone()],
));
}
match mode {
ExecutionMode::Mpc => {
thread.commit(&inputs).await?;
for (circ, inputs, outputs) in calls {
thread.execute(circ, &inputs, &outputs).await?;
}
}
ExecutionMode::Prove => {
// Note that after the circuit execution, the value of `block` can be considered
// as implicitly authenticated since `key` and `iv` have already
// been authenticated earlier and `nonce_ref` and `ctr_ref` are
// public. [Prove::prove] will **not** be called on `block` at
// any later point.
thread.commit_prove(&inputs).await?;
for (circ, inputs, outputs) in calls {
thread.execute_prove(circ, &inputs, &outputs).await?;
}
}
ExecutionMode::Verify => {
thread.commit_verify(&inputs).await?;
for (circ, inputs, outputs) in calls {
thread.execute_verify(circ, &inputs, &outputs).await?;
}
}
}
let keystream = thread.array_from_values(&vars.flatten(len))?;
Ok(keystream)
}
}

View File

@@ -1,474 +0,0 @@
//! This crate provides a 2PC stream cipher implementation using a block cipher
//! in counter mode.
//!
//! Each party plays a specific role, either the `StreamCipherLeader` or the
//! `StreamCipherFollower`. Both parties work together to encrypt and decrypt
//! messages using a shared key.
//!
//! # Transcript
//!
//! Using the `record` flag, the `StreamCipherFollower` can optionally use a
//! dedicated stream when encoding the plaintext labels, which allows the
//! `StreamCipherLeader` to build a transcript of active labels which are pushed
//! to the provided `TranscriptSink`.
//!
//! Afterwards, the `StreamCipherLeader` can create commitments to the
//! transcript which can be used in a selective disclosure protocol.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![deny(unsafe_code)]
mod cipher;
mod circuit;
mod config;
pub(crate) mod error;
pub(crate) mod keystream;
mod stream_cipher;
pub use self::cipher::{Aes128Ctr, CtrCircuit};
pub use config::{StreamCipherConfig, StreamCipherConfigBuilder, StreamCipherConfigBuilderError};
pub use error::StreamCipherError;
pub use stream_cipher::MpcStreamCipher;
use async_trait::async_trait;
use mpz_garble::value::ValueRef;
/// A trait for MPC stream ciphers.
#[async_trait]
pub trait StreamCipher<Cipher>: Send + Sync
where
Cipher: cipher::CtrCircuit,
{
/// Sets the key and iv for the stream cipher.
fn set_key(&mut self, key: ValueRef, iv: ValueRef);
/// Decodes the key for the stream cipher, revealing it to this party.
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError>;
/// Decodes the key for the stream cipher, revealing it to the other
/// party(s).
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError>;
/// Sets the transcript id
///
/// The stream cipher assigns unique identifiers to each byte of plaintext
/// during encryption and decryption.
///
/// For example, if the transcript id is set to `foo`, then the first byte
/// will be assigned the id `foo/0`, the second byte `foo/1`, and so on.
///
/// Each transcript id has an independent counter.
///
/// # Note
///
/// The state of a transcript counter is preserved between calls to
/// `set_transcript_id`.
fn set_transcript_id(&mut self, id: &str);
/// Preprocesses the keystream for the given number of bytes.
async fn preprocess(&mut self, len: usize) -> Result<(), StreamCipherError>;
/// Applies the keystream to the given plaintext, where all parties
/// provide the plaintext as an input.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `plaintext` - The message to apply the keystream to.
async fn encrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;
/// Applies the keystream to the given plaintext without revealing it
/// to the other party(s).
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `plaintext` - The message to apply the keystream to.
async fn encrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;
/// Applies the keystream to a plaintext provided by another party.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `len` - The length of the plaintext provided by another party.
async fn encrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<Vec<u8>, StreamCipherError>;
/// Decrypts a ciphertext by removing the keystream, where the plaintext
/// is revealed to all parties.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to decrypt.
async fn decrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;
/// Decrypts a ciphertext by removing the keystream, where the plaintext
/// is only revealed to this party.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to decrypt.
async fn decrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;
/// Decrypts a ciphertext by removing the keystream, where the plaintext
/// is not revealed to this party.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to decrypt.
async fn decrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), StreamCipherError>;
/// Locally decrypts the provided ciphertext and then proves in ZK to the
/// other party(s) that the plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be
/// achieved by calling the `decode_key_private` method.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to decrypt and prove.
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;
/// Verifies the other party(s) can prove they know a plaintext which
/// encrypts to the given ciphertext.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream.
/// * `ciphertext` - The ciphertext to verify.
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), StreamCipherError>;
/// Returns an additive share of the keystream block for the given explicit
/// nonce and counter.
///
/// # Arguments
///
/// * `explicit_nonce` - The explicit nonce to use for the keystream block.
/// * `ctr` - The counter to use for the keystream block.
async fn share_keystream_block(
&mut self,
explicit_nonce: Vec<u8>,
ctr: usize,
) -> Result<Vec<u8>, StreamCipherError>;
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::cipher::Aes128Ctr;
use super::*;
use mpz_garble::{
protocol::deap::mock::{create_mock_deap_vm, MockFollower, MockLeader},
Memory,
};
use rstest::*;
async fn create_test_pair<C: CtrCircuit>(
start_ctr: usize,
key: [u8; 16],
iv: [u8; 4],
) -> (
MpcStreamCipher<C, MockLeader>,
MpcStreamCipher<C, MockFollower>,
) {
let (leader_vm, follower_vm) = create_mock_deap_vm();
let leader_key = leader_vm.new_public_input::<[u8; 16]>("key").unwrap();
let leader_iv = leader_vm.new_public_input::<[u8; 4]>("iv").unwrap();
leader_vm.assign(&leader_key, key).unwrap();
leader_vm.assign(&leader_iv, iv).unwrap();
let follower_key = follower_vm.new_public_input::<[u8; 16]>("key").unwrap();
let follower_iv = follower_vm.new_public_input::<[u8; 4]>("iv").unwrap();
follower_vm.assign(&follower_key, key).unwrap();
follower_vm.assign(&follower_iv, iv).unwrap();
let leader_config = StreamCipherConfig::builder()
.id("test")
.start_ctr(start_ctr)
.build()
.unwrap();
let follower_config = StreamCipherConfig::builder()
.id("test")
.start_ctr(start_ctr)
.build()
.unwrap();
let mut leader = MpcStreamCipher::<C, _>::new(leader_config, leader_vm);
leader.set_key(leader_key, leader_iv);
let mut follower = MpcStreamCipher::<C, _>::new(follower_config, follower_vm);
follower.set_key(follower_key, follower_iv);
(leader, follower)
}
#[rstest]
#[timeout(Duration::from_millis(10000))]
#[tokio::test]
#[ignore = "expensive"]
async fn test_stream_cipher_public() {
let key = [0u8; 16];
let iv = [0u8; 4];
let explicit_nonce = [0u8; 8];
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
let leader_fut = async {
let leader_encrypted_msg = leader
.encrypt_public(explicit_nonce.to_vec(), msg.clone())
.await
.unwrap();
let leader_decrypted_msg = leader
.decrypt_public(explicit_nonce.to_vec(), leader_encrypted_msg.clone())
.await
.unwrap();
(leader_encrypted_msg, leader_decrypted_msg)
};
let follower_fut = async {
let follower_encrypted_msg = follower
.encrypt_public(explicit_nonce.to_vec(), msg.clone())
.await
.unwrap();
let follower_decrypted_msg = follower
.decrypt_public(explicit_nonce.to_vec(), follower_encrypted_msg.clone())
.await
.unwrap();
(follower_encrypted_msg, follower_decrypted_msg)
};
let (
(leader_encrypted_msg, leader_decrypted_msg),
(follower_encrypted_msg, follower_decrypted_msg),
) = futures::join!(leader_fut, follower_fut);
let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
assert_eq!(leader_encrypted_msg, reference);
assert_eq!(leader_decrypted_msg, msg);
assert_eq!(follower_encrypted_msg, reference);
assert_eq!(follower_decrypted_msg, msg);
}
#[rstest]
#[timeout(Duration::from_millis(10000))]
#[tokio::test]
#[ignore = "expensive"]
async fn test_stream_cipher_private() {
let key = [0u8; 16];
let iv = [0u8; 4];
let explicit_nonce = [1u8; 8];
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
let leader_fut = async {
let leader_decrypted_msg = leader
.decrypt_private(explicit_nonce.to_vec(), ciphertext.clone())
.await
.unwrap();
let leader_encrypted_msg = leader
.encrypt_private(explicit_nonce.to_vec(), leader_decrypted_msg.clone())
.await
.unwrap();
(leader_encrypted_msg, leader_decrypted_msg)
};
let follower_fut = async {
follower
.decrypt_blind(explicit_nonce.to_vec(), ciphertext.clone())
.await
.unwrap();
follower
.encrypt_blind(explicit_nonce.to_vec(), msg.len())
.await
.unwrap()
};
let ((leader_encrypted_msg, leader_decrypted_msg), follower_encrypted_msg) =
futures::join!(leader_fut, follower_fut);
assert_eq!(leader_encrypted_msg, ciphertext);
assert_eq!(leader_decrypted_msg, msg);
assert_eq!(follower_encrypted_msg, ciphertext);
futures::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
}
#[rstest]
#[timeout(Duration::from_millis(10000))]
#[tokio::test]
#[ignore = "expensive"]
async fn test_stream_cipher_share_key_block() {
let key = [0u8; 16];
let iv = [0u8; 4];
let explicit_nonce = [0u8; 8];
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
let leader_fut = async {
leader
.share_keystream_block(explicit_nonce.to_vec(), 1)
.await
.unwrap()
};
let follower_fut = async {
follower
.share_keystream_block(explicit_nonce.to_vec(), 1)
.await
.unwrap()
};
let (leader_share, follower_share) = futures::join!(leader_fut, follower_fut);
let key_block = leader_share
.into_iter()
.zip(follower_share)
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
let reference =
Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]).unwrap();
assert_eq!(reference, key_block);
}
#[rstest]
#[timeout(Duration::from_millis(10000))]
#[tokio::test]
#[ignore = "expensive"]
async fn test_stream_cipher_zk() {
let key = [0u8; 16];
let iv = [0u8; 4];
let explicit_nonce = [1u8; 8];
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg).unwrap();
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(2, key, iv).await;
futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap();
futures::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), ciphertext.clone()),
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
)
.unwrap();
futures::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
}
#[rstest]
#[case::one_block(16)]
#[case::partial(17)]
#[case::extra(128)]
#[timeout(Duration::from_millis(10000))]
#[tokio::test]
#[ignore = "expensive"]
async fn test_stream_cipher_preprocess(#[case] len: usize) {
let key = [0u8; 16];
let iv = [0u8; 4];
let explicit_nonce = [1u8; 8];
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
let (mut leader, mut follower) = create_test_pair::<Aes128Ctr>(1, key, iv).await;
let leader_fut = async {
leader.preprocess(len).await.unwrap();
leader
.decrypt_private(explicit_nonce.to_vec(), ciphertext.clone())
.await
.unwrap()
};
let follower_fut = async {
follower.preprocess(len).await.unwrap();
follower
.decrypt_blind(explicit_nonce.to_vec(), ciphertext.clone())
.await
.unwrap();
};
let (leader_decrypted_msg, _) = futures::join!(leader_fut, follower_fut);
assert_eq!(leader_decrypted_msg, msg);
futures::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
}
}

View File

@@ -1,671 +0,0 @@
use async_trait::async_trait;
use mpz_circuits::types::Value;
use std::collections::HashMap;
use tracing::instrument;
use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Load, Prove, Thread, Verify};
use utils::id::NestedId;
use crate::{
cipher::CtrCircuit,
circuit::build_array_xor,
config::{is_valid_mode, ExecutionMode, InputText, StreamCipherConfig},
keystream::KeyStream,
StreamCipher, StreamCipherError,
};
/// An MPC stream cipher.
#[derive(Debug)]
pub struct MpcStreamCipher<C, E>
where
C: CtrCircuit,
E: Thread + Execute + Decode + DecodePrivate + Send + Sync,
{
config: StreamCipherConfig,
state: State<C>,
thread: E,
}
struct State<C> {
/// Encoded key and IV for the cipher.
encoded_key_iv: Option<EncodedKeyAndIv>,
/// Key and IV for the cipher.
key_iv: Option<KeyAndIv>,
/// Keystream state.
keystream: KeyStream<C>,
/// Current transcript.
transcript: Transcript,
/// Maps a transcript ID to the corresponding transcript.
transcripts: HashMap<String, Transcript>,
/// Number of messages operated on.
counter: usize,
}
opaque_debug::implement!(State<C>);
#[derive(Clone)]
struct EncodedKeyAndIv {
key: ValueRef,
iv: ValueRef,
}
#[derive(Clone)]
struct KeyAndIv {
key: Vec<u8>,
iv: Vec<u8>,
}
/// A subset of plaintext bytes processed by the stream cipher.
///
/// Note that `Transcript` does not store the actual bytes. Instead, it provides
/// IDs which are assigned to plaintext bytes of the stream cipher.
struct Transcript {
/// The ID of this transcript.
id: String,
/// The ID for the next plaintext byte.
plaintext: NestedId,
}
impl Transcript {
fn new(id: &str) -> Self {
Self {
id: id.to_string(),
plaintext: NestedId::new(id).append_counter(),
}
}
/// Returns unique identifiers for the next plaintext bytes in the
/// transcript.
fn extend_plaintext(&mut self, len: usize) -> Vec<String> {
(0..len)
.map(|_| self.plaintext.increment_in_place().to_string())
.collect()
}
}
impl<C, E> MpcStreamCipher<C, E>
where
C: CtrCircuit,
E: Thread + Execute + Load + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static,
{
/// Creates a new counter-mode cipher.
pub fn new(config: StreamCipherConfig, thread: E) -> Self {
let keystream = KeyStream::new(&config.id);
let transcript = Transcript::new(&config.transcript_id);
Self {
config,
state: State {
encoded_key_iv: None,
key_iv: None,
keystream,
transcript,
transcripts: HashMap::new(),
counter: 0,
},
thread,
}
}
/// Returns a mutable reference to the underlying thread.
pub fn thread_mut(&mut self) -> &mut E {
&mut self.thread
}
/// Computes a keystream of the given length.
async fn compute_keystream(
&mut self,
explicit_nonce: Vec<u8>,
start_ctr: usize,
len: usize,
mode: ExecutionMode,
) -> Result<ValueRef, StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.as_ref()
.ok_or_else(StreamCipherError::key_not_set)?;
let keystream = self
.state
.keystream
.compute(
&mut self.thread,
mode,
key,
iv,
explicit_nonce,
start_ctr,
len,
)
.await?;
self.state.counter += 1;
Ok(keystream)
}
/// Applies the keystream to the provided input text.
async fn apply_keystream(
&mut self,
mode: ExecutionMode,
input_text: InputText,
keystream: ValueRef,
) -> Result<ValueRef, StreamCipherError> {
debug_assert!(
is_valid_mode(&mode, &input_text),
"invalid execution mode for input text"
);
let input_text = match input_text {
InputText::Public { ids, text } => {
let refs = text
.into_iter()
.zip(ids)
.map(|(byte, id)| {
let value_ref = self.thread.new_public_input::<u8>(&id)?;
self.thread.assign(&value_ref, byte)?;
Ok::<_, StreamCipherError>(value_ref)
})
.collect::<Result<Vec<_>, _>>()?;
self.thread.array_from_values(&refs)?
}
InputText::Private { ids, text } => {
let refs = text
.into_iter()
.zip(ids)
.map(|(byte, id)| {
let value_ref = self.thread.new_private_input::<u8>(&id)?;
self.thread.assign(&value_ref, byte)?;
Ok::<_, StreamCipherError>(value_ref)
})
.collect::<Result<Vec<_>, _>>()?;
self.thread.array_from_values(&refs)?
}
InputText::Blind { ids } => {
let refs = ids
.into_iter()
.map(|id| self.thread.new_blind_input::<u8>(&id))
.collect::<Result<Vec<_>, _>>()?;
self.thread.array_from_values(&refs)?
}
};
let output_text = self.thread.new_array_output::<u8>(
&format!("{}/out/{}", self.config.id, self.state.counter),
input_text.len(),
)?;
let circ = build_array_xor(input_text.len());
match mode {
ExecutionMode::Mpc => {
self.thread
.execute(circ, &[input_text, keystream], &[output_text.clone()])
.await?;
}
ExecutionMode::Prove => {
self.thread
.execute_prove(circ, &[input_text, keystream], &[output_text.clone()])
.await?;
}
ExecutionMode::Verify => {
self.thread
.execute_verify(circ, &[input_text, keystream], &[output_text.clone()])
.await?;
}
}
Ok(output_text)
}
async fn decode_public(&mut self, value: ValueRef) -> Result<Value, StreamCipherError> {
self.thread
.decode(&[value])
.await
.map_err(StreamCipherError::from)
.map(|mut output| output.pop().unwrap())
}
async fn decode_shared(&mut self, value: ValueRef) -> Result<Value, StreamCipherError> {
self.thread
.decode_shared(&[value])
.await
.map_err(StreamCipherError::from)
.map(|mut output| output.pop().unwrap())
}
async fn decode_private(&mut self, value: ValueRef) -> Result<Value, StreamCipherError> {
self.thread
.decode_private(&[value])
.await
.map_err(StreamCipherError::from)
.map(|mut output| output.pop().unwrap())
}
async fn decode_blind(&mut self, value: ValueRef) -> Result<(), StreamCipherError> {
self.thread.decode_blind(&[value]).await?;
Ok(())
}
async fn prove(&mut self, value: ValueRef) -> Result<(), StreamCipherError> {
self.thread.prove(&[value]).await?;
Ok(())
}
async fn verify(&mut self, value: ValueRef, expected: Value) -> Result<(), StreamCipherError> {
self.thread.verify(&[value], &[expected]).await?;
Ok(())
}
}
#[async_trait]
impl<C, E> StreamCipher<C> for MpcStreamCipher<C, E>
where
C: CtrCircuit,
E: Thread + Execute + Load + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static,
{
fn set_key(&mut self, key: ValueRef, iv: ValueRef) {
self.state.encoded_key_iv = Some(EncodedKeyAndIv { key, iv });
}
#[instrument(level = "debug", skip_all, err)]
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.clone()
.ok_or_else(StreamCipherError::key_not_set)?;
let [key, iv]: [_; 2] = self
.thread
.decode_private(&[key, iv])
.await?
.try_into()
.expect("decoded 2 values");
let key: Vec<u8> = key.try_into().expect("key is an array");
let iv: Vec<u8> = iv.try_into().expect("iv is an array");
self.state.key_iv = Some(KeyAndIv { key, iv });
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.clone()
.ok_or_else(StreamCipherError::key_not_set)?;
self.thread.decode_blind(&[key, iv]).await?;
Ok(())
}
fn set_transcript_id(&mut self, id: &str) {
if id == self.state.transcript.id {
return;
}
let transcript = self
.state
.transcripts
.remove(id)
.unwrap_or_else(|| Transcript::new(id));
let old_transcript = std::mem::replace(&mut self.state.transcript, transcript);
self.state
.transcripts
.insert(old_transcript.id.clone(), old_transcript);
}
#[instrument(level = "debug", skip_all, err)]
async fn preprocess(&mut self, len: usize) -> Result<(), StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.as_ref()
.ok_or_else(StreamCipherError::key_not_set)?;
self.state
.keystream
.preprocess(&mut self.thread, key, iv, len)
.await
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError> {
let keystream = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
plaintext.len(),
ExecutionMode::Mpc,
)
.await?;
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
let ciphertext = self
.apply_keystream(
ExecutionMode::Mpc,
InputText::Public {
ids: plaintext_ids,
text: plaintext,
},
keystream,
)
.await?;
let ciphertext: Vec<u8> = self
.decode_public(ciphertext)
.await?
.try_into()
.expect("ciphertext is array");
Ok(ciphertext)
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError> {
let keystream = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
plaintext.len(),
ExecutionMode::Mpc,
)
.await?;
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
let ciphertext = self
.apply_keystream(
ExecutionMode::Mpc,
InputText::Private {
ids: plaintext_ids,
text: plaintext,
},
keystream,
)
.await?;
let ciphertext: Vec<u8> = self
.decode_public(ciphertext)
.await?
.try_into()
.expect("ciphertext is array");
Ok(ciphertext)
}
#[instrument(level = "debug", skip_all, err)]
async fn encrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<Vec<u8>, StreamCipherError> {
let keystream = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
len,
ExecutionMode::Mpc,
)
.await?;
let plaintext_ids = self.state.transcript.extend_plaintext(len);
let ciphertext = self
.apply_keystream(
ExecutionMode::Mpc,
InputText::Blind { ids: plaintext_ids },
keystream,
)
.await?;
let ciphertext: Vec<u8> = self
.decode_public(ciphertext)
.await?
.try_into()
.expect("ciphertext is array");
Ok(ciphertext)
}
#[instrument(level = "debug", skip_all, err)]
async fn decrypt_public(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError> {
// TODO: We may want to support writing to the transcript when decrypting
// in public mode.
let keystream = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
ciphertext.len(),
ExecutionMode::Mpc,
)
.await?;
let ciphertext_ids = (0..ciphertext.len())
.map(|i| format!("ct/{}/{}", self.state.counter, i))
.collect();
let plaintext = self
.apply_keystream(
ExecutionMode::Mpc,
InputText::Public {
ids: ciphertext_ids,
text: ciphertext,
},
keystream,
)
.await?;
let plaintext: Vec<u8> = self
.decode_public(plaintext)
.await?
.try_into()
.expect("plaintext is array");
Ok(plaintext)
}
#[instrument(level = "debug", skip_all, err)]
async fn decrypt_private(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError> {
let keystream_ref = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
ciphertext.len(),
ExecutionMode::Mpc,
)
.await?;
let keystream: Vec<u8> = self
.decode_private(keystream_ref.clone())
.await?
.try_into()
.expect("keystream is array");
let plaintext = ciphertext
.into_iter()
.zip(keystream)
.map(|(c, k)| c ^ k)
.collect::<Vec<_>>();
// Prove plaintext encrypts back to ciphertext.
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
let ciphertext = self
.apply_keystream(
ExecutionMode::Prove,
InputText::Private {
ids: plaintext_ids,
text: plaintext.clone(),
},
keystream_ref,
)
.await?;
self.prove(ciphertext).await?;
Ok(plaintext)
}
#[instrument(level = "debug", skip_all, err)]
async fn decrypt_blind(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), StreamCipherError> {
let keystream_ref = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
ciphertext.len(),
ExecutionMode::Mpc,
)
.await?;
self.decode_blind(keystream_ref.clone()).await?;
// Verify the plaintext encrypts back to ciphertext.
let plaintext_ids = self.state.transcript.extend_plaintext(ciphertext.len());
let ciphertext_ref = self
.apply_keystream(
ExecutionMode::Verify,
InputText::Blind { ids: plaintext_ids },
keystream_ref,
)
.await?;
self.verify(ciphertext_ref, ciphertext.into()).await?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError> {
let KeyAndIv { key, iv } = self
.state
.key_iv
.clone()
.ok_or_else(StreamCipherError::key_not_set)?;
let plaintext = C::apply_keystream(
&key,
&iv,
self.config.start_ctr,
&explicit_nonce,
&ciphertext,
)?;
// Prove plaintext encrypts back to ciphertext.
let keystream = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
plaintext.len(),
ExecutionMode::Prove,
)
.await?;
let plaintext_ids = self.state.transcript.extend_plaintext(plaintext.len());
let ciphertext = self
.apply_keystream(
ExecutionMode::Prove,
InputText::Private {
ids: plaintext_ids,
text: plaintext.clone(),
},
keystream,
)
.await?;
self.prove(ciphertext).await?;
Ok(plaintext)
}
#[instrument(level = "debug", skip_all, err)]
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), StreamCipherError> {
let keystream = self
.compute_keystream(
explicit_nonce,
self.config.start_ctr,
ciphertext.len(),
ExecutionMode::Verify,
)
.await?;
let plaintext_ids = self.state.transcript.extend_plaintext(ciphertext.len());
let ciphertext_ref = self
.apply_keystream(
ExecutionMode::Verify,
InputText::Blind { ids: plaintext_ids },
keystream,
)
.await?;
self.verify(ciphertext_ref, ciphertext.into()).await?;
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn share_keystream_block(
&mut self,
explicit_nonce: Vec<u8>,
ctr: usize,
) -> Result<Vec<u8>, StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.as_ref()
.ok_or_else(StreamCipherError::key_not_set)?;
let key_block = self
.state
.keystream
.compute(
&mut self.thread,
ExecutionMode::Mpc,
key,
iv,
explicit_nonce,
ctr,
C::BLOCK_LEN,
)
.await?;
let share = self
.decode_shared(key_block)
.await?
.try_into()
.expect("key block is array");
Ok(share)
}
}

View File

@@ -38,7 +38,7 @@ pub struct Ghash<C, Ctx> {
impl<C, Ctx> Ghash<C, Ctx>
where
Ctx: Context,
C: ShareConvert<Ctx, Gf2_128>,
{
/// Creates a new instance.
@@ -91,7 +91,7 @@ impl<C, Ctx> Debug for Ghash<C, Ctx> {
#[async_trait]
impl<Ctx, C> UniversalHash for Ghash<C, Ctx>
where
Ctx: Context,
C: Preprocess<Ctx, Error = ShareConversionError> + ShareConvert<Ctx, Gf2_128> + Send,
{
#[instrument(level = "info", fields(thread = %self.context.id()), skip_all, err)]

View File

@@ -30,6 +30,7 @@ opaque-debug = { workspace = true }
p256 = { workspace = true, features = ["serde"] }
rand = { workspace = true }
rand_core = { workspace = true }
rand_chacha = { workspace = true }
rs_merkle = { workspace = true, features = ["serde"] }
rstest = { workspace = true, optional = true }
serde = { workspace = true }
@@ -38,6 +39,7 @@ thiserror = { workspace = true }
tiny-keccak = { version = "2.0", features = ["keccak"] }
web-time = { workspace = true }
webpki-roots = { workspace = true }
itybity = { workspace = true }
[dev-dependencies]
bincode = { workspace = true }

View File

@@ -12,6 +12,7 @@ use crate::{
request::Request,
serialize::CanonicalSerialize,
signing::SignatureAlgId,
transcript::encoding::EncoderSecret,
CryptoProvider,
};
@@ -25,7 +26,7 @@ pub struct Sign {
server_ephemeral_key: Option<ServerEphemKey>,
cert_commitment: ServerCertCommitment,
encoding_commitment_root: Option<TypedHash>,
encoding_seed: Option<Vec<u8>>,
encoder_secret: Option<EncoderSecret>,
}
/// An attestation builder.
@@ -91,7 +92,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
server_ephemeral_key: None,
cert_commitment,
encoding_commitment_root,
encoding_seed: None,
encoder_secret: None,
},
})
}
@@ -110,9 +111,9 @@ impl AttestationBuilder<'_, Sign> {
self
}
/// Sets the encoding seed.
pub fn encoding_seed(&mut self, seed: Vec<u8>) -> &mut Self {
self.state.encoding_seed = Some(seed);
/// Sets the encoder secret.
pub fn encoder_secret(&mut self, secret: EncoderSecret) -> &mut Self {
self.state.encoder_secret = Some(secret);
self
}
@@ -125,7 +126,7 @@ impl AttestationBuilder<'_, Sign> {
server_ephemeral_key,
cert_commitment,
encoding_commitment_root,
encoding_seed,
encoder_secret,
} = self.state;
let hasher = provider.hash.get(&hash_alg).map_err(|_| {
@@ -144,14 +145,14 @@ impl AttestationBuilder<'_, Sign> {
})?;
let encoding_commitment = if let Some(root) = encoding_commitment_root {
let Some(seed) = encoding_seed else {
let Some(secret) = encoder_secret else {
return Err(AttestationBuilderError::new(
ErrorKind::Field,
"encoding commitment requested but seed was not set",
"encoding commitment requested but encoder_secret was not set",
));
};
Some(EncodingCommitment { root, seed })
Some(EncodingCommitment { root, secret })
} else {
None
};
@@ -246,7 +247,7 @@ mod test {
use crate::{
connection::{HandshakeData, HandshakeDataV1_2},
fixtures::{
encoder_seed, encoding_provider, request_fixture, ConnectionFixture, RequestFixture,
encoder_secret, encoding_provider, request_fixture, ConnectionFixture, RequestFixture,
},
hash::Blake3,
transcript::Transcript,
@@ -435,7 +436,7 @@ mod test {
attestation_builder
.connection_info(connection_info)
.encoding_seed(encoder_seed().to_vec());
.encoder_secret(encoder_secret());
let err = attestation_builder.build(crypto_provider).err().unwrap();
assert!(matches!(err.kind, ErrorKind::Field));
@@ -471,7 +472,7 @@ mod test {
attestation_builder
.server_ephemeral_key(server_ephemeral_key)
.encoding_seed(encoder_seed().to_vec());
.encoder_secret(encoder_secret());
let err = attestation_builder.build(crypto_provider).err().unwrap();
assert!(matches!(err.kind, ErrorKind::Field));

View File

@@ -2,7 +2,7 @@
mod provider;
pub use provider::ChaChaProvider;
pub use provider::FixtureEncodingProvider;
use hex::FromHex;
use p256::ecdsa::SigningKey;
@@ -17,7 +17,7 @@ use crate::{
request::{Request, RequestConfig},
signing::SignatureAlgId,
transcript::{
encoding::{EncodingProvider, EncodingTree},
encoding::{EncoderSecret, EncodingProvider, EncodingTree},
Transcript, TranscriptCommitConfigBuilder,
},
CryptoProvider,
@@ -131,12 +131,26 @@ impl ConnectionFixture {
/// Returns an encoding provider fixture.
pub fn encoding_provider(tx: &[u8], rx: &[u8]) -> impl EncodingProvider {
ChaChaProvider::new(encoder_seed(), Transcript::new(tx, rx))
let secret = encoder_secret();
FixtureEncodingProvider::new(&secret, Transcript::new(tx, rx))
}
/// Returns an encoder seed fixture.
pub fn encoder_seed() -> [u8; 32] {
[0u8; 32]
/// Seed fixture.
const SEED: [u8; 32] = [0; 32];
/// Delta fixture.
const DELTA: [u8; 16] = [1; 16];
/// Returns an encoder secret fixture.
pub fn encoder_secret() -> EncoderSecret {
EncoderSecret::new(SEED, DELTA)
}
/// Returns a tampered encoder secret fixture.
pub fn encoder_secret_tampered_seed() -> EncoderSecret {
let mut seed = SEED;
seed[0] += 1;
EncoderSecret::new(seed, DELTA)
}
/// Returns a notary signing key fixture.
@@ -205,7 +219,7 @@ pub fn attestation_fixture(
request: Request,
connection: ConnectionFixture,
signature_alg: SignatureAlgId,
encoding_seed: Vec<u8>,
secret: EncoderSecret,
) -> Attestation {
let ConnectionFixture {
connection_info,
@@ -237,7 +251,7 @@ pub fn attestation_fixture(
attestation_builder
.connection_info(connection_info)
.server_ephemeral_key(server_ephemeral_key)
.encoding_seed(encoding_seed);
.encoder_secret(secret);
attestation_builder.build(&provider).unwrap()
}

View File

@@ -1,27 +1,25 @@
use mpz_garble_core::ChaChaEncoder;
use crate::transcript::{
encoding::{Encoder, EncodingProvider},
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider},
Direction, Idx, Transcript,
};
/// A ChaCha encoding provider fixture.
pub struct ChaChaProvider {
encoder: ChaChaEncoder,
/// A encoding provider fixture.
pub struct FixtureEncodingProvider {
encoder: Box<dyn Encoder>,
transcript: Transcript,
}
impl ChaChaProvider {
/// Creates a new ChaCha encoding provider.
pub(crate) fn new(seed: [u8; 32], transcript: Transcript) -> Self {
impl FixtureEncodingProvider {
/// Creates a new encoding provider fixture.
pub(crate) fn new(secret: &EncoderSecret, transcript: Transcript) -> Self {
Self {
encoder: ChaChaEncoder::new(seed),
encoder: Box::new(new_encoder(secret)),
transcript,
}
}
}
impl EncodingProvider for ChaChaProvider {
impl EncodingProvider for FixtureEncodingProvider {
fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option<Vec<u8>> {
let seq = self.transcript.get(direction, idx)?;
Some(self.encoder.encode_subsequence(direction, &seq))

View File

@@ -97,7 +97,7 @@ mod test {
use crate::{
connection::{ServerCertOpening, TranscriptLength},
fixtures::{
attestation_fixture, encoder_seed, encoding_provider, request_fixture,
attestation_fixture, encoder_secret, encoding_provider, request_fixture,
ConnectionFixture, RequestFixture,
},
hash::{Blake3, Hash, HashAlgId},
@@ -122,7 +122,7 @@ mod test {
request.clone(),
connection,
SignatureAlgId::SECP256K1,
encoder_seed().to_vec(),
encoder_secret(),
);
assert!(request.validate(&attestation).is_ok())
@@ -144,7 +144,7 @@ mod test {
request.clone(),
connection,
SignatureAlgId::SECP256K1,
encoder_seed().to_vec(),
encoder_secret(),
);
request.signature_alg = SignatureAlgId::SECP256R1;
@@ -169,7 +169,7 @@ mod test {
request.clone(),
connection,
SignatureAlgId::SECP256K1,
encoder_seed().to_vec(),
encoder_secret(),
);
request.hash_alg = HashAlgId::SHA256;
@@ -194,7 +194,7 @@ mod test {
request.clone(),
connection,
SignatureAlgId::SECP256K1,
encoder_seed().to_vec(),
encoder_secret(),
);
let ConnectionFixture {
@@ -229,7 +229,7 @@ mod test {
request.clone(),
connection,
SignatureAlgId::SECP256K1,
encoder_seed().to_vec(),
encoder_secret(),
);
request.encoding_commitment_root = Some(TypedHash {

View File

@@ -52,11 +52,6 @@ pub use proof::{
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
};
/// Sent data transcript ID.
pub static TX_TRANSCRIPT_ID: &str = "tx";
/// Received data transcript ID.
pub static RX_TRANSCRIPT_ID: &str = "rx";
/// A transcript contains all the data communicated over a TLS connection.
#[derive(Clone, Serialize, Deserialize)]
pub struct Transcript {
@@ -588,17 +583,6 @@ impl Subsequence {
#[error("invalid subsequence: {0}")]
pub struct InvalidSubsequence(&'static str);
/// Returns the value ID for each byte in the provided range set.
#[doc(hidden)]
pub fn get_value_ids(direction: Direction, idx: &Idx) -> impl Iterator<Item = String> + '_ {
let id = match direction {
Direction::Sent => TX_TRANSCRIPT_ID,
Direction::Received => RX_TRANSCRIPT_ID,
};
idx.iter().map(move |idx| format!("{}/{}", id, idx))
}
mod validation {
use super::*;
@@ -716,7 +700,7 @@ mod validation {
.sent_idx
.0
.iter_ranges()
.last()
.next_back()
.unwrap()
.end;

View File

@@ -8,7 +8,7 @@ mod proof;
mod provider;
mod tree;
pub(crate) use encoder::{new_encoder, Encoder};
pub use encoder::{new_encoder, Encoder, EncoderSecret};
pub use proof::{EncodingProof, EncodingProofError};
pub use provider::EncodingProvider;
pub use tree::EncodingTree;
@@ -32,7 +32,7 @@ pub struct EncodingCommitment {
/// Merkle root of the encoding commitments.
pub root: TypedHash,
/// Seed used to generate the encodings.
pub seed: Vec<u8>,
pub secret: EncoderSecret,
}
impl_domain_separator!(EncodingCommitment);

View File

@@ -1,18 +1,79 @@
use mpz_circuits::types::ValueType;
use mpz_core::serialize::CanonicalSerialize;
use mpz_garble_core::ChaChaEncoder;
use crate::transcript::{Direction, Idx, Subsequence};
use itybity::ToBits;
use rand::{RngCore, SeedableRng};
use rand_chacha::ChaCha12Rng;
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Subsequence, RX_TRANSCRIPT_ID, TX_TRANSCRIPT_ID};
/// The size of the encoding for 1 bit, in bytes.
const BIT_ENCODING_SIZE: usize = 16;
/// The size of the encoding for 1 byte, in bytes.
const BYTE_ENCODING_SIZE: usize = 128;
pub(crate) fn new_encoder(seed: [u8; 32]) -> impl Encoder {
ChaChaEncoder::new(seed)
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],
}
opaque_debug::implement!(EncoderSecret);
impl EncoderSecret {
/// Creates a new secret.
///
/// # Arguments
///
/// * `seed` - The seed for the PRG.
/// * `delta` - Delta for deriving the one-encodings.
pub fn new(seed: [u8; 32], delta: [u8; 16]) -> Self {
Self { seed, delta }
}
/// Returns the seed.
pub fn seed(&self) -> &[u8; 32] {
&self.seed
}
/// Returns the delta.
pub fn delta(&self) -> &[u8; 16] {
&self.delta
}
}
/// Creates a new encoder.
pub fn new_encoder(secret: &EncoderSecret) -> impl Encoder {
ChaChaEncoder::new(secret)
}
pub(crate) struct ChaChaEncoder {
seed: [u8; 32],
delta: [u8; 16],
}
impl ChaChaEncoder {
pub(crate) fn new(secret: &EncoderSecret) -> Self {
let seed = *secret.seed();
let delta = *secret.delta();
Self { seed, delta }
}
pub(crate) fn new_prg(&self, stream_id: u64) -> ChaCha12Rng {
let mut prg = ChaCha12Rng::from_seed(self.seed);
prg.set_stream(stream_id);
prg.set_word_pos(0);
prg
}
}
/// A transcript encoder.
///
/// This is an internal implementation detail that should not be exposed to the
/// public API.
pub(crate) trait Encoder {
pub trait Encoder {
/// Returns the zero encoding for the given index.
fn encode_idx(&self, direction: Direction, idx: &Idx) -> Vec<u8>;
/// Returns the encoding for the given subsequence of the transcript.
///
/// # Arguments
@@ -22,28 +83,45 @@ pub(crate) trait Encoder {
}
impl Encoder for ChaChaEncoder {
fn encode_subsequence(&self, direction: Direction, seq: &Subsequence) -> Vec<u8> {
let id = match direction {
Direction::Sent => TX_TRANSCRIPT_ID,
Direction::Received => RX_TRANSCRIPT_ID,
fn encode_idx(&self, direction: Direction, idx: &Idx) -> Vec<u8> {
// ChaCha20 encoder works with 32-bit words. Each encoded bit is 128 bits long.
const WORDS_PER_BYTE: u128 = 8 * 128 / 32;
let stream_id: u64 = match direction {
Direction::Sent => 0,
Direction::Received => 1,
};
let mut encoding = Vec::with_capacity(seq.len() * 16);
for (byte_id, &byte) in seq.index().iter().zip(seq.data()) {
let id_hash = mpz_core::utils::blake3(format!("{}/{}", id, byte_id).as_bytes());
let id = u64::from_be_bytes(id_hash[..8].try_into().unwrap());
let mut prg = self.new_prg(stream_id);
let mut encoding: Vec<u8> = vec![0u8; idx.len() * BYTE_ENCODING_SIZE];
encoding.extend(
<ChaChaEncoder as mpz_garble_core::Encoder>::encode_by_type(
self,
id,
&ValueType::U8,
)
.select(byte)
.expect("encoding is a byte encoding")
.to_bytes(),
)
let mut pos = 0;
for range in idx.iter_ranges() {
let len = range.len() * BYTE_ENCODING_SIZE;
prg.set_word_pos(range.start as u128 * WORDS_PER_BYTE);
prg.fill_bytes(&mut encoding[pos..pos + len]);
pos += len;
}
encoding
}
fn encode_subsequence(&self, direction: Direction, seq: &Subsequence) -> Vec<u8> {
const ZERO: [u8; 16] = [0; BIT_ENCODING_SIZE];
let mut encoding = self.encode_idx(direction, seq.index());
for (byte_idx, &byte) in seq.data().iter().enumerate() {
let start = byte_idx * BYTE_ENCODING_SIZE;
for (bit_idx, bit) in byte.iter_lsb0().enumerate() {
let pos = start + (bit_idx * BIT_ENCODING_SIZE);
let delta = if bit { &self.delta } else { &ZERO };
encoding[pos..pos + BIT_ENCODING_SIZE]
.iter_mut()
.zip(delta)
.for_each(|(a, b)| *a ^= *b);
}
}
encoding
}
}

View File

@@ -51,11 +51,7 @@ impl EncodingProof {
) -> Result<PartialTranscript, EncodingProofError> {
let hasher = provider.hash.get(&commitment.root.alg)?;
let seed: [u8; 32] = commitment.seed.clone().try_into().map_err(|_| {
EncodingProofError::new(ErrorKind::Commitment, "encoding seed not 32 bytes")
})?;
let encoder = new_encoder(seed);
let encoder = new_encoder(&commitment.secret);
let Self {
inclusion_proof,
openings,
@@ -149,7 +145,6 @@ impl EncodingProofError {
#[derive(Debug)]
enum ErrorKind {
Provider,
Commitment,
Proof,
}
@@ -159,7 +154,6 @@ impl fmt::Display for EncodingProofError {
match self.kind {
ErrorKind::Provider => f.write_str("provider error")?,
ErrorKind::Commitment => f.write_str("commitment error")?,
ErrorKind::Proof => f.write_str("proof error")?,
}
@@ -188,9 +182,12 @@ mod test {
use tlsn_data_fixtures::http::{request::POST_JSON, response::OK_JSON};
use crate::{
fixtures::{encoder_seed, encoding_provider},
fixtures::{encoder_secret, encoder_secret_tampered_seed, encoding_provider},
hash::Blake3,
transcript::{encoding::EncodingTree, Idx, Transcript},
transcript::{
encoding::{EncoderSecret, EncodingTree},
Idx, Transcript,
},
};
use super::*;
@@ -201,7 +198,7 @@ mod test {
commitment: EncodingCommitment,
}
fn new_encoding_fixture(seed: Vec<u8>) -> EncodingFixture {
fn new_encoding_fixture(secret: EncoderSecret) -> EncodingFixture {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
@@ -226,7 +223,7 @@ mod test {
let commitment = EncodingCommitment {
root: tree.root(),
seed,
secret,
};
EncodingFixture {
@@ -237,12 +234,12 @@ mod test {
}
#[test]
fn test_verify_encoding_proof_invalid_seed() {
fn test_verify_encoding_proof_tampered_seed() {
let EncodingFixture {
transcript,
proof,
commitment,
} = new_encoding_fixture(encoder_seed().to_vec().split_off(1));
} = new_encoding_fixture(encoder_secret_tampered_seed());
let err = proof
.verify_with_provider(
@@ -252,7 +249,7 @@ mod test {
)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Commitment));
assert!(matches!(err.kind, ErrorKind::Proof));
}
#[test]
@@ -261,7 +258,7 @@ mod test {
transcript,
proof,
commitment,
} = new_encoding_fixture(encoder_seed().to_vec());
} = new_encoding_fixture(encoder_secret());
let err = proof
.verify_with_provider(
@@ -283,7 +280,7 @@ mod test {
transcript,
mut proof,
commitment,
} = new_encoding_fixture(encoder_seed().to_vec());
} = new_encoding_fixture(encoder_secret());
let Opening { seq, .. } = proof.openings.values_mut().next().unwrap();
@@ -306,7 +303,7 @@ mod test {
transcript,
mut proof,
commitment,
} = new_encoding_fixture(encoder_seed().to_vec());
} = new_encoding_fixture(encoder_secret());
let Opening { blinder, .. } = proof.openings.values_mut().next().unwrap();

View File

@@ -198,7 +198,7 @@ impl EncodingTree {
mod tests {
use super::*;
use crate::{
fixtures::{encoder_seed, encoding_provider},
fixtures::{encoder_secret, encoding_provider},
hash::Blake3,
transcript::encoding::EncodingCommitment,
CryptoProvider,
@@ -235,7 +235,7 @@ mod tests {
let commitment = EncodingCommitment {
root: tree.root(),
seed: encoder_seed().to_vec(),
secret: encoder_secret(),
};
let partial_transcript = proof
@@ -272,7 +272,7 @@ mod tests {
let commitment = EncodingCommitment {
root: tree.root(),
seed: encoder_seed().to_vec(),
secret: encoder_secret(),
};
let partial_transcript = proof

View File

@@ -361,7 +361,7 @@ mod tests {
use crate::{
fixtures::{
attestation_fixture, encoder_seed, encoding_provider, request_fixture,
attestation_fixture, encoder_secret, encoding_provider, request_fixture,
ConnectionFixture, RequestFixture,
},
hash::Blake3,
@@ -448,7 +448,7 @@ mod tests {
request,
connection,
SignatureAlgId::SECP256K1,
encoder_seed().to_vec(),
encoder_secret(),
);
let provider = CryptoProvider::default();

View File

@@ -1,7 +1,7 @@
use tlsn_core::{
attestation::{Attestation, AttestationConfig},
connection::{HandshakeData, HandshakeDataV1_2},
fixtures::{self, encoder_seed, ConnectionFixture},
fixtures::{self, encoder_secret, ConnectionFixture},
hash::Blake3,
presentation::PresentationOutput,
request::{Request, RequestConfig},
@@ -84,7 +84,7 @@ fn test_api() {
.connection_info(connection_info.clone())
// Server key Notary received during handshake
.server_ephemeral_key(server_ephemeral_key)
.encoding_seed(encoder_seed().to_vec());
.encoder_secret(encoder_secret());
let attestation = attestation_builder.build(&provider).unwrap();

72
crates/mpc-tls/Cargo.toml Normal file
View File

@@ -0,0 +1,72 @@
[package]
name = "tlsn-mpc-tls"
authors = ["TLSNotary Team"]
description = "Implementation of the backend trait for 2PC"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.0.0"
publish = false
edition = "2021"
[lib]
name = "mpc_tls"
[features]
default = []
[dependencies]
tlsn-cipher = { workspace = true }
tlsn-common = { workspace = true }
tlsn-hmac-sha256 = { workspace = true }
tlsn-key-exchange = { workspace = true }
tlsn-tls-backend = { workspace = true }
tlsn-tls-core = { workspace = true, features = ["serde"] }
tlsn-utils-aio = { workspace = true }
mpz-common = { workspace = true }
mpz-core = { workspace = true }
mpz-fields = { workspace = true }
mpz-ot = { workspace = true }
mpz-ole = { workspace = true }
mpz-share-conversion = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
mpz-circuits = { workspace = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
serio = { workspace = true }
async-trait = { workspace = true }
derive_builder = { workspace = true }
enum-try-as-inner = { workspace = true }
futures = { workspace = true }
p256 = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
rand = { workspace = true }
opaque-debug = { workspace = true }
aes = { workspace = true }
aes-gcm = { workspace = true }
ctr = { workspace = true }
ghash_rc = { package = "ghash", version = "0.5" }
cipher-crate = { package = "cipher", version = "0.4" }
tokio = { workspace = true, features = ["sync"] }
pin-project-lite = { workspace = true }
[dev-dependencies]
mpz-ole = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true }
mpz-garble = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }
rand_chacha = { workspace = true }
generic-array = { workspace = true }
uid-mux = { workspace = true, features = ["serio", "test-utils"] }
rstest = { workspace = true }

View File

@@ -0,0 +1,67 @@
use derive_builder::Builder;
const MIN_SENT: usize = 32;
const MIN_SENT_RECORDS: usize = 8;
const MIN_RECV: usize = 32;
const MIN_RECV_RECORDS: usize = 8;
/// MPC-TLS configuration.
#[derive(Debug, Clone, Builder)]
#[builder(build_fn(skip))]
pub struct Config {
/// Defers decryption of received data until after the MPC-TLS connection is
/// closed.
///
/// The received data will be decrypted locally without MPC, thus improving
/// bandwidth usage and performance.
pub(crate) defer_decryption: bool,
/// Maximum number of sent TLS records. Data is transmitted in records up to
/// 16KB long.
pub(crate) max_sent_records: usize,
/// Maximum number of sent bytes.
pub(crate) max_sent: usize,
/// Maximum number of received TLS records. Data is transmitted in records
/// up to 16KB long.
pub(crate) max_recv_records: usize,
/// Maximum number of received bytes which will be decrypted while
/// the TLS connection is active. Data which can be decrypted after the TLS
/// connection will be decrypted for free.
pub(crate) max_recv_online: usize,
}
impl Config {
/// Creates a new builder.
pub fn builder() -> ConfigBuilder {
ConfigBuilder::default()
}
}
impl ConfigBuilder {
/// Builds the configuration.
pub fn build(&self) -> Result<Config, ConfigBuilderError> {
let defer_decryption = self.defer_decryption.unwrap_or(true);
let max_sent = MIN_SENT
+ self
.max_sent
.ok_or(ConfigBuilderError::UninitializedField("max_sent"))?;
let max_recv_online = MIN_RECV
+ self
.max_recv_online
.ok_or(ConfigBuilderError::UninitializedField("max_recv_online"))?;
let max_sent_records = self
.max_sent_records
.unwrap_or_else(|| MIN_SENT_RECORDS + max_sent.div_ceil(16384));
let max_recv_records = self
.max_recv_records
.unwrap_or_else(|| MIN_RECV_RECORDS + max_recv_online.div_ceil(16384));
Ok(Config {
defer_decryption,
max_sent_records,
max_sent,
max_recv_records,
max_recv_online,
})
}
}

View File

@@ -0,0 +1,72 @@
use std::{
array::from_fn,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use crate::Role;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
binary::{Binary, U8},
DecodeError, DecodeFutureTyped,
};
use mpz_vm_core::{prelude::*, Vm, VmError};
use pin_project_lite::pin_project;
use rand::{thread_rng, Rng};
pin_project! {
/// Supports decoding into additive shares.
#[project = OneTimePadSharedProj]
pub(crate) enum OneTimePadShared<T> {
Leader {
otp: T,
},
Follower {
#[pin] value: DecodeFutureTyped<BitVec, T>,
otp: T,
}
}
}
impl<const N: usize> OneTimePadShared<[u8; N]> {
pub(crate) fn new(
role: Role,
value: Array<U8, N>,
vm: &mut dyn Vm<Binary>,
) -> Result<Self, VmError> {
let mut rng = thread_rng();
let otp: [u8; N] = from_fn(|_| rng.gen());
match role {
Role::Leader => {
let masked = vm.mask_private(value, otp)?;
let masked = vm.mask_blind(masked)?;
_ = vm.decode(masked)?;
Ok(Self::Leader { otp })
}
Role::Follower => {
let masked = vm.mask_blind(value)?;
let masked = vm.mask_private(masked, otp)?;
let value = vm.decode(masked)?;
Ok(Self::Follower { value, otp })
}
}
}
}
impl Future for OneTimePadShared<[u8; 16]> {
type Output = Result<[u8; 16], DecodeError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
OneTimePadSharedProj::Leader { otp } => Poll::Ready(Ok(*otp)),
OneTimePadSharedProj::Follower { value, otp } => {
let mut value = ready!(value.poll(cx))?;
value.iter_mut().zip(otp).for_each(|(a, b)| *a ^= *b);
Poll::Ready(Ok(value))
}
}
}
}

113
crates/mpc-tls/src/error.rs Normal file
View File

@@ -0,0 +1,113 @@
use hmac_sha256::PrfError;
use key_exchange::KeyExchangeError;
use tls_backend::BackendError;
/// MPC-TLS error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct MpcTlsError(#[from] ErrorRepr);
impl MpcTlsError {
pub(crate) fn peer<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Peer(err.into()))
}
pub(crate) fn actor<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Actor(err.into()))
}
pub(crate) fn state<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::State(err.into()))
}
pub(crate) fn alloc<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Alloc(err.into()))
}
pub(crate) fn preprocess<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Preprocess(err.into()))
}
pub(crate) fn hs<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Handshake(err.into()))
}
pub(crate) fn record_layer<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::RecordLayer(err.into()))
}
pub(crate) fn other<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Other(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("mpc-tls error: {0}")]
enum ErrorRepr {
#[error("peer misbehaved")]
Peer(Box<dyn std::error::Error + Send + Sync>),
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("actor error: {0}")]
Actor(Box<dyn std::error::Error + Send + Sync>),
#[error("state error: {0}")]
State(Box<dyn std::error::Error + Send + Sync>),
#[error("allocation error: {0}")]
Alloc(Box<dyn std::error::Error + Send + Sync>),
#[error("preprocess error: {0}")]
Preprocess(Box<dyn std::error::Error + Send + Sync>),
#[error("handshake error: {0}")]
Handshake(Box<dyn std::error::Error + Send + Sync>),
#[error("record layer error: {0}")]
RecordLayer(Box<dyn std::error::Error + Send + Sync>),
#[error("other: {0}")]
Other(Box<dyn std::error::Error + Send + Sync>),
}
impl From<std::io::Error> for MpcTlsError {
fn from(value: std::io::Error) -> Self {
MpcTlsError(ErrorRepr::Io(value))
}
}
impl From<MpcTlsError> for BackendError {
fn from(value: MpcTlsError) -> Self {
BackendError::InternalError(value.to_string())
}
}
impl From<KeyExchangeError> for MpcTlsError {
fn from(value: KeyExchangeError) -> Self {
MpcTlsError::hs(value)
}
}
impl From<PrfError> for MpcTlsError {
fn from(value: PrfError) -> Self {
MpcTlsError::hs(value)
}
}

View File

@@ -0,0 +1,582 @@
use crate::{
msg::Message,
record_layer::{aead::MpcAesGcm, RecordLayer},
Config, FollowerData, MpcTlsError, Role, SessionKeys, Vm,
};
use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput};
use ke::KeyExchange;
use key_exchange::{self as ke, MpcKeyExchange};
use mpz_common::{scoped_futures::ScopedFutureExt, Context, Flush};
use mpz_core::{bitvec::BitVec, Block};
use mpz_memory_core::{DecodeFutureTyped, MemoryExt};
use mpz_ole::{Receiver as OLEReceiver, Sender as OLESender};
use mpz_ot::{
rcot::{RCOTReceiver, RCOTSender},
rot::{
any::{AnyReceiver, AnySender},
randomize::{RandomizeRCOTReceiver, RandomizeRCOTSender},
},
};
use mpz_share_conversion::{ShareConversionReceiver, ShareConversionSender};
use rand::{thread_rng, Rng};
use serio::stream::IoStreamExt;
use std::mem;
use tls_core::msgs::{
alert::AlertMessagePayload,
codec::{Codec, Reader},
enums::{AlertDescription, ContentType, NamedGroup, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload},
};
use tlsn_common::transcript::TlsTranscript;
use tracing::{debug, instrument};
/// MPC-TLS follower.
#[derive(Debug)]
pub struct MpcTlsFollower {
config: Config,
ctx: Context,
state: State,
}
impl MpcTlsFollower {
/// Creates a new follower.
pub fn new<CS, CR>(
config: Config,
ctx: Context,
vm: Vm,
cot_send: CS,
cot_recv: (CR, CR, CR),
) -> Self
where
CS: RCOTSender<Block> + Flush + Send + Sync + 'static,
CR: RCOTReceiver<bool, Block> + Flush + Send + Sync + 'static,
{
let mut rng = thread_rng();
let ke = Box::new(MpcKeyExchange::new(
key_exchange::Role::Follower,
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
RandomizeRCOTReceiver::new(cot_recv.0),
))),
ShareConversionSender::new(OLESender::new(
rng.gen(),
AnySender::new(RandomizeRCOTSender::new(cot_send)),
)),
)) as Box<dyn KeyExchange + Send + Sync>;
let prf = MpcPrf::new(
PrfConfig::builder()
.role(hmac_sha256::Role::Follower)
.build()
.expect("PRF config is valid"),
);
let encrypter = MpcAesGcm::new(
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
RandomizeRCOTReceiver::new(cot_recv.1),
))),
Role::Follower,
);
let decrypter = MpcAesGcm::new(
ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new(
RandomizeRCOTReceiver::new(cot_recv.2),
))),
Role::Follower,
);
let record_layer = RecordLayer::new(Role::Follower, encrypter, decrypter);
Self {
config,
ctx,
state: State::Init {
vm,
ke,
prf,
record_layer,
},
}
}
/// Allocates resources for the connection.
pub fn alloc(&mut self) -> Result<SessionKeys, MpcTlsError> {
let State::Init {
vm,
mut ke,
mut prf,
mut record_layer,
} = self.state.take()
else {
return Err(MpcTlsError::state("must be in init state to allocate"));
};
let (keys, cf_vd, sf_vd) = {
let vm = &mut (*vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?);
let pms = ke.alloc(vm)?;
let PrfOutput { keys, cf_vd, sf_vd } = prf.alloc(vm, pms)?;
record_layer.set_keys(
keys.client_write_key,
keys.client_iv,
keys.server_write_key,
keys.server_iv,
)?;
prf.set_client_random(vm, None)?;
let cf_vd = vm.decode(cf_vd).map_err(MpcTlsError::alloc)?;
let sf_vd = vm.decode(sf_vd).map_err(MpcTlsError::alloc)?;
record_layer.alloc(
vm,
self.config.max_sent_records,
self.config.max_recv_records,
self.config.max_sent,
self.config.max_recv_online,
)?;
(keys, cf_vd, sf_vd)
};
self.state = State::Setup {
vm,
keys: keys.into(),
ke,
prf,
record_layer,
cf_vd,
sf_vd,
};
Ok(keys.into())
}
/// Preprocesses the connection.
#[instrument(skip_all, err)]
pub async fn preprocess(&mut self) -> Result<(), MpcTlsError> {
let State::Setup {
vm,
keys,
mut ke,
prf,
mut record_layer,
cf_vd,
sf_vd,
} = self.state.take()
else {
return Err(MpcTlsError::state("must be in setup state to preprocess"));
};
let (ke, record_layer, _) = {
let mut vm = vm
.clone()
.try_lock_owned()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
self.ctx
.try_join3(
|ctx| {
async move {
ke.setup(ctx)
.await
.map(|_| ke)
.map_err(MpcTlsError::preprocess)
}
.scope_boxed()
},
|ctx| {
async move {
record_layer
.preprocess(ctx)
.await
.map(|_| record_layer)
.map_err(MpcTlsError::preprocess)
}
.scope_boxed()
},
|ctx| {
async move {
vm.flush(ctx).await.map_err(MpcTlsError::preprocess)?;
vm.preprocess(ctx).await.map_err(MpcTlsError::preprocess)?;
vm.flush(ctx).await.map_err(MpcTlsError::preprocess)?;
Ok::<_, MpcTlsError>(())
}
.scope_boxed()
},
)
.await
.map_err(MpcTlsError::hs)??
};
self.state = State::Ready {
vm,
keys,
ke,
prf,
record_layer,
cf_vd,
sf_vd,
};
Ok(())
}
/// Runs the follower.
#[instrument(skip_all, err)]
pub async fn run(mut self) -> Result<(Context, FollowerData), MpcTlsError> {
let State::Ready {
vm,
keys,
mut ke,
mut prf,
mut record_layer,
cf_vd: mut cf_vd_fut,
sf_vd: mut sf_vd_fut,
} = self.state.take()
else {
return Err(MpcTlsError::state("must be in ready state to run"));
};
let mut server_random = None;
let mut server_key = None;
let mut cf_vd = None;
let mut sf_vd = None;
loop {
let msg: Message = self.ctx.io_mut().expect_next().await?;
match msg {
Message::SetServerRandom(random) => {
if server_random.is_some() {
return Err(MpcTlsError::hs("server random already set"));
}
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_server_random(&mut (*vm), random.random)?;
server_random = Some(random);
}
Message::SetServerKey(key) => {
if server_key.is_some() {
return Err(MpcTlsError::hs("server key already set"));
}
let key = key.key;
let NamedGroup::secp256r1 = key.group else {
return Err(MpcTlsError::hs("unsupported server key group"));
};
ke.set_server_key(
p256::PublicKey::from_sec1_bytes(&key.key)
.map_err(|_| MpcTlsError::hs("failed to parse server key"))?,
)?;
server_key = Some(key);
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
ke.compute_shares(&mut self.ctx).await?;
ke.assign(&mut (*vm))?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
ke.finalize().await?;
record_layer.setup(&mut self.ctx).await?;
}
Message::ClientFinishedVd(vd) => {
if cf_vd.is_some() {
return Err(MpcTlsError::hs("client finished VD already computed"));
}
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_cf_hash(&mut (*vm), vd.handshake_hash)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
cf_vd = Some(
cf_vd_fut
.try_recv()
.map_err(MpcTlsError::hs)?
.ok_or(MpcTlsError::hs("client finished VD not computed"))?,
);
}
Message::ServerFinishedVd(vd) => {
if sf_vd.is_some() {
return Err(MpcTlsError::hs("server finished VD already computed"));
}
let mut vm = vm
.try_lock()
.map_err(|_| MpcTlsError::other("VM lock is held"))?;
prf.set_sf_hash(&mut (*vm), vd.handshake_hash)?;
vm.execute_all(&mut self.ctx)
.await
.map_err(MpcTlsError::hs)?;
sf_vd = Some(
sf_vd_fut
.try_recv()
.map_err(MpcTlsError::hs)?
.ok_or(MpcTlsError::hs("server finished VD not computed"))?,
);
}
Message::Encrypt(encrypt) => {
record_layer
.push_encrypt(
encrypt.typ,
encrypt.version,
encrypt.len,
encrypt.plaintext,
encrypt.mode,
)
.map_err(MpcTlsError::record_layer)?;
}
Message::Decrypt(decrypt) => {
record_layer
.push_decrypt(
decrypt.typ,
decrypt.version,
decrypt.explicit_nonce,
decrypt.ciphertext,
decrypt.tag,
decrypt.mode,
)
.map_err(MpcTlsError::record_layer)?;
}
Message::Flush { is_decrypting } => {
record_layer
.flush(&mut self.ctx, vm.clone(), is_decrypting)
.await?;
debug!("flushed record layer");
}
Message::CloseConnection => {
break;
}
}
}
debug!("committing");
let transcript = record_layer.commit(&mut self.ctx, vm).await?;
debug!("committed");
let server_key = server_key.ok_or(MpcTlsError::hs("server key not set"))?;
let cf_vd = cf_vd.ok_or(MpcTlsError::hs("client finished VD not computed"))?;
let sf_vd = sf_vd.ok_or(MpcTlsError::hs("server finished VD not computed"))?;
validate_transcript(cf_vd, sf_vd, &transcript)?;
Ok((
self.ctx,
FollowerData {
server_key,
transcript,
keys,
},
))
}
}
enum State {
Init {
vm: Vm,
ke: Box<dyn KeyExchange + Send + Sync + 'static>,
prf: MpcPrf,
record_layer: RecordLayer,
},
Setup {
vm: Vm,
keys: SessionKeys,
ke: Box<dyn KeyExchange + Send + Sync + 'static>,
prf: MpcPrf,
record_layer: RecordLayer,
cf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
sf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
},
Ready {
vm: Vm,
keys: SessionKeys,
ke: Box<dyn KeyExchange + Send + Sync + 'static>,
prf: MpcPrf,
record_layer: RecordLayer,
cf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
sf_vd: DecodeFutureTyped<BitVec, [u8; 12]>,
},
Error,
}
impl State {
fn take(&mut self) -> Self {
mem::replace(self, State::Error)
}
}
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Init { .. } => f.debug_struct("Init").finish_non_exhaustive(),
Self::Setup { .. } => f.debug_struct("Setup").finish_non_exhaustive(),
Self::Ready { .. } => f.debug_struct("Ready").finish_non_exhaustive(),
Self::Error => write!(f, "Error"),
}
}
}
fn validate_transcript(
cf_vd: [u8; 12],
sf_vd: [u8; 12],
transcript: &TlsTranscript,
) -> Result<(), MpcTlsError> {
let mut sent = transcript.sent.iter();
let mut recv = transcript.recv.iter();
// Make sure the client finished verify data message was consistent.
if let Some(record) = sent.next() {
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
"client finished message was hidden from the follower",
))?;
let mut reader = Reader::init(payload);
let payload = HandshakeMessagePayload::read_version(&mut reader, ProtocolVersion::TLSv1_2)
.ok_or(MpcTlsError::record_layer(
"first record sent was not a handshake message",
))?;
let HandshakePayload::Finished(actual_cf_vd) = payload.payload else {
return Err(MpcTlsError::record_layer(
"first record sent was not a client finished message",
));
};
if cf_vd != actual_cf_vd.0.as_slice() {
return Err(MpcTlsError::record_layer(format!(
"client finished verify data does not match output from PRF: {:?} != {:?}",
cf_vd, actual_cf_vd
)));
}
} else {
return Err(MpcTlsError::record_layer("no records were sent"));
}
// Make sure the server finished verify data message was consistent.
if let Some(record) = recv.next() {
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
"server finished message was hidden from the follower",
))?;
let mut reader = Reader::init(payload);
let payload = HandshakeMessagePayload::read_version(&mut reader, ProtocolVersion::TLSv1_2)
.ok_or(MpcTlsError::record_layer(
"first record received was not a handshake message",
))?;
let HandshakePayload::Finished(actual_sf_vd) = payload.payload else {
return Err(MpcTlsError::record_layer(
"first record received was not a server finished message",
));
};
if sf_vd != actual_sf_vd.0.as_slice() {
return Err(MpcTlsError::record_layer(format!(
"server finished verify data does not match output from PRF: {:?} != {:?}",
sf_vd, actual_sf_vd
)));
}
} else {
return Err(MpcTlsError::record_layer("no records were received"));
}
// Verify last record sent was either application data or close notify.
if let Some(record) = sent.next_back() {
match record.typ {
ContentType::ApplicationData => {}
ContentType::Alert => {
// Ensure the alert is a close notify.
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
"alert content was hidden from the follower",
))?;
let mut reader = Reader::init(payload);
let payload = AlertMessagePayload::read(&mut reader)
.ok_or(MpcTlsError::record_layer("alert message was malformed"))?;
let AlertDescription::CloseNotify = payload.description else {
return Err(MpcTlsError::record_layer(
"sent alert that is not close notify",
));
};
}
typ => {
return Err(MpcTlsError::record_layer(format!(
"sent unexpected record content type: {:?}",
typ
)))
}
}
}
// Verify last record received was either application data or close notify.
if let Some(record) = recv.next_back() {
match record.typ {
ContentType::ApplicationData => {}
ContentType::Alert => {
// Ensure the alert is a close notify.
let payload = record.plaintext.as_ref().ok_or(MpcTlsError::record_layer(
"alert content was hidden from the follower",
))?;
let mut reader = Reader::init(payload);
let payload = AlertMessagePayload::read(&mut reader)
.ok_or(MpcTlsError::record_layer("alert message was malformed"))?;
let AlertDescription::CloseNotify = payload.description else {
return Err(MpcTlsError::record_layer(
"received alert that is not close notify",
));
};
}
typ => {
return Err(MpcTlsError::record_layer(format!(
"received unexpected record content type: {:?}",
typ
)))
}
}
}
// Ensure all other records were application data.
for record in sent {
if record.typ != ContentType::ApplicationData {
return Err(MpcTlsError::record_layer(format!(
"sent unexpected record content type: {:?}",
record.typ
)));
}
}
for record in recv {
if record.typ != ContentType::ApplicationData {
return Err(MpcTlsError::record_layer(format!(
"received unexpected record content type: {:?}",
record.typ
)));
}
}
Ok(())
}

1004
crates/mpc-tls/src/leader.rs Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

106
crates/mpc-tls/src/lib.rs Normal file
View File

@@ -0,0 +1,106 @@
//! TLSNotary MPC-TLS protocol implementation.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]
#![forbid(unsafe_code)]
mod config;
mod decode;
mod error;
pub(crate) mod follower;
pub(crate) mod leader;
mod msg;
mod record_layer;
pub(crate) mod utils;
pub use config::{Config, ConfigBuilder, ConfigBuilderError};
pub use error::MpcTlsError;
pub use follower::MpcTlsFollower;
pub use leader::{LeaderCtrl, MpcTlsLeader};
use std::{future::Future, pin::Pin, sync::Arc};
use mpz_memory_core::{
binary::{Binary, U8},
Array,
};
use mpz_vm_core::Vm as VmTrait;
use tls_core::{
cert::ServerCertDetails,
ke::ServerKxDetails,
key::PublicKey,
msgs::{
enums::{CipherSuite, ProtocolVersion},
handshake::Random,
},
};
use tlsn_common::transcript::TlsTranscript;
use tokio::sync::Mutex;
pub(crate) type BoxFut<T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>;
/// Virtual machine type.
pub type Vm = Arc<Mutex<dyn VmTrait<Binary> + Send + Sync + 'static>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Role {
Leader,
Follower,
}
/// TLS session keys.
#[derive(Debug, Clone)]
pub struct SessionKeys {
/// Client write key.
pub client_write_key: Array<U8, 16>,
/// Client write IV.
pub client_write_iv: Array<U8, 4>,
/// Server write key.
pub server_write_key: Array<U8, 16>,
/// Server write IV.
pub server_write_iv: Array<U8, 4>,
}
impl From<hmac_sha256::SessionKeys> for SessionKeys {
fn from(keys: hmac_sha256::SessionKeys) -> Self {
Self {
client_write_key: keys.client_write_key,
client_write_iv: keys.client_iv,
server_write_key: keys.server_write_key,
server_write_iv: keys.server_iv,
}
}
}
/// MPC-TLS Leader output.
#[derive(Debug)]
pub struct LeaderOutput {
/// TLS protocol version.
pub protocol_version: ProtocolVersion,
/// TLS cipher suite.
pub cipher_suite: CipherSuite,
/// Server ephemeral public key.
pub server_key: PublicKey,
/// Server certificate chain and related details.
pub server_cert_details: ServerCertDetails,
/// Key exchange details.
pub server_kx_details: ServerKxDetails,
/// Client random
pub client_random: Random,
/// Server random
pub server_random: Random,
/// TLS transcript.
pub transcript: TlsTranscript,
/// TLS session keys.
pub keys: SessionKeys,
}
/// MPC-TLS Follower output.
#[derive(Debug)]
pub struct FollowerData {
/// Server ephemeral public key.
pub server_key: PublicKey,
/// TLS transcript.
pub transcript: TlsTranscript,
/// TLS session keys.
pub keys: SessionKeys,
}

62
crates/mpc-tls/src/msg.rs Normal file
View File

@@ -0,0 +1,62 @@
use serde::{Deserialize, Serialize};
use tls_core::{
key::PublicKey,
msgs::enums::{ContentType, ProtocolVersion},
};
use crate::record_layer::{DecryptMode, EncryptMode};
/// MPC-TLS protocol message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum Message {
SetServerRandom(SetServerRandom),
SetServerKey(SetServerKey),
ClientFinishedVd(ClientFinishedVd),
ServerFinishedVd(ServerFinishedVd),
Encrypt(Encrypt),
Decrypt(Decrypt),
Flush { is_decrypting: bool },
CloseConnection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct SetServerRandom {
pub(crate) random: [u8; 32],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct SetServerKey {
pub(crate) key: PublicKey,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Decrypt {
pub(crate) typ: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) explicit_nonce: Vec<u8>,
pub(crate) ciphertext: Vec<u8>,
pub(crate) tag: Vec<u8>,
pub(crate) mode: DecryptMode,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Encrypt {
pub(crate) typ: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) len: usize,
pub(crate) plaintext: Option<Vec<u8>>,
pub(crate) mode: EncryptMode,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ClientFinishedVd {
pub handshake_hash: [u8; 32],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ServerFinishedVd {
pub handshake_hash: [u8; 32],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CloseConnection;

View File

@@ -0,0 +1,558 @@
//! TLS record layer.
pub(crate) mod aead;
mod aes_ctr;
mod decrypt;
mod encrypt;
use std::{collections::VecDeque, mem::take, sync::Arc};
use aead::MpcAesGcm;
use futures::TryFutureExt;
use mpz_common::{scoped_futures::ScopedFutureExt, Context, Task};
use mpz_memory_core::{
binary::{Binary, U8},
Array,
};
use mpz_vm_core::Vm as VmTrait;
use rand::{thread_rng, RngCore};
use serde::{Deserialize, Serialize};
use tls_core::{
cipher::make_tls12_aad,
msgs::enums::{ContentType, ProtocolVersion},
};
use tlsn_common::transcript::{Record, TlsTranscript};
use tokio::sync::Mutex;
use crate::{
record_layer::{aes_ctr::AesCtr, decrypt::DecryptOp, encrypt::EncryptOp},
MpcTlsError, Role, Vm,
};
pub(crate) use decrypt::DecryptMode;
pub(crate) use encrypt::EncryptMode;
const MAX_RECORD_SIZE: usize = 1026 * 16;
// This limits how much the leader can cause the follower to allocate.
const MAX_BUFFER_SIZE: usize = (16 * (1 << 20)) / MAX_RECORD_SIZE;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct PlainRecord {
pub(crate) typ: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) plaintext: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct EncryptedRecord {
pub(crate) typ: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) explicit_nonce: Vec<u8>,
pub(crate) ciphertext: Vec<u8>,
pub(crate) tag: Option<Vec<u8>>,
}
enum State {
Init,
Online {
recv_otp: Option<Vec<u8>>,
sent_records: Vec<Record>,
recv_records: Vec<Record>,
},
Complete,
Error,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
/// MPC-TLS record layer.
pub(crate) struct RecordLayer {
role: Role,
write_seq: u64,
read_seq: u64,
encrypter: Arc<Mutex<MpcAesGcm>>,
decrypt: Arc<Mutex<MpcAesGcm>>,
aes_ctr: AesCtr,
state: State,
encrypt_buffer: Vec<EncryptOp>,
decrypt_buffer: Vec<DecryptOp>,
encrypted_buffer: VecDeque<EncryptedRecord>,
decrypted_buffer: VecDeque<PlainRecord>,
}
impl RecordLayer {
/// Creates a new record layer.
pub(crate) fn new(role: Role, encrypt: MpcAesGcm, decrypt: MpcAesGcm) -> Self {
Self {
role,
write_seq: 0,
read_seq: 0,
encrypter: Arc::new(Mutex::new(encrypt)),
decrypt: Arc::new(Mutex::new(decrypt)),
aes_ctr: AesCtr::new(role),
state: State::Init,
encrypt_buffer: Vec::new(),
decrypt_buffer: Vec::new(),
encrypted_buffer: VecDeque::new(),
decrypted_buffer: VecDeque::new(),
}
}
/// Allocates resources for the record layer.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `sent_records` - Number of sent records to allocate.
/// * `recv_records` - Number of received records to allocate.
/// * `sent_len` - Total length of sent records to allocate.
/// * `recv_len` - Total length of received records to allocate.
pub(crate) fn alloc(
&mut self,
vm: &mut dyn VmTrait<Binary>,
sent_records: usize,
recv_records: usize,
sent_len: usize,
recv_len: usize,
) -> Result<(), MpcTlsError> {
let State::Init = self.state.take() else {
return Err(MpcTlsError::other("record layer is already allocated"));
};
let mut encrypt = self
.encrypter
.try_lock()
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
let mut decrypt = self
.decrypt
.try_lock()
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
encrypt
.alloc(vm, sent_records, sent_len)
.map_err(MpcTlsError::record_layer)?;
decrypt
.alloc(vm, recv_records, recv_len)
.map_err(MpcTlsError::record_layer)?;
let recv_otp = match self.role {
Role::Leader => {
let mut recv_otp = vec![0u8; recv_len];
thread_rng().fill_bytes(&mut recv_otp);
Some(recv_otp)
}
Role::Follower => None,
};
self.aes_ctr.alloc(vm)?;
self.state = State::Online {
recv_otp,
sent_records: Vec::new(),
recv_records: Vec::new(),
};
Ok(())
}
pub(crate) async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), MpcTlsError> {
let mut encrypt = self
.encrypter
.clone()
.try_lock_owned()
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
let mut decrypt = self
.decrypt
.clone()
.try_lock_owned()
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
// Computes GHASH keys in parallel.
ctx.try_join(
|ctx| async move { encrypt.preprocess(ctx).await }.scope_boxed(),
|ctx| async move { decrypt.preprocess(ctx).await }.scope_boxed(),
)
.await
.map_err(MpcTlsError::record_layer)?
.map_err(MpcTlsError::record_layer)?;
Ok(())
}
/// Sets the keys for the record layer.
pub(crate) fn set_keys(
&mut self,
client_write_key: Array<U8, 16>,
client_iv: Array<U8, 4>,
server_write_key: Array<U8, 16>,
server_iv: Array<U8, 4>,
) -> Result<(), MpcTlsError> {
let mut encrypt = self
.encrypter
.try_lock()
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
let mut decrypt = self
.decrypt
.try_lock()
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
encrypt.set_key(client_write_key);
encrypt.set_iv(client_iv);
decrypt.set_key(server_write_key);
decrypt.set_iv(server_iv);
self.aes_ctr.set_key(server_write_key, server_iv);
Ok(())
}
/// Sets up the record layer.
pub(crate) async fn setup(&mut self, ctx: &mut Context) -> Result<(), MpcTlsError> {
let mut encrypt = self
.encrypter
.clone()
.try_lock_owned()
.map_err(|_| MpcTlsError::other("encrypt lock is held"))?;
let mut decrypt = self
.decrypt
.clone()
.try_lock_owned()
.map_err(|_| MpcTlsError::other("decrypt lock is held"))?;
// Computes GHASH keys in parallel.
ctx.try_join(
|ctx| async move { encrypt.setup(ctx).await }.scope_boxed(),
|ctx| async move { decrypt.setup(ctx).await }.scope_boxed(),
)
.await
.map_err(MpcTlsError::record_layer)?
.map_err(MpcTlsError::record_layer)?;
Ok(())
}
pub(crate) fn is_empty(&self) -> bool {
self.encrypt_buffer.is_empty()
&& self.decrypt_buffer.is_empty()
&& self.encrypted_buffer.is_empty()
&& self.decrypted_buffer.is_empty()
}
pub(crate) fn wants_flush(&self) -> bool {
!self.encrypt_buffer.is_empty() || !self.decrypt_buffer.is_empty()
}
pub(crate) fn push_encrypt(
&mut self,
typ: ContentType,
version: ProtocolVersion,
len: usize,
plaintext: Option<Vec<u8>>,
mode: EncryptMode,
) -> Result<(), MpcTlsError> {
if self.encrypt_buffer.len() >= MAX_BUFFER_SIZE {
return Err(MpcTlsError::peer("encrypt buffer is full"));
}
let (seq, explicit_nonce, aad) = self.next_write(typ, version, len);
self.encrypt_buffer.push(EncryptOp::new(
seq,
typ,
version,
len,
plaintext,
explicit_nonce,
aad,
mode,
)?);
Ok(())
}
pub(crate) fn push_decrypt(
&mut self,
typ: ContentType,
version: ProtocolVersion,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
tag: Vec<u8>,
mode: DecryptMode,
) -> Result<(), MpcTlsError> {
if self.decrypt_buffer.len() >= MAX_BUFFER_SIZE {
return Err(MpcTlsError::peer("decrypt buffer is full"));
}
let (seq, aad) = self.next_read(typ, version, ciphertext.len());
self.decrypt_buffer.push(DecryptOp::new(
seq,
typ,
version,
explicit_nonce,
ciphertext,
aad,
tag,
mode,
));
Ok(())
}
/// Returns the next encrypted record.
pub(crate) fn next_encrypted(&mut self) -> Option<EncryptedRecord> {
self.encrypted_buffer.pop_front()
}
/// Returns the next decrypted record.
pub(crate) fn next_decrypted(&mut self) -> Option<PlainRecord> {
self.decrypted_buffer.pop_front()
}
pub(crate) async fn flush(
&mut self,
ctx: &mut Context,
vm: Vm,
is_decrypting: bool,
) -> Result<(), MpcTlsError> {
let State::Online {
recv_otp,
sent_records,
recv_records,
..
} = &mut self.state
else {
return Err(MpcTlsError::state(
"record layer must be in online state to flush",
));
};
let mut vm = vm
.try_lock_owned()
.map_err(|_| MpcTlsError::record_layer("VM lock is held"))?;
let mut encrypter = self
.encrypter
.try_lock()
.map_err(|_| MpcTlsError::record_layer("encrypt lock is held"))?;
let mut decrypter = self
.decrypt
.try_lock()
.map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?;
let encrypt_ops = take(&mut self.encrypt_buffer);
let decrypt_end = if is_decrypting {
self.decrypt_buffer.len()
} else {
// Position of the first application data in the decrypt buffer.
self.decrypt_buffer
.iter()
.position(|op| op.typ == ContentType::ApplicationData)
.unwrap_or(self.decrypt_buffer.len())
};
let decrypt_ops: Vec<_> = self.decrypt_buffer.drain(..decrypt_end).collect();
let (pending_encrypt, compute_tags) =
encrypt::encrypt(&mut (*vm), &mut encrypter, &encrypt_ops)?;
let pending_decrypt =
decrypt::decrypt_mpc(&mut (*vm), &mut decrypter, recv_otp.as_mut(), &decrypt_ops)?;
let verify_tags = decrypt::verify_tags(&mut (*vm), &mut decrypter, &decrypt_ops)?;
// Run tag computation and VM in parallel.
let (mut tags, _, _) = ctx
.try_join3(
|ctx| {
compute_tags
.run(ctx)
.map_err(MpcTlsError::record_layer)
.scope_boxed()
},
|ctx| {
verify_tags
.run(ctx)
.map_err(MpcTlsError::record_layer)
.scope_boxed()
},
|ctx| {
async move { vm.execute_all(ctx).map_err(MpcTlsError::record_layer).await }
.scope_boxed()
},
)
.await
.map_err(MpcTlsError::record_layer)??;
// Reverse tags, as we will be popping from the back.
if let Some(tags) = tags.as_mut() {
tags.reverse();
}
for (op, pending) in encrypt_ops.into_iter().zip(pending_encrypt) {
let ciphertext = pending.output.try_encrypt()?;
self.encrypted_buffer.push_back(EncryptedRecord {
typ: op.typ,
version: op.version,
explicit_nonce: op.explicit_nonce.clone(),
ciphertext: ciphertext.clone(),
tag: tags.as_mut().and_then(Vec::pop),
});
sent_records.push(Record {
seq: op.seq,
typ: op.typ,
plaintext: op.plaintext,
plaintext_ref: pending.plaintext_ref,
explicit_nonce: op.explicit_nonce,
ciphertext,
});
}
for (op, pending) in decrypt_ops.into_iter().zip(pending_decrypt) {
let plaintext = pending.output.try_decrypt()?;
self.decrypted_buffer.push_back(PlainRecord {
typ: op.typ,
version: op.version,
plaintext: plaintext.clone(),
});
recv_records.push(Record {
seq: op.seq,
typ: op.typ,
plaintext,
plaintext_ref: None,
explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext,
});
}
Ok(())
}
pub(crate) async fn commit(
&mut self,
ctx: &mut Context,
vm: Vm,
) -> Result<TlsTranscript, MpcTlsError> {
let State::Online {
sent_records,
mut recv_records,
..
} = self.state.take()
else {
return Err(MpcTlsError::state(
"record layer must be in online state to commit",
));
};
if !self.encrypt_buffer.is_empty() {
return Err(MpcTlsError::state(
"record layer can not commit with pending encrypt operations",
));
}
let mut vm = vm
.try_lock_owned()
.map_err(|_| MpcTlsError::record_layer("VM lock is held"))?;
let mut decrypter = self
.decrypt
.try_lock()
.map_err(|_| MpcTlsError::record_layer("decrypt lock is held"))?;
let buffered_ops = take(&mut self.decrypt_buffer);
// Verify tags of buffered ciphertexts.
let verify_tags = decrypt::verify_tags(&mut (*vm), &mut decrypter, &buffered_ops)?;
vm.execute_all(ctx)
.await
.map_err(MpcTlsError::record_layer)?;
verify_tags
.run(ctx)
.await
.map_err(MpcTlsError::record_layer)?;
// Reveal decrypt key to the leader.
self.aes_ctr.decode_key(&mut (*vm))?;
vm.flush(ctx).await.map_err(MpcTlsError::record_layer)?;
self.aes_ctr.finish_decode()?;
let pending_decrypts = decrypt::decrypt_local(
self.role,
&mut (*vm),
&mut decrypter,
&mut self.aes_ctr,
&buffered_ops,
)?;
vm.execute_all(ctx)
.await
.map_err(MpcTlsError::record_layer)?;
for (op, pending) in buffered_ops.into_iter().zip(pending_decrypts) {
let plaintext = pending.output.try_decrypt()?;
self.decrypted_buffer.push_back(PlainRecord {
typ: op.typ,
version: op.version,
plaintext: plaintext.clone(),
});
recv_records.push(Record {
seq: op.seq,
typ: op.typ,
plaintext,
plaintext_ref: None,
explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext,
});
}
self.state = State::Complete;
Ok(TlsTranscript {
sent: sent_records,
recv: recv_records,
})
}
fn next_write(
&mut self,
typ: ContentType,
version: ProtocolVersion,
len: usize,
) -> (u64, Vec<u8>, Vec<u8>) {
let seq = self.write_seq;
self.write_seq += 1;
let explicit_nonce = seq.to_be_bytes().to_vec();
let aad = make_tls12_aad(seq, typ, version, len).to_vec();
(seq, explicit_nonce, aad)
}
fn next_read(
&mut self,
typ: ContentType,
version: ProtocolVersion,
len: usize,
) -> (u64, Vec<u8>) {
let seq = self.read_seq;
self.read_seq += 1;
let aad = make_tls12_aad(seq, typ, version, len).to_vec();
(seq, aad)
}
}
#[derive(Clone)]
pub(crate) struct TagData {
pub(crate) explicit_nonce: Vec<u8>,
pub(crate) aad: Vec<u8>,
}

View File

@@ -0,0 +1,69 @@
mod aes_gcm;
mod ghash;
pub(crate) use aes_gcm::MpcAesGcm;
use cipher::{aes::AesError, CipherError};
pub(crate) use ghash::{ComputeTags, VerifyTags};
use mpz_memory_core::{binary::U8, Array};
use mpz_vm_core::VmError;
type Nonce = Array<U8, 8>;
type Ctr = Array<U8, 4>;
type Block = Array<U8, 16>;
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct AeadError(ErrorRepr);
impl AeadError {
pub(crate) fn state<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::State(err.into()))
}
pub(crate) fn cipher<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Cipher(err.into()))
}
pub(crate) fn tag<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Tag(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("aead error: {0}")]
enum ErrorRepr {
#[error("state error: {0}")]
State(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("cipher error: {0}")]
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("tag error: {0}")]
Tag(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl From<VmError> for AeadError {
fn from(err: VmError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}
impl From<CipherError> for AeadError {
fn from(err: CipherError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}
impl From<AesError> for AeadError {
fn from(err: AesError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}

View File

@@ -0,0 +1,664 @@
use std::{future::Future, sync::Arc};
use crate::{
decode::OneTimePadShared,
record_layer::{
aead::{
ghash::{ComputeTagData, ComputeTags, Ghash, MpcGhash, VerifyTagData, VerifyTags},
AeadError, Block, Ctr, Nonce,
},
TagData,
},
Role,
};
use cipher::{aes::Aes128, Cipher, CtrBlock, Keystream};
use mpz_common::{Context, Flush};
use mpz_fields::gf2_128::Gf2_128;
use mpz_memory_core::{
binary::{Binary, U8},
Vector,
};
use mpz_share_conversion::ShareConvert;
use mpz_vm_core::{prelude::*, Vm};
use tracing::instrument;
const START_CTR: u32 = 2;
#[allow(clippy::type_complexity)]
enum State {
Init {
ghash: Box<dyn Ghash + Send + Sync>,
},
Setup {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
output: Vector<U8>,
ghash_key: OneTimePadShared<[u8; 16]>,
ghash: Box<dyn Ghash + Send + Sync>,
},
Ready {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
j0s: Vec<(CtrBlock<Nonce, Ctr, Block>, OneTimePadShared<[u8; 16]>)>,
output: Vector<U8>,
ghash: Arc<dyn Ghash + Send + Sync>,
},
Error,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
pub(crate) struct MpcAesGcm {
role: Role,
aes: Aes128,
state: State,
}
impl MpcAesGcm {
/// Creates a new AES-GCM instance.
pub(crate) fn new<C>(converter: C, role: Role) -> Self
where
C: ShareConvert<Gf2_128> + Flush + Send + Sync + 'static,
{
Self {
role,
aes: Aes128::default(),
state: State::Init {
ghash: Box::new(MpcGhash::new(converter)),
},
}
}
/// Allocates resources.
///
/// # Arguments
///
/// * `vm` - Virtual machine to allocate in.
/// * `records` - Number of records to allocate.
/// * `len` - Length of the input text in bytes.
pub(crate) fn alloc(
&mut self,
vm: &mut dyn Vm<Binary>,
records: usize,
len: usize,
) -> Result<(), AeadError> {
let State::Init { mut ghash } = self.state.take() else {
return Err(AeadError::state("must be in init state to allocate"));
};
let zero_block: Array<U8, 16> = vm.alloc()?;
vm.mark_public(zero_block)?;
vm.assign(zero_block, [0u8; 16])?;
vm.commit(zero_block)?;
ghash.alloc()?;
let ghash_key = self.aes.alloc_block(vm, zero_block)?;
let ghash_key = OneTimePadShared::<[u8; 16]>::new(self.role, ghash_key, vm)?;
// Allocate J0 secret sharing for GHASH.
let mut j0s = Vec::with_capacity(records);
for _ in 0..records {
let j0 = self.aes.alloc_ctr_block(vm)?;
let j0_shared = OneTimePadShared::<[u8; 16]>::new(self.role, j0.output, vm)?;
j0s.push((j0, j0_shared));
}
// Allocate encryption/decryption.
// Round up the length to the nearest multiple of the block count.
let len = 16 * len.div_ceil(16);
let input = vm.alloc_vec::<U8>(len)?;
match self.role {
Role::Leader => {
vm.mark_private(input)?;
}
Role::Follower => {
vm.mark_blind(input)?;
}
}
let keystream = self.aes.alloc_keystream(vm, len)?;
let output = keystream.apply(vm, input)?;
self.state = State::Setup {
input,
keystream,
j0s,
output,
ghash,
ghash_key,
};
Ok(())
}
pub(crate) async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), AeadError> {
let State::Setup { ghash, .. } = &mut self.state else {
return Err(AeadError::state("must be in setup state to allocate"));
};
ghash.preprocess(ctx).await?;
Ok(())
}
pub(crate) fn set_key(&mut self, key: Array<U8, 16>) {
self.aes.set_key(key);
}
pub(crate) fn set_iv(&mut self, iv: Array<U8, 4>) {
self.aes.set_iv(iv);
}
pub(crate) async fn setup(&mut self, ctx: &mut Context) -> Result<(), AeadError> {
let State::Setup {
input,
keystream,
j0s,
output,
mut ghash,
ghash_key,
} = self.state.take()
else {
return Err(AeadError::state("must be in setup state to set up"));
};
let key = ghash_key.await.map_err(AeadError::tag)?;
ghash.set_key(key.to_vec())?;
ghash.setup(ctx).await?;
self.state = State::Ready {
input,
keystream,
j0s,
output,
ghash: Arc::from(ghash),
};
Ok(())
}
/// Returns `len` bytes of input and output text.
///
/// The outer context is responsible for assigning to the input text.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `len` - Number of bytes to take.
#[instrument(level = "debug", skip_all, err)]
pub(crate) fn apply_keystream(
&mut self,
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<(Vector<U8>, Vector<U8>), AeadError> {
let State::Ready {
input,
keystream,
output,
..
} = &mut self.state
else {
return Err(AeadError::state(
"must be in ready state to apply keystream",
));
};
let explicit_nonce: [u8; 8] = explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
AeadError::cipher(format!(
"explicit nonce length: expected {}, got {}",
16,
nonce.len()
))
})?;
let block_count = len.div_ceil(16);
let padded_len = block_count * 16;
let padding_len = padded_len - len;
if padded_len > input.len() {
return Err(AeadError::cipher(format!(
"input length exceeds allocated: {} > {}",
padded_len,
input.len()
)));
}
let mut input = input.split_off(input.len() - padded_len);
let keystream = keystream.consume(padded_len)?;
let mut output = output.split_off(output.len() - padded_len);
// Assign counter block inputs.
let mut ctr = START_CTR..;
keystream.assign(vm, explicit_nonce, move || {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
// Assign zeroes to the padding.
if padding_len > 0 {
let padding = input.split_off(input.len() - padding_len);
if let Role::Leader = self.role {
vm.assign(padding, vec![0; padding_len])?;
}
vm.commit(padding)?;
output.truncate(len);
}
Ok((input, output))
}
/// Returns `len` bytes of keystream.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `len` - Number of bytes to take.
#[instrument(level = "debug", skip_all, err)]
pub(crate) fn take_keystream(
&mut self,
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<Vector<U8>, AeadError> {
let State::Ready {
input,
keystream,
output,
..
} = &mut self.state
else {
return Err(AeadError::state("must be in ready state to take keystream"));
};
let explicit_nonce: [u8; 8] = explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
AeadError::cipher(format!(
"explicit nonce length: expected {}, got {}",
16,
nonce.len()
))
})?;
let block_count = len.div_ceil(16);
let padded_len = block_count * 16;
if padded_len > input.len() {
return Err(AeadError::cipher(format!(
"input length exceeds allocated: {} > {}",
padded_len,
input.len()
)));
}
// Drop the input and output text, we won't be needing them.
// This leaves them allocated but unassigned in the VM.
_ = input.split_off(input.len() - padded_len);
_ = output.split_off(output.len() - padded_len);
let keystream = keystream.consume(len)?;
// Assign counter block inputs.
let mut ctr = START_CTR..;
keystream.assign(vm, explicit_nonce, move || {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
Ok(keystream.to_vector(len)?)
}
/// Computes tags for the provided ciphertext. See
/// [`verify_tags`](MpcAesGcm::verify_tags) for a method that verifies an
/// tags instead.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `ciphertext` - Ciphertext to compute the tag for.
#[instrument(level = "debug", skip_all, err)]
pub(crate) fn compute_tags<C>(
&mut self,
vm: &mut dyn Vm<Binary>,
ciphertexts: Vec<C>,
data: Vec<TagData>,
) -> Result<ComputeTags, AeadError>
where
C: Future<Output = Result<Vec<u8>, AeadError>> + Send + Sync + 'static,
{
let State::Ready { j0s, ghash, .. } = &mut self.state else {
return Err(AeadError::state("must be in ready state to compute tags"));
};
if ciphertexts.len() != data.len() {
return Err(AeadError::tag("ciphertext and data length mismatch"));
} else if ciphertexts.len() > j0s.len() {
return Err(AeadError::tag("ciphertext length exceeds allocated"));
}
let mut tag_data = Vec::with_capacity(ciphertexts.len());
for (ciphertext, data) in ciphertexts.into_iter().zip(data) {
let explicit_nonce: [u8; 8] =
data.explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
AeadError::cipher(format!(
"explicit nonce length: expected {}, got {}",
8,
nonce.len()
))
})?;
let (j0, j0_shared) = j0s.pop().expect("j0 length was checked");
assign_j0(vm, j0, explicit_nonce)?;
tag_data.push(ComputeTagData {
j0: j0_shared,
ciphertext: Box::pin(ciphertext),
aad: data.aad,
});
}
let tags = ComputeTags::new(self.role, tag_data, ghash.clone());
Ok(tags)
}
/// Verifies the tags for the provided ciphertexts.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `inputs` - Data to verify the tags for.
pub(crate) fn verify_tags(
&mut self,
vm: &mut dyn Vm<Binary>,
data: Vec<TagData>,
ciphertexts: Vec<Vec<u8>>,
tags: Vec<Vec<u8>>,
) -> Result<VerifyTags, AeadError> {
let State::Ready { j0s, ghash, .. } = &mut self.state else {
return Err(AeadError::state("must be in ready state to verify tags"));
};
if ciphertexts.len() != data.len() {
return Err(AeadError::tag("ciphertext and data length mismatch"));
} else if ciphertexts.len() != tags.len() {
return Err(AeadError::tag("ciphertext and tag length mismatch"));
} else if ciphertexts.len() > j0s.len() {
return Err(AeadError::tag("ciphertext length exceeds allocated"));
}
let mut tag_data = Vec::with_capacity(ciphertexts.len());
for ((ciphertext, data), tag) in ciphertexts.into_iter().zip(data).zip(tags) {
let explicit_nonce: [u8; 8] =
data.explicit_nonce.try_into().map_err(|nonce: Vec<_>| {
AeadError::cipher(format!(
"explicit nonce length: expected {}, got {}",
8,
nonce.len()
))
})?;
let (j0, j0_shared) = j0s.pop().expect("j0 length was checked");
assign_j0(vm, j0, explicit_nonce)?;
tag_data.push(VerifyTagData {
j0: j0_shared,
ciphertext,
aad: data.aad,
tag,
});
}
let tags = VerifyTags::new(self.role, tag_data, ghash.clone());
Ok(tags)
}
}
fn assign_j0(
vm: &mut dyn Vm<Binary>,
j0: CtrBlock<Nonce, Ctr, Block>,
explicit_nonce: [u8; 8],
) -> Result<(), AeadError> {
vm.assign(j0.explicit_nonce, explicit_nonce)?;
vm.commit(j0.explicit_nonce)?;
vm.assign(j0.counter, 1u32.to_be_bytes())?;
vm.commit(j0.counter)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use aes_gcm::{
aead::{AeadInPlace, NewAead},
Aes128Gcm,
};
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_memory_core::{binary::U8, correlated::Delta};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_share_conversion::ideal::ideal_share_convert;
use rand::{rngs::StdRng, Rng, SeedableRng};
use rstest::*;
static SHORT_MSG: &[u8] = b"hello world";
static LONG_MSG: &[u8] = b"this message exceeds one block in length";
#[derive(Clone, Copy)]
struct Vars {
key: Array<U8, 16>,
iv: Array<U8, 4>,
}
#[rstest]
#[case::short(SHORT_MSG, 1)]
#[case::long(LONG_MSG, 1)]
#[case::short_multiple(SHORT_MSG, 3)]
#[case::long_multiple(LONG_MSG, 3)]
#[tokio::test]
async fn test_aes_gcm_encrypt(#[case] msg: &[u8], #[case] count: usize) {
let (mut ctx_0, mut ctx_1) = test_st_context(8);
let key = [42u8; 16];
let iv = [0u8; 4];
let ((mut vm_0, vars_0), (mut vm_1, vars_1)) = create_vm(key, iv);
let (mut leader, mut follower) = create_pair(vars_0, vars_1);
leader.alloc(&mut vm_0, count, 256).unwrap();
follower.alloc(&mut vm_1, count, 256).unwrap();
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
tokio::try_join!(leader.setup(&mut ctx_0), follower.setup(&mut ctx_1)).unwrap();
for i in 0u64..count as u64 {
let explicit_nonce = i.to_be_bytes().to_vec();
let (msg_0, ct_0) = leader
.apply_keystream(&mut vm_0, explicit_nonce.clone(), msg.len())
.unwrap();
let (msg_1, ct_1) = follower
.apply_keystream(&mut vm_1, explicit_nonce.clone(), msg.len())
.unwrap();
vm_0.assign(msg_0, msg.to_vec()).unwrap();
vm_0.commit(msg_0).unwrap();
vm_1.commit(msg_1).unwrap();
let ct_0 = vm_0.decode(ct_0).unwrap();
let ct_1 = vm_1.decode(ct_1).unwrap();
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
let ct_0 = ct_0.await.unwrap();
let ct_1 = ct_1.await.unwrap();
let (expected, _) = expected(&key, &iv, &explicit_nonce, msg, &[]);
assert_eq!(ct_0, expected);
assert_eq!(ct_1, expected);
}
}
#[rstest]
#[case::short(SHORT_MSG, 1)]
#[case::long(LONG_MSG, 1)]
#[case::short_multiple(SHORT_MSG, 3)]
#[case::long_multiple(LONG_MSG, 3)]
#[tokio::test]
async fn test_aes_gcm_decrypt(#[case] msg: &[u8], #[case] count: usize) {
let (mut ctx_0, mut ctx_1) = test_st_context(8);
let key = [42u8; 16];
let iv = [0u8; 4];
let ((mut vm_0, vars_0), (mut vm_1, vars_1)) = create_vm(key, iv);
let (mut leader, mut follower) = create_pair(vars_0, vars_1);
leader.alloc(&mut vm_0, count, 256).unwrap();
follower.alloc(&mut vm_1, count, 256).unwrap();
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
tokio::try_join!(leader.setup(&mut ctx_0), follower.setup(&mut ctx_1)).unwrap();
for i in 0u64..count as u64 {
let explicit_nonce = i.to_be_bytes().to_vec();
let (ct, _) = expected(&key, &iv, &explicit_nonce, msg, &[]);
let (ct_0, msg_0) = leader
.apply_keystream(&mut vm_0, explicit_nonce.clone(), ct.len())
.unwrap();
let (ct_1, msg_1) = follower
.apply_keystream(&mut vm_1, explicit_nonce.clone(), ct.len())
.unwrap();
vm_0.assign(ct_0, ct.clone()).unwrap();
vm_0.commit(ct_0).unwrap();
vm_1.commit(ct_1).unwrap();
let msg_0 = vm_0.decode(msg_0).unwrap();
let msg_1 = vm_1.decode(msg_1).unwrap();
run_vms(&mut vm_0, &mut ctx_0, &mut vm_1, &mut ctx_1).await;
let msg_0 = msg_0.await.unwrap();
let msg_1 = msg_1.await.unwrap();
assert_eq!(&msg_0, msg);
assert_eq!(&msg_1, msg);
}
}
fn create_vm(key: [u8; 16], iv: [u8; 4]) -> ((impl Vm<Binary>, Vars), (impl Vm<Binary>, Vars)) {
let mut rng = StdRng::seed_from_u64(0);
let block = Block::random(&mut rng);
let (sender, receiver) = ideal_cot(block);
let delta = Delta::new(block);
let mut vm_0 = Generator::new(sender, [0u8; 16], delta);
let mut vm_1 = Evaluator::new(receiver);
let key_ref_0 = vm_0.alloc::<Array<U8, 16>>().unwrap();
vm_0.mark_public(key_ref_0).unwrap();
vm_0.assign(key_ref_0, key).unwrap();
vm_0.commit(key_ref_0).unwrap();
let key_ref_1 = vm_1.alloc::<Array<U8, 16>>().unwrap();
vm_1.mark_public(key_ref_1).unwrap();
vm_1.assign(key_ref_1, key).unwrap();
vm_1.commit(key_ref_1).unwrap();
let iv_ref_0 = vm_0.alloc::<Array<U8, 4>>().unwrap();
vm_0.mark_public(iv_ref_0).unwrap();
vm_0.assign(iv_ref_0, iv).unwrap();
vm_0.commit(iv_ref_0).unwrap();
let iv_ref_1 = vm_1.alloc::<Array<U8, 4>>().unwrap();
vm_1.mark_public(iv_ref_1).unwrap();
vm_1.assign(iv_ref_1, iv).unwrap();
vm_1.commit(iv_ref_1).unwrap();
(
(
vm_0,
Vars {
key: key_ref_0,
iv: iv_ref_0,
},
),
(
vm_1,
Vars {
key: key_ref_1,
iv: iv_ref_1,
},
),
)
}
fn create_pair(vars_0: Vars, vars_1: Vars) -> (MpcAesGcm, MpcAesGcm) {
let mut rng = StdRng::seed_from_u64(0);
let (c_0, c_1) = ideal_share_convert(rng.gen());
let mut leader = MpcAesGcm::new(c_0, Role::Leader);
let mut follower = MpcAesGcm::new(c_1, Role::Follower);
leader.set_key(vars_0.key);
leader.set_iv(vars_0.iv);
follower.set_key(vars_1.key);
follower.set_iv(vars_1.iv);
(leader, follower)
}
async fn run_vms(
vm_0: &mut (dyn Vm<Binary> + Send),
ctx_0: &mut Context,
vm_1: &mut (dyn Vm<Binary> + Send),
ctx_1: &mut Context,
) {
tokio::join!(
async {
vm_0.execute_all(ctx_0).await.unwrap();
},
async {
vm_1.execute_all(ctx_1).await.unwrap();
}
);
}
fn expected(
key: &[u8],
iv: &[u8],
explicit_nonce: &[u8],
msg: &[u8],
aad: &[u8],
) -> (Vec<u8>, Vec<u8>) {
let key: [u8; 16] = key.try_into().unwrap();
let aes = Aes128Gcm::new(&key.into());
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(iv);
nonce[4..].copy_from_slice(explicit_nonce);
let mut payload = msg.to_vec();
let tag = aes
.encrypt_in_place_detached(&nonce.into(), aad, &mut payload)
.unwrap();
(payload, tag.to_vec())
}
}

View File

@@ -0,0 +1,531 @@
mod compute;
mod verify;
pub(crate) use compute::{ComputeTagData, ComputeTags};
pub(crate) use verify::{VerifyTagData, VerifyTags};
use std::{fmt::Debug, ops::Add};
use async_trait::async_trait;
use mpz_common::{future::Output, Context, Flush};
use mpz_core::Block;
use mpz_fields::{gf2_128::Gf2_128, Field};
use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive};
use serde::{Deserialize, Serialize};
use crate::record_layer::aead::AeadError;
/// Maximum key share power.
const MAX_POWER: usize = 1026;
#[async_trait]
pub(crate) trait Ghash {
/// Allocates resources needed for GHASH.
fn alloc(&mut self) -> Result<(), GhashError>;
/// Preprocesses GHASH.
async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), GhashError>;
/// Sets the key for the hash function.
fn set_key(&mut self, key: Vec<u8>) -> Result<(), GhashError>;
/// Sets up GHASH, computing the key shares.
async fn setup(&mut self, ctx: &mut Context) -> Result<(), GhashError>;
/// Computes the GHASH tag.
fn compute(&self, input: &[u8]) -> Result<Vec<u8>, GhashError>;
}
/// MPC GHASH implementation.
pub(crate) struct MpcGhash<C> {
state: State,
converter: C,
alloc: bool,
}
#[derive(Debug)]
enum State {
Init,
SetKey { key: Gf2_128 },
Ready { shares: Vec<Gf2_128> },
Error,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
impl<C> MpcGhash<C> {
/// Creates a new instance.
///
/// # Arguments
///
/// * `converter` - GF2_128 share converter.
pub(crate) fn new(converter: C) -> Self {
Self {
state: State::Init,
converter,
alloc: false,
}
}
}
#[async_trait]
impl<C> Ghash for MpcGhash<C>
where
C: AdditiveToMultiplicative<Gf2_128> + Flush + Send,
C: MultiplicativeToAdditive<Gf2_128> + Flush + Send,
{
fn alloc(&mut self) -> Result<(), GhashError> {
if !self.alloc {
// We need only half the number of `MAX_POWER` M2As because of the free
// squaring trick and we need one extra A2M conversion in the beginning.
// Both M2A and A2M, each require a single OLE.
AdditiveToMultiplicative::<Gf2_128>::alloc(&mut self.converter, 1)
.map_err(GhashError::conversion)?;
MultiplicativeToAdditive::<Gf2_128>::alloc(&mut self.converter, (MAX_POWER / 2) - 1)
.map_err(GhashError::conversion)?;
self.alloc = true;
}
Ok(())
}
async fn preprocess(&mut self, ctx: &mut Context) -> Result<(), GhashError> {
self.converter
.flush(ctx)
.await
.map_err(GhashError::conversion)
}
fn set_key(&mut self, key: Vec<u8>) -> Result<(), GhashError> {
if key.len() != 16 {
return Err(ErrorRepr::KeyLength {
expected: 16,
actual: key.len(),
}
.into());
}
let State::Init = self.state.take() else {
return Err(GhashError::state("Key already set"));
};
let mut h_additive = [0u8; 16];
h_additive.copy_from_slice(key.as_slice());
// GHASH reflects the bits of the key.
let h_additive = Gf2_128::new(u128::from_be_bytes(h_additive).reverse_bits());
self.state = State::SetKey { key: h_additive };
Ok(())
}
async fn setup(&mut self, ctx: &mut Context) -> Result<(), GhashError> {
let State::SetKey { key: add_key } = self.state.take() else {
return Err(GhashError::state("can not setup before key is set"));
};
let mut mult_key = self
.converter
.queue_to_multiplicative(&[add_key])
.map_err(GhashError::conversion)?;
self.converter
.flush(ctx)
.await
.map_err(GhashError::conversion)?;
let mult_key = mult_key
.try_recv()
.map_err(GhashError::conversion)?
.expect("share should be computed")
.shares[0];
// Compute the odd powers of the multiplicative key share.
//
// Resulting vector contains odd powers of H from H^3 to H^1025.
let odd_shares: Vec<_> = (0..MAX_POWER)
.scan(mult_key, |acc, _| {
let power_n = *acc;
*acc = power_n * mult_key;
Some(power_n)
})
// Start from H^3
.skip(2)
// Skip even powers
.step_by(2)
.collect();
// Compute the additive shares of the odd powers.
let mut add_shares_odd = self
.converter
.queue_to_additive(&odd_shares)
.map_err(GhashError::conversion)?;
self.converter
.flush(ctx)
.await
.map_err(GhashError::conversion)?;
let add_shares_odd = add_shares_odd
.try_recv()
.map_err(GhashError::conversion)?
.expect("share should be computed")
.shares;
let shares = compute_shares(add_key, &add_shares_odd);
self.state = State::Ready { shares };
Ok(())
}
fn compute(&self, input: &[u8]) -> Result<Vec<u8>, GhashError> {
let State::Ready { shares } = &self.state else {
return Err(GhashError::state("key shares are not computed"));
};
// Divide by block length and round up.
let block_count = input.len() / 16 + (input.len() % 16 != 0) as usize;
if block_count > MAX_POWER {
return Err(ErrorRepr::InputLength {
len: block_count,
max: MAX_POWER * 16,
}
.into());
}
let mut input = input.to_vec();
// Pad input to a multiple of 16 bytes.
input.resize(block_count * 16, 0);
// Convert input to blocks.
let blocks = input
.chunks_exact(16)
.map(|chunk| {
let mut block = [0u8; 16];
block.copy_from_slice(chunk);
Block::from(block)
})
.collect::<Vec<Block>>();
let offset = shares.len() - blocks.len();
let tag: Block = blocks
.iter()
.zip(shares.iter().rev().skip(offset))
.fold(Gf2_128::zero(), |acc, (block, share)| {
acc + Gf2_128::from(block.reverse_bits()) * *share
})
.into();
Ok(tag.reverse_bits().to_bytes().to_vec())
}
}
/// Computes shares of powers of H.
///
/// # Arguments
///
/// * `key` - Additive share of H.
/// * `odd_powers` - Additive shares of odd powers of H starting at H^3.
fn compute_shares(key: Gf2_128, odd_powers: &[Gf2_128]) -> Vec<Gf2_128> {
let mut shares = Vec::with_capacity(MAX_POWER);
// H^1
shares.push(key);
let mut odd_idx = 0;
for i in 2..=MAX_POWER {
if i % 2 == 0 {
// Even power, compute by squaring the square root power.
let base = shares[i / 2 - 1];
shares.push(base * base);
} else {
// Odd power
shares.push(odd_powers[odd_idx]);
odd_idx += 1;
}
}
shares
}
/// Builds padded data for GHASH.
pub(crate) fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
let associated_data_bitlen = (aad.len() as u64) * 8;
let text_bitlen = (ciphertext.len() as u64) * 8;
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
data.extend(aad);
data.extend(ciphertext);
data.extend_from_slice(&len_block.to_be_bytes());
data
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TagShare([u8; 16]);
impl Add for TagShare {
type Output = Vec<u8>;
fn add(mut self, rhs: Self) -> Self::Output {
self.0.iter_mut().zip(rhs.0).for_each(|(a, b)| *a ^= b);
self.0.to_vec()
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct GhashError(#[from] ErrorRepr);
impl GhashError {
fn conversion<E>(error: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::ShareConversion(error.into()))
}
fn state(reason: impl ToString) -> Self {
Self(ErrorRepr::State(reason.to_string()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("ghash error: {0}")]
enum ErrorRepr {
#[error("share conversion error: {0}")]
ShareConversion(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("invalid state: {0}")]
State(String),
#[error("incorrect key length, expected: {expected}, actual: {actual}")]
KeyLength { expected: usize, actual: usize },
#[error("input length exceeds maximum: {len} > {max}")]
InputLength { len: usize, max: usize },
}
impl From<GhashError> for AeadError {
fn from(value: GhashError) -> Self {
AeadError::tag(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ghash_rc::{
universal_hash::{KeyInit, UniversalHash as UniversalHashReference},
GHash as GhashReference,
};
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_fields::{gf2_128::Gf2_128, UniformRand};
use mpz_share_conversion::ideal::{
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
fn create_pair() -> (
MpcGhash<IdealShareConvertSender<Gf2_128>>,
MpcGhash<IdealShareConvertReceiver<Gf2_128>>,
) {
let (convert_a, convert_b) = ideal_share_convert(Block::ZERO);
let (mut sender, mut receiver) = (MpcGhash::new(convert_a), MpcGhash::new(convert_b));
sender.alloc().unwrap();
receiver.alloc().unwrap();
(sender, receiver)
}
#[test]
fn test_compute_shares() {
let mut rng = StdRng::seed_from_u64(0);
let key = Gf2_128::rand(&mut rng);
let expected_powers: Vec<_> = (0..MAX_POWER)
.scan(key, |acc, _| {
let power_n = *acc;
*acc = power_n * key;
Some(power_n)
})
.collect();
let odd_powers = expected_powers
.iter()
.skip(2)
.step_by(2)
.cloned()
.collect::<Vec<_>>();
let powers = compute_shares(key, &odd_powers);
assert_eq!(powers, expected_powers);
}
#[tokio::test]
async fn test_ghash_output() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut rng = StdRng::seed_from_u64(0);
let h: u128 = rng.gen();
let sender_key: u128 = rng.gen();
let receiver_key: u128 = h ^ sender_key;
let message: Vec<u8> = (0..16).map(|_| rng.gen()).collect();
let (mut sender, mut receiver) = create_pair();
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
receiver
.set_key(receiver_key.to_be_bytes().to_vec())
.unwrap();
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
let sender_share = sender.compute(&message).unwrap();
let receiver_share = receiver.compute(&message).unwrap();
let tag = sender_share
.iter()
.zip(receiver_share.iter())
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
assert_eq!(tag, ghash_reference_impl(h, &message));
}
#[tokio::test]
async fn test_ghash_output_padded() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut rng = StdRng::seed_from_u64(0);
let h: u128 = rng.gen();
let sender_key: u128 = rng.gen();
let receiver_key: u128 = h ^ sender_key;
// Message length is not a multiple of the block length
let message: Vec<u8> = (0..14).map(|_| rng.gen()).collect();
let (mut sender, mut receiver) = create_pair();
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
receiver
.set_key(receiver_key.to_be_bytes().to_vec())
.unwrap();
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
let sender_share = sender.compute(&message).unwrap();
let receiver_share = receiver.compute(&message).unwrap();
let tag = sender_share
.iter()
.zip(receiver_share.iter())
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
assert_eq!(tag, ghash_reference_impl(h, &message));
}
#[tokio::test]
async fn test_ghash_long_message() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut rng = StdRng::seed_from_u64(0);
let h: u128 = rng.gen();
let sender_key: u128 = rng.gen();
let receiver_key: u128 = h ^ sender_key;
// A longer message.
let long_message: Vec<u8> = (0..30).map(|_| rng.gen()).collect();
let (mut sender, mut receiver) = create_pair();
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
receiver
.set_key(receiver_key.to_be_bytes().to_vec())
.unwrap();
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
let sender_share = sender.compute(&long_message).unwrap();
let receiver_share = receiver.compute(&long_message).unwrap();
let tag = sender_share
.iter()
.zip(receiver_share.iter())
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
assert_eq!(tag, ghash_reference_impl(h, &long_message));
}
#[tokio::test]
async fn test_ghash_repeated() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let mut rng = StdRng::seed_from_u64(0);
let h: u128 = rng.gen();
let sender_key: u128 = rng.gen();
let receiver_key: u128 = h ^ sender_key;
// Two messages.
let first_message: Vec<u8> = (0..14).map(|_| rng.gen()).collect();
let second_message: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
let (mut sender, mut receiver) = create_pair();
sender.set_key(sender_key.to_be_bytes().to_vec()).unwrap();
receiver
.set_key(receiver_key.to_be_bytes().to_vec())
.unwrap();
tokio::try_join!(sender.setup(&mut ctx_a), receiver.setup(&mut ctx_b)).unwrap();
// Compute and check first message.
let sender_share = sender.compute(&first_message).unwrap();
let receiver_share = receiver.compute(&first_message).unwrap();
let tag = sender_share
.iter()
.zip(receiver_share.iter())
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
assert_eq!(tag, ghash_reference_impl(h, &first_message));
// Compute and check second message.
let sender_share = sender.compute(&second_message).unwrap();
let receiver_share = receiver.compute(&second_message).unwrap();
let tag = sender_share
.iter()
.zip(receiver_share.iter())
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
assert_eq!(tag, ghash_reference_impl(h, &second_message));
}
fn ghash_reference_impl(h: u128, message: &[u8]) -> Vec<u8> {
let mut ghash = GhashReference::new(&h.to_be_bytes().into());
ghash.update_padded(message);
let mac = ghash.finalize();
mac.to_vec()
}
}

View File

@@ -0,0 +1,119 @@
use std::{future::Future, pin::Pin, sync::Arc};
use async_trait::async_trait;
use futures::{stream::FuturesOrdered, StreamExt as _};
use mpz_common::{Context, Task};
use serio::{stream::IoStreamExt, SinkExt};
use crate::{
decode::OneTimePadShared,
record_layer::aead::{
ghash::{build_ghash_data, Ghash, TagShare},
AeadError,
},
Role,
};
pub(crate) struct ComputeTagData {
pub(crate) j0: OneTimePadShared<[u8; 16]>,
pub(crate) ciphertext: Pin<Box<dyn Future<Output = Result<Vec<u8>, AeadError>> + Send + Sync>>,
pub(crate) aad: Vec<u8>,
}
#[must_use = "compute tags operation must be awaited"]
pub(crate) struct ComputeTags {
role: Role,
data: Vec<ComputeTagData>,
ghash: Arc<dyn Ghash + Send + Sync>,
}
impl ComputeTags {
pub(crate) fn new(
role: Role,
data: Vec<ComputeTagData>,
ghash: Arc<dyn Ghash + Send + Sync>,
) -> Self {
Self { role, data, ghash }
}
}
#[async_trait]
impl Task for ComputeTags {
type Output = Result<Option<Vec<Vec<u8>>>, AeadError>;
async fn run(self, ctx: &mut Context) -> Self::Output {
let Self {
role,
mut data,
ghash,
} = self;
if data.is_empty() {
return Ok(None);
}
let mut j0_shares = Vec::with_capacity(data.len());
{
let mut futs = FuturesOrdered::from_iter(data.iter_mut().map(|data| &mut data.j0));
while let Some(j0_share) = futs.next().await.transpose().map_err(AeadError::tag)? {
j0_shares.push(j0_share);
}
}
let mut ciphertexts = Vec::with_capacity(data.len());
{
let mut futs =
FuturesOrdered::from_iter(data.iter_mut().map(|data| &mut data.ciphertext));
while let Some(ciphertext) = futs.next().await.transpose().map_err(AeadError::tag)? {
ciphertexts.push(ciphertext);
}
}
let mut tag_shares = Vec::with_capacity(data.len());
for ((mut tag_share, ciphertext), data) in j0_shares.into_iter().zip(ciphertexts).zip(data)
{
let ghash_share = ghash
.compute(&build_ghash_data(data.aad, ciphertext))
.map_err(AeadError::tag)?;
tag_share
.iter_mut()
.zip(ghash_share)
.for_each(|(a, b)| *a ^= b);
tag_shares.push(TagShare(tag_share));
}
let tags = match role {
Role::Leader => {
let follower_tag_shares: Vec<TagShare> =
ctx.io_mut().expect_next().await.map_err(AeadError::tag)?;
if follower_tag_shares.len() != tag_shares.len() {
return Err(AeadError::tag("follower tag shares length mismatch"));
}
let tags = tag_shares
.into_iter()
.zip(follower_tag_shares)
.map(|(a, b)| (a + b).to_vec())
.collect();
Some(tags)
}
Role::Follower => {
ctx.io_mut()
.send(tag_shares)
.await
.map_err(AeadError::tag)?;
None
}
};
Ok(tags)
}
async fn run_boxed(self: Box<Self>, ctx: &mut Context) -> Self::Output {
self.run(ctx).await
}
}

View File

@@ -0,0 +1,134 @@
use std::sync::Arc;
use async_trait::async_trait;
use futures::{stream::FuturesOrdered, StreamExt};
use mpz_common::{Context, Task};
use mpz_core::commit::{Decommitment, HashCommit};
use serio::{stream::IoStreamExt, SinkExt};
use crate::{
decode::OneTimePadShared,
record_layer::aead::{
ghash::{build_ghash_data, Ghash, TagShare},
AeadError,
},
Role,
};
pub(crate) struct VerifyTagData {
pub(crate) j0: OneTimePadShared<[u8; 16]>,
pub(crate) ciphertext: Vec<u8>,
pub(crate) aad: Vec<u8>,
pub(crate) tag: Vec<u8>,
}
#[must_use = "verify tags operation must be awaited"]
pub(crate) struct VerifyTags {
role: Role,
data: Vec<VerifyTagData>,
ghash: Arc<dyn Ghash + Send + Sync>,
}
impl VerifyTags {
pub(crate) fn new(
role: Role,
data: Vec<VerifyTagData>,
ghash: Arc<dyn Ghash + Send + Sync>,
) -> Self {
Self { role, data, ghash }
}
}
#[async_trait]
impl Task for VerifyTags {
type Output = Result<(), AeadError>;
async fn run(self, ctx: &mut Context) -> Self::Output {
let Self {
role,
mut data,
ghash,
} = self;
if data.is_empty() {
return Ok(());
}
let mut j0_shares = Vec::with_capacity(data.len());
{
let mut futs = FuturesOrdered::from_iter(data.iter_mut().map(|data| &mut data.j0));
while let Some(j0_share) = futs.next().await.transpose().map_err(AeadError::tag)? {
j0_shares.push(j0_share);
}
}
let mut tag_shares = Vec::with_capacity(data.len());
let mut tags = Vec::with_capacity(data.len());
for (mut tag_share, data) in j0_shares.into_iter().zip(data) {
let ghash_share = ghash
.compute(&build_ghash_data(data.aad, data.ciphertext))
.map_err(AeadError::tag)?;
tag_share
.iter_mut()
.zip(ghash_share)
.for_each(|(a, b)| *a ^= b);
tag_shares.push(TagShare(tag_share));
tags.push(data.tag);
}
let io = ctx.io_mut();
let peer_tag_shares = match role {
Role::Leader => {
// Send commitment to follower.
let (decommitment, commitment) = tag_shares.clone().hash_commit();
io.send(commitment).await.map_err(AeadError::tag)?;
let follower_tag_shares: Vec<TagShare> =
io.expect_next().await.map_err(AeadError::tag)?;
if follower_tag_shares.len() != tag_shares.len() {
return Err(AeadError::tag("follower tag shares length mismatch"));
}
// Send decommitment to follower.
io.send(decommitment).await.map_err(AeadError::tag)?;
follower_tag_shares
}
Role::Follower => {
// Wait for commitment from leader.
let commitment = io.expect_next().await.map_err(AeadError::tag)?;
// Send tag shares to leader.
io.send(tag_shares.clone()).await.map_err(AeadError::tag)?;
// Expect decommitment from leader.
let decommitment: Decommitment<Vec<TagShare>> =
io.expect_next().await.map_err(AeadError::tag)?;
// Verify decommitment.
decommitment.verify(&commitment).map_err(AeadError::tag)?;
decommitment.into_inner()
}
};
let expected_tags = tag_shares
.into_iter()
.zip(peer_tag_shares)
.map(|(tag_share, peer_tag_share)| tag_share + peer_tag_share)
.collect::<Vec<_>>();
if tags != expected_tags {
return Err(AeadError::tag("failed to verify tags"));
}
Ok(())
}
async fn run_boxed(self: Box<Self>, ctx: &mut Context) -> Self::Output {
self.run(ctx).await
}
}

View File

@@ -0,0 +1,263 @@
use cipher_crate::{KeyIvInit, StreamCipher as _, StreamCipherSeek};
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
binary::{Binary, U8},
Array, DecodeFutureTyped,
};
use mpz_vm_core::{prelude::*, Vm};
use rand::{thread_rng, RngCore};
use crate::{MpcTlsError, Role};
type LocalAesCtr = ctr::Ctr32BE<aes::Aes128>;
enum State {
Init,
Alloc {
masked_key: Array<U8, 16>,
masked_iv: Array<U8, 4>,
key_otp: Option<[u8; 16]>,
iv_otp: Option<[u8; 4]>,
},
Decode {
masked_key: DecodeFutureTyped<BitVec, [u8; 16]>,
masked_iv: DecodeFutureTyped<BitVec, [u8; 4]>,
key_otp: Option<[u8; 16]>,
iv_otp: Option<[u8; 4]>,
},
Ready {
key: Option<[u8; 16]>,
iv: Option<[u8; 4]>,
},
Error,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
pub(crate) struct AesCtr {
role: Role,
key: Option<Array<U8, 16>>,
iv: Option<Array<U8, 4>>,
state: State,
}
impl AesCtr {
pub(crate) fn new(role: Role) -> Self {
Self {
role,
key: None,
iv: None,
state: State::Init,
}
}
pub(crate) fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
self.key = Some(key);
self.iv = Some(iv);
}
pub(crate) fn alloc(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), MpcTlsError> {
let State::Init = self.state.take() else {
Err(MpcTlsError::record_layer(
"aes-ctr must be in initialized state to allocate",
))?
};
let key = self
.key
.ok_or_else(|| MpcTlsError::record_layer("key not set in aes-ctr"))?;
let iv = self
.iv
.ok_or_else(|| MpcTlsError::record_layer("iv not set in aes-ctr"))?;
let (masked_key, key_otp, masked_iv, iv_otp) = match self.role {
Role::Leader => {
let mut key_otp = [0u8; 16];
thread_rng().fill_bytes(&mut key_otp);
let mut iv_otp = [0u8; 4];
thread_rng().fill_bytes(&mut iv_otp);
let masked_key = vm
.mask_private(key, key_otp)
.map_err(MpcTlsError::record_layer)?;
let masked_iv = vm
.mask_private(iv, iv_otp)
.map_err(MpcTlsError::record_layer)?;
(masked_key, Some(key_otp), masked_iv, Some(iv_otp))
}
Role::Follower => {
let masked_key = vm.mask_blind(key).map_err(MpcTlsError::record_layer)?;
let masked_iv = vm.mask_blind(iv).map_err(MpcTlsError::record_layer)?;
(masked_key, None, masked_iv, None)
}
};
self.state = State::Alloc {
masked_key,
masked_iv,
key_otp,
iv_otp,
};
Ok(())
}
pub(crate) fn decode_key(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), MpcTlsError> {
let State::Alloc {
masked_key,
masked_iv,
key_otp,
iv_otp,
} = self.state.take()
else {
Err(MpcTlsError::record_layer(
"aes-ctr must be in allocated state to decode key",
))?
};
let masked_key = vm.decode(masked_key).map_err(MpcTlsError::record_layer)?;
let masked_iv = vm.decode(masked_iv).map_err(MpcTlsError::record_layer)?;
self.state = State::Decode {
masked_key,
masked_iv,
key_otp,
iv_otp,
};
Ok(())
}
pub(crate) fn finish_decode(&mut self) -> Result<(), MpcTlsError> {
let State::Decode {
mut masked_key,
mut masked_iv,
key_otp,
iv_otp,
} = self.state.take()
else {
Err(MpcTlsError::record_layer(
"aes-ctr must be in decode state to finish decode",
))?
};
let (key, iv) = if let Role::Leader = self.role {
let key_otp = key_otp.expect("leader knows key otp");
let iv_otp = iv_otp.expect("leader knows iv otp");
let masked_key = masked_key
.try_recv()
.map_err(MpcTlsError::record_layer)?
.ok_or_else(|| MpcTlsError::record_layer("masked key is not decoded"))?;
let masked_iv = masked_iv
.try_recv()
.map_err(MpcTlsError::record_layer)?
.ok_or_else(|| MpcTlsError::record_layer("masked iv is not decoded"))?;
let mut key = masked_key;
let mut iv = masked_iv;
key.iter_mut().zip(key_otp).for_each(|(key, otp)| {
*key ^= otp;
});
iv.iter_mut().zip(iv_otp).for_each(|(iv, otp)| {
*iv ^= otp;
});
(Some(key), Some(iv))
} else {
(None, None)
};
self.state = State::Ready { key, iv };
Ok(())
}
pub(crate) fn decrypt(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, MpcTlsError> {
let State::Ready { key, iv, .. } = &self.state else {
Err(MpcTlsError::record_layer(
"aes-ctr must be in ready state to decrypt",
))?
};
if let Role::Follower = self.role {
return Err(MpcTlsError::record_layer(
"aes-ctr must be in leader role to decrypt",
));
}
let key = key.as_ref().expect("leader knows key");
let iv = iv.as_ref().expect("leader knows iv");
let explicit_nonce: [u8; 8] =
explicit_nonce
.try_into()
.map_err(|explicit_nonce: Vec<_>| {
MpcTlsError::record_layer(format!(
"incorrect explicit nonce length: {} != 8",
explicit_nonce.len()
))
})?;
let mut full_iv = [0u8; 16];
full_iv[..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut aes = LocalAesCtr::new(key.into(), &full_iv.into());
// Skip the first 32 bytes of the keystream to match the AES-GCM implementation.
aes.seek(32);
let mut plaintext = ciphertext;
aes.apply_keystream(&mut plaintext);
Ok(plaintext)
}
}
#[cfg(test)]
mod tests {
use super::*;
use aes_gcm::{aead::AeadMutInPlace, Aes128Gcm, NewAead};
#[test]
fn test_aes_ctr_local() {
let key = [0u8; 16];
let iv = [42u8; 4];
let explicit_nonce = [69u8; 8];
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce);
let mut aes_ctr = AesCtr::new(Role::Leader);
aes_ctr.state = State::Ready {
key: Some(key),
iv: Some(iv),
};
let mut aes_gcm = Aes128Gcm::new(&key.into());
let msg = b"hello world";
let mut ciphertext = msg.to_vec();
_ = aes_gcm
.encrypt_in_place_detached(&nonce.into(), &[], &mut ciphertext)
.unwrap();
let decrypted = aes_ctr
.decrypt(explicit_nonce.to_vec(), ciphertext)
.unwrap();
assert_eq!(msg, decrypted.as_slice());
}
}

View File

@@ -0,0 +1,296 @@
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{binary::Binary, DecodeFutureTyped};
use mpz_vm_core::{prelude::*, Vm};
use serde::{Deserialize, Serialize};
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
use crate::{
record_layer::{
aead::{MpcAesGcm, VerifyTags},
aes_ctr::AesCtr,
TagData,
},
MpcTlsError, Role,
};
pub(crate) fn private_mpc(
vm: &mut dyn Vm<Binary>,
decrypter: &mut MpcAesGcm,
otp: Option<&mut Vec<u8>>,
op: &DecryptOp,
) -> Result<DecryptOutput, MpcTlsError> {
if let Some(otp) = otp.as_ref() {
if op.ciphertext.len() > otp.len() {
return Err(MpcTlsError::record_layer(format!(
"ciphertext length exceeds allocated: {} > {}",
op.ciphertext.len(),
otp.len()
)));
}
}
let (otp_ref, masked_keystream) = decrypter
.apply_keystream(vm, op.explicit_nonce.clone(), op.ciphertext.len())
.map_err(MpcTlsError::record_layer)?;
let otp = otp.map(|otp| otp.split_off(otp.len() - op.ciphertext.len()));
if let Some(otp) = otp.clone() {
vm.assign(otp_ref, otp).map_err(MpcTlsError::record_layer)?;
}
vm.commit(otp_ref).map_err(MpcTlsError::record_layer)?;
// Decode the masked keystream.
let masked_keystream = vm
.decode(masked_keystream)
.map_err(MpcTlsError::record_layer)?;
Ok(DecryptOutput::Private(DecryptPrivate {
masked_keystream,
otp,
ciphertext: op.ciphertext.clone(),
}))
}
pub(crate) fn public(
vm: &mut dyn Vm<Binary>,
decrypter: &mut MpcAesGcm,
op: &DecryptOp,
) -> Result<DecryptOutput, MpcTlsError> {
// Instead of computing the plaintext in MPC, we only compute the keystream and
// decode it for both parties. Each party then locally computes the plaintext.
let keystream = decrypter
.take_keystream(vm, op.explicit_nonce.clone(), op.ciphertext.len())
.map_err(MpcTlsError::record_layer)?;
Ok(DecryptOutput::Public(DecryptPublic {
keystream: vm.decode(keystream).map_err(MpcTlsError::record_layer)?,
ciphertext: op.ciphertext.clone(),
}))
}
pub(crate) fn decrypt_mpc(
vm: &mut dyn Vm<Binary>,
decrypter: &mut MpcAesGcm,
mut otp: Option<&mut Vec<u8>>,
ops: &[DecryptOp],
) -> Result<Vec<PendingDecrypt>, MpcTlsError> {
let mut pending_decrypt = Vec::with_capacity(ops.len());
for op in ops {
match op.mode {
DecryptMode::Private => {
pending_decrypt.push(PendingDecrypt {
output: private_mpc(vm, decrypter, otp.as_deref_mut(), op)?,
});
}
DecryptMode::Public => {
pending_decrypt.push(PendingDecrypt {
output: public(vm, decrypter, op)?,
});
}
}
}
Ok(pending_decrypt)
}
pub(crate) fn decrypt_local(
role: Role,
vm: &mut dyn Vm<Binary>,
mpc_decrypter: &mut MpcAesGcm,
local_decrypter: &mut AesCtr,
ops: &[DecryptOp],
) -> Result<Vec<PendingDecrypt>, MpcTlsError> {
let mut pending_decrypt = Vec::with_capacity(ops.len());
for op in ops {
match op.mode {
DecryptMode::Private => {
let plaintext = if let Role::Leader = role {
let plaintext = local_decrypter
.decrypt(op.explicit_nonce.clone(), op.ciphertext.clone())?;
Some(plaintext)
} else {
None
};
pending_decrypt.push(PendingDecrypt {
output: DecryptOutput::Local(DecryptLocal {
plaintext: plaintext.clone(),
}),
});
}
DecryptMode::Public => {
pending_decrypt.push(PendingDecrypt {
output: public(vm, mpc_decrypter, op)?,
});
}
}
}
Ok(pending_decrypt)
}
pub(crate) fn verify_tags(
vm: &mut dyn Vm<Binary>,
decrypter: &mut MpcAesGcm,
ops: &[DecryptOp],
) -> Result<VerifyTags, MpcTlsError> {
let mut ciphertexts = Vec::with_capacity(ops.len());
let mut tags = Vec::with_capacity(ops.len());
let mut tags_data = Vec::with_capacity(ops.len());
for DecryptOp {
ciphertext,
tag,
explicit_nonce,
aad,
..
} in ops
{
ciphertexts.push(ciphertext.clone());
tags.push(tag.clone());
tags_data.push(TagData {
explicit_nonce: explicit_nonce.clone(),
aad: aad.clone(),
});
}
decrypter
.verify_tags(vm, tags_data, ciphertexts, tags)
.map_err(MpcTlsError::record_layer)
}
pub(crate) struct DecryptOp {
pub(crate) seq: u64,
pub(crate) typ: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) explicit_nonce: Vec<u8>,
pub(crate) ciphertext: Vec<u8>,
pub(crate) aad: Vec<u8>,
pub(crate) tag: Vec<u8>,
pub(crate) mode: DecryptMode,
}
impl DecryptOp {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
seq: u64,
typ: ContentType,
version: ProtocolVersion,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
tag: Vec<u8>,
mode: DecryptMode,
) -> Self {
Self {
seq,
typ,
version,
explicit_nonce,
ciphertext,
aad,
tag,
mode,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub(crate) enum DecryptMode {
Private,
Public,
}
pub(crate) enum DecryptOutput {
Private(DecryptPrivate),
Public(DecryptPublic),
Local(DecryptLocal),
}
impl DecryptOutput {
pub(crate) fn try_decrypt(self) -> Result<Option<Vec<u8>>, MpcTlsError> {
match self {
DecryptOutput::Private(decrypt) => decrypt.try_decrypt(),
DecryptOutput::Public(decrypt) => decrypt.try_decrypt().map(Some),
DecryptOutput::Local(decrypt) => decrypt.try_decrypt(),
}
}
}
pub(crate) struct PendingDecrypt {
pub(crate) output: DecryptOutput,
}
pub(crate) struct DecryptPrivate {
masked_keystream: DecodeFutureTyped<BitVec, Vec<u8>>,
otp: Option<Vec<u8>>,
ciphertext: Vec<u8>,
}
impl DecryptPrivate {
pub(crate) fn try_decrypt(mut self) -> Result<Option<Vec<u8>>, MpcTlsError> {
let masked_keystream = self
.masked_keystream
.try_recv()
.map_err(MpcTlsError::record_layer)?
.ok_or_else(|| MpcTlsError::record_layer("masked keystream is not ready"))?;
let Some(otp) = self.otp else {
return Ok(None);
};
// Recover the plaintext by removing the OTP from the masked keystream and
// applying the ciphertext.
let mut plaintext = self.ciphertext;
plaintext
.iter_mut()
.zip(otp)
.zip(masked_keystream)
.for_each(|((a, b), c)| *a ^= b ^ c);
Ok(Some(plaintext))
}
}
pub(crate) struct DecryptPublic {
keystream: DecodeFutureTyped<BitVec, Vec<u8>>,
ciphertext: Vec<u8>,
}
impl DecryptPublic {
/// Decrypts the ciphertext.
pub(crate) fn try_decrypt(mut self) -> Result<Vec<u8>, MpcTlsError> {
let keystream = self
.keystream
.try_recv()
.map_err(MpcTlsError::record_layer)?
.ok_or_else(|| MpcTlsError::record_layer("keystream is not ready"))?;
if keystream.len() != self.ciphertext.len() {
return Err(MpcTlsError::record_layer(format!(
"keystream length does not match ciphertext: {} != {}",
keystream.len(),
self.ciphertext.len()
)));
}
let mut plaintext = self.ciphertext;
plaintext
.iter_mut()
.zip(keystream)
.for_each(|(a, b)| *a ^= b);
Ok(plaintext)
}
}
#[derive(Clone)]
pub(crate) struct DecryptLocal {
pub(crate) plaintext: Option<Vec<u8>>,
}
impl DecryptLocal {
pub(crate) fn try_decrypt(self) -> Result<Option<Vec<u8>>, MpcTlsError> {
Ok(self.plaintext)
}
}

View File

@@ -0,0 +1,264 @@
use futures::TryFutureExt as _;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
binary::{Binary, U8},
DecodeFutureTyped, Vector,
};
use mpz_vm_core::{prelude::*, Vm};
use serde::{Deserialize, Serialize};
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
use crate::{
record_layer::{
aead::{AeadError, ComputeTags, MpcAesGcm},
TagData,
},
BoxFut, MpcTlsError,
};
#[allow(clippy::type_complexity)]
fn private(
vm: &mut dyn Vm<Binary>,
encrypter: &mut MpcAesGcm,
op: &EncryptOp,
) -> Result<
(
Vector<U8>,
EncryptOutput,
BoxFut<Result<Vec<u8>, AeadError>>,
),
MpcTlsError,
> {
let (plaintext, ciphertext) = encrypter
.apply_keystream(vm, op.explicit_nonce.clone(), op.len)
.map_err(MpcTlsError::record_layer)?;
if let Some(data) = op.plaintext.clone() {
vm.assign(plaintext, data)
.map_err(MpcTlsError::record_layer)?;
}
vm.commit(plaintext).map_err(MpcTlsError::record_layer)?;
let ciphertext_fut = Box::pin(
vm.decode(ciphertext)
.map_err(MpcTlsError::record_layer)?
.map_err(AeadError::tag),
);
Ok((
plaintext,
EncryptOutput::Private(EncryptPrivate {
ciphertext: vm.decode(ciphertext).map_err(MpcTlsError::record_layer)?,
}),
ciphertext_fut,
))
}
#[allow(clippy::type_complexity)]
fn public(
vm: &mut dyn Vm<Binary>,
encrypter: &mut MpcAesGcm,
op: &EncryptOp,
) -> Result<(EncryptOutput, BoxFut<Result<Vec<u8>, AeadError>>), MpcTlsError> {
// Instead of computing the ciphertext in MPC, we only compute the keystream and
// decode it for both parties. Each party then locally computes the ciphertext.
let Some(plaintext) = op.plaintext.clone() else {
return Err(MpcTlsError::record_layer(
"plaintext must be provided in public mode",
));
};
let keystream = encrypter
.take_keystream(vm, op.explicit_nonce.clone(), op.len)
.map_err(MpcTlsError::record_layer)?;
let keystream_fut = vm.decode(keystream).map_err(MpcTlsError::record_layer)?;
let ciphertext_fut = {
let plaintext = plaintext.clone();
Box::pin(async move {
let mut ciphertext = keystream_fut.await.map_err(AeadError::tag)?;
ciphertext
.iter_mut()
.zip(plaintext)
.for_each(|(a, b)| *a ^= b);
Ok(ciphertext)
})
};
Ok((
EncryptOutput::Public(EncryptPublic {
keystream: vm.decode(keystream).map_err(MpcTlsError::record_layer)?,
plaintext,
}),
ciphertext_fut,
))
}
pub(crate) fn encrypt(
vm: &mut dyn Vm<Binary>,
encrypter: &mut MpcAesGcm,
ops: &[EncryptOp],
) -> Result<(Vec<PendingEncrypt>, ComputeTags), MpcTlsError> {
let mut outputs = Vec::new();
let mut ciphertext_futs = Vec::new();
let mut tags_data = Vec::new();
for op in ops {
match op.mode {
EncryptMode::Private => {
let (plaintext_ref, output, ciphertext_fut) = private(vm, encrypter, op)?;
outputs.push(PendingEncrypt {
plaintext_ref: Some(plaintext_ref),
output,
});
ciphertext_futs.push(ciphertext_fut);
}
EncryptMode::Public => {
let (output, ciphertext_fut) = public(vm, encrypter, op)?;
outputs.push(PendingEncrypt {
plaintext_ref: None,
output,
});
ciphertext_futs.push(ciphertext_fut);
}
}
tags_data.push(TagData {
explicit_nonce: op.explicit_nonce.clone(),
aad: op.aad.clone(),
});
}
let compute_tags = encrypter
.compute_tags(vm, ciphertext_futs, tags_data)
.map_err(MpcTlsError::record_layer)?;
Ok((outputs, compute_tags))
}
pub(crate) struct EncryptOp {
pub(crate) seq: u64,
pub(crate) typ: ContentType,
pub(crate) version: ProtocolVersion,
pub(crate) len: usize,
pub(crate) plaintext: Option<Vec<u8>>,
pub(crate) explicit_nonce: Vec<u8>,
pub(crate) aad: Vec<u8>,
pub(crate) mode: EncryptMode,
}
impl EncryptOp {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
seq: u64,
typ: ContentType,
version: ProtocolVersion,
len: usize,
plaintext: Option<Vec<u8>>,
explicit_nonce: Vec<u8>,
aad: Vec<u8>,
mode: EncryptMode,
) -> Result<Self, MpcTlsError> {
if let Some(plaintext) = &plaintext {
if plaintext.len() != len {
return Err(MpcTlsError::record_layer(format!(
"inconsistent plaintext length: {} != {}",
plaintext.len(),
len
)));
}
}
if mode == EncryptMode::Public && plaintext.is_none() {
return Err(MpcTlsError::record_layer(
"plaintext must be provided in public mode",
));
}
Ok(Self {
seq,
typ,
version,
len,
plaintext,
explicit_nonce,
aad,
mode,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) enum EncryptMode {
Private,
Public,
}
pub(crate) enum EncryptOutput {
Private(EncryptPrivate),
Public(EncryptPublic),
}
impl EncryptOutput {
pub(crate) fn try_encrypt(self) -> Result<Vec<u8>, MpcTlsError> {
match self {
EncryptOutput::Private(encrypt) => encrypt.try_encrypt(),
EncryptOutput::Public(encrypt) => encrypt.try_encrypt(),
}
}
}
pub(crate) struct PendingEncrypt {
pub(crate) plaintext_ref: Option<Vector<U8>>,
pub(crate) output: EncryptOutput,
}
pub(crate) struct EncryptPrivate {
ciphertext: DecodeFutureTyped<BitVec, Vec<u8>>,
}
impl EncryptPrivate {
/// Encrypts the plaintext.
pub(crate) fn try_encrypt(mut self) -> Result<Vec<u8>, MpcTlsError> {
let ciphertext = self
.ciphertext
.try_recv()
.map_err(MpcTlsError::record_layer)?
.ok_or_else(|| MpcTlsError::record_layer("ciphertext is not ready"))?;
Ok(ciphertext)
}
}
pub(crate) struct EncryptPublic {
keystream: DecodeFutureTyped<BitVec, Vec<u8>>,
plaintext: Vec<u8>,
}
impl EncryptPublic {
pub(crate) fn try_encrypt(mut self) -> Result<Vec<u8>, MpcTlsError> {
let keystream = self
.keystream
.try_recv()
.map_err(MpcTlsError::record_layer)?
.ok_or_else(|| MpcTlsError::record_layer("keystream is not ready"))?;
if keystream.len() != self.plaintext.len() {
return Err(MpcTlsError::record_layer(format!(
"keystream length does not match plaintext length: {} != {}",
keystream.len(),
self.plaintext.len()
)));
}
let mut ciphertext = self.plaintext;
ciphertext
.iter_mut()
.zip(keystream)
.for_each(|(a, b)| *a ^= b);
Ok(ciphertext)
}
}

View File

@@ -0,0 +1,19 @@
use crate::MpcTlsError;
/// Split an opaque message into its constituent parts.
///
/// Returns the explicit nonce, ciphertext, and tag, respectively.
#[allow(clippy::type_complexity)]
pub(crate) fn opaque_into_parts(
mut msg: Vec<u8>,
) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>), MpcTlsError> {
let tag = msg.split_off(msg.len() - 16);
let ciphertext = msg.split_off(8);
let explicit_nonce = msg;
if explicit_nonce.len() != 8 {
return Err(MpcTlsError::other("explicit nonce length is not 8"));
}
Ok((explicit_nonce, ciphertext, tag))
}

View File

@@ -0,0 +1,169 @@
use std::sync::Arc;
use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
cot::{DerandCOTReceiver, DerandCOTSender},
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_client::Certificate;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "expensive"]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
let config = Config::builder()
.defer_decryption(false)
.max_sent(1 << 13)
.max_recv_online(1 << 13)
.build()
.unwrap();
let (leader, follower) = build_pair(config);
tokio::try_join!(
tokio::spawn(leader_task(leader)),
tokio::spawn(follower_task(follower))
)
.unwrap();
}
async fn leader_task(mut leader: MpcTlsLeader) {
leader.alloc().unwrap();
leader.preprocess().await.unwrap();
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();
let client = tls_client::ClientConnection::new(
Arc::new(config),
Box::new(leader_ctrl.clone()),
server_name,
)
.unwrap();
let (client_socket, server_socket) = tokio::io::duplex(1 << 16);
tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let (mut conn, conn_fut) = bind_client(client_socket.compat(), client);
let handle = tokio::spawn(async { conn_fut.await.unwrap() });
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: keep-alive\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 48];
conn.read_exact(&mut buf).await.unwrap();
leader_ctrl.defer_decryption().await.unwrap();
let msg = concat!(
"POST /echo HTTP/1.1\r\n",
"Host: test-server.io\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"Content-Length: 5\r\n",
"\r\n",
"hello",
"\r\n"
);
conn.write_all(msg.as_bytes()).await.unwrap();
conn.close().await.unwrap();
let mut buf = vec![0u8; 1024];
conn.read_to_end(&mut buf).await.unwrap();
leader_ctrl.stop().await.unwrap();
handle.await.unwrap();
}
async fn follower_task(mut follower: MpcTlsFollower) {
follower.alloc().unwrap();
follower.preprocess().await.unwrap();
follower.run().await.unwrap();
}
fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let mut rng = StdRng::seed_from_u64(0);
let (mut mt_a, mut mt_b) = test_mt_context(8);
let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap();
let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap();
let delta_a = Delta::new(rng.gen());
let delta_b = Delta::new(rng.gen());
let (rcot_send_a, rcot_recv_b) = ideal_rcot(rng.gen(), delta_a.into_inner());
let (rcot_send_b, rcot_recv_a) = ideal_rcot(rng.gen(), delta_b.into_inner());
let mut rcot_send_a = SharedRCOTSender::new(4, rcot_send_a);
let mut rcot_send_b = SharedRCOTSender::new(1, rcot_send_b);
let mut rcot_recv_a = SharedRCOTReceiver::new(1, rcot_recv_a);
let mut rcot_recv_b = SharedRCOTReceiver::new(4, rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(Generator::new(
DerandCOTSender::new(rcot_send_a.next().unwrap()),
rng.gen(),
delta_a,
)));
let mpc_b = Arc::new(Mutex::new(Evaluator::new(DerandCOTReceiver::new(
rcot_recv_b.next().unwrap(),
))));
let leader = MpcTlsLeader::new(
config.clone(),
ctx_a,
mpc_a,
(
rcot_send_a.next().unwrap(),
rcot_send_a.next().unwrap(),
rcot_send_a.next().unwrap(),
),
rcot_recv_a.next().unwrap(),
);
let follower = MpcTlsFollower::new(
config,
ctx_b,
mpc_b,
rcot_send_b.next().unwrap(),
(
rcot_recv_b.next().unwrap(),
rcot_recv_b.next().unwrap(),
rcot_recv_b.next().unwrap(),
),
);
(leader, follower)
}

View File

@@ -115,7 +115,7 @@ String
## Logging
The default logging strategy of this server is set to `DEBUG` verbosity level for the crates that are useful for most debugging scenarios, i.e. using the following filtering logic:
`notary_server=DEBUG,tlsn_verifier=DEBUG,tls_mpc=DEBUG,tls_client_async=DEBUG`
`notary_server=DEBUG,tlsn_verifier=DEBUG,mpc_tls=DEBUG,tls_client_async=DEBUG`
In the config [file](./config/config.yaml), one can toggle the verbosity level for these crates using the `level` field under `logging`.

View File

@@ -66,7 +66,7 @@ pub struct NotarySigningKeyProperties {
#[derive(Clone, Debug, Deserialize, Default)]
pub struct LoggingProperties {
/// Log verbosity level of the default filtering logic, which is
/// notary_server=<level>,tlsn_verifier=<level>,tls_mpc=<level> Must be either of <https://docs.rs/tracing/latest/tracing/struct.Level.html#implementations>
/// notary_server=<level>,tlsn_verifier=<level>,mpc_tls=<level> Must be either of <https://docs.rs/tracing/latest/tracing/struct.Level.html#implementations>
pub level: String,
/// Custom filtering logic, refer to the syntax here https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax
/// This will override the default filtering logic above

View File

@@ -13,7 +13,7 @@ pub fn init_tracing(config: &NotaryServerProperties) -> Result<()> {
// Use the default filter when only verbosity level is provided
None => {
let level = Level::from_str(&config.logging.level)?;
format!("notary_server={level},tlsn_verifier={level},tls_mpc={level}")
format!("notary_server={level},tlsn_verifier={level},mpc_tls={level}")
}
};
let filter_layer = EnvFilter::builder().parse(directives)?;

View File

@@ -16,9 +16,11 @@ force-st = ["mpz-common/force-st"]
[dependencies]
tlsn-common = { workspace = true }
tlsn-core = { workspace = true }
tlsn-deap = { workspace = true }
tlsn-tls-client = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-mpc = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-mpc-tls = { workspace = true }
serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] }
@@ -27,8 +29,11 @@ mpz-common = { workspace = true }
mpz-core = { workspace = true }
mpz-garble = { workspace = true }
mpz-garble-core = { workspace = true }
mpz-memory-core = { workspace = true }
mpz-ole = { workspace = true }
mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
@@ -37,4 +42,4 @@ rand = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
web-time = { workspace = true }
tokio = { workspace = true, features = ["sync"] }

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use mpz_ot::{chou_orlandi, kos};
use tls_mpc::{MpcTlsCommonConfig, MpcTlsLeaderConfig, TranscriptConfig};
use mpc_tls::Config;
use tlsn_common::config::ProtocolConfig;
use tlsn_core::{connection::ServerName, CryptoProvider};
@@ -15,8 +14,6 @@ pub struct ProverConfig {
protocol_config: ProtocolConfig,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
///
/// See `defer_decryption_from_start` in [tls_mpc::MpcTlsLeaderConfig].
#[builder(default = "true")]
defer_decryption_from_start: bool,
/// Cryptography provider.
@@ -51,52 +48,11 @@ impl ProverConfig {
self.defer_decryption_from_start
}
pub(crate) fn build_mpc_tls_config(&self) -> MpcTlsLeaderConfig {
MpcTlsLeaderConfig::builder()
.common(
MpcTlsCommonConfig::builder()
.tx_config(
TranscriptConfig::default_tx()
.max_online_size(self.protocol_config.max_sent_data())
.build()
.unwrap(),
)
.rx_config(
TranscriptConfig::default_rx()
.max_online_size(self.protocol_config.max_recv_data_online())
.max_offline_size(
self.protocol_config.max_recv_data()
- self.protocol_config.max_recv_data_online(),
)
.build()
.unwrap(),
)
.handshake_commit(true)
.build()
.unwrap(),
)
.build()
.unwrap()
}
pub(crate) fn build_base_ot_sender_config(&self) -> chou_orlandi::SenderConfig {
chou_orlandi::SenderConfig::builder()
.receiver_commit()
.build()
.unwrap()
}
pub(crate) fn build_base_ot_receiver_config(&self) -> chou_orlandi::ReceiverConfig {
chou_orlandi::ReceiverConfig::default()
}
pub(crate) fn build_ot_sender_config(&self) -> kos::SenderConfig {
kos::SenderConfig::default()
}
pub(crate) fn build_ot_receiver_config(&self) -> kos::ReceiverConfig {
kos::ReceiverConfig::builder()
.sender_commit()
pub(crate) fn build_mpc_tls_config(&self) -> Config {
Config::builder()
.defer_decryption(self.defer_decryption_from_start)
.max_sent(self.protocol_config.max_sent_data())
.max_recv_online(self.protocol_config.max_recv_data_online())
.build()
.unwrap()
}

View File

@@ -1,5 +1,6 @@
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
use tls_mpc::MpcTlsError;
use tlsn_common::{encoding::EncodingError, zk_aes::ZkAesCtrError};
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]
@@ -26,6 +27,27 @@ impl ProverError {
Self::new(ErrorKind::Config, source)
}
pub(crate) fn mpc<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Mpc, source)
}
pub(crate) fn zk<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn commit<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn attestation<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
@@ -38,7 +60,9 @@ impl ProverError {
enum ErrorKind {
Io,
Mpc,
Zk,
Config,
Commit,
Attestation,
}
@@ -49,7 +73,9 @@ impl fmt::Display for ProverError {
match self.kind {
ErrorKind::Io => f.write_str("io error")?,
ErrorKind::Mpc => f.write_str("mpc error")?,
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Attestation => f.write_str("attestation error")?,
}
@@ -91,50 +117,14 @@ impl From<MpcTlsError> for ProverError {
}
}
impl From<mpz_ot::OTError> for ProverError {
fn from(e: mpz_ot::OTError) -> Self {
Self::new(ErrorKind::Mpc, e)
impl From<ZkAesCtrError> for ProverError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<mpz_ot::kos::SenderError> for ProverError {
fn from(e: mpz_ot::kos::SenderError) -> Self {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<mpz_ole::OLEError> for ProverError {
fn from(e: mpz_ole::OLEError) -> Self {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<mpz_ot::kos::ReceiverError> for ProverError {
fn from(e: mpz_ot::kos::ReceiverError) -> Self {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<mpz_garble::VmError> for ProverError {
fn from(e: mpz_garble::VmError) -> Self {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<mpz_garble::protocol::deap::DEAPError> for ProverError {
fn from(e: mpz_garble::protocol::deap::DEAPError) -> Self {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<mpz_garble::MemoryError> for ProverError {
fn from(e: mpz_garble::MemoryError) -> Self {
Self::new(ErrorKind::Mpc, e)
}
}
impl From<mpz_garble::ProveError> for ProverError {
fn from(e: mpz_garble::ProveError) -> Self {
Self::new(ErrorKind::Mpc, e)
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -14,21 +14,20 @@ pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError};
pub use error::ProverError;
pub use future::ProverFuture;
use mpz_common::Context;
use mpz_garble_core::Delta;
use state::{Notarize, Prove};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpz_common::Allocate;
use mpz_garble::config::Role as DEAPRole;
use mpz_ot::{chou_orlandi, kos};
use rand::Rng;
use serio::{SinkExt, StreamExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader};
use rand::{thread_rng, Rng};
use serio::SinkExt;
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{bind_client, ClosedConnection, TlsConnection};
use tls_mpc::{build_components, LeaderCtrl, MpcTlsLeader, TlsRole};
use tls_client_async::{bind_client, TlsConnection};
use tls_core::msgs::enums::ContentType;
use tlsn_common::{
mux::{attach_mux, MuxControl},
DEAPThread, Executor, OTReceiver, OTSender, Role,
commit::commit_records, context::build_mt_context, mux::attach_mux, zk_aes::ZkAesCtr, Role,
};
use tlsn_core::{
connection::{
@@ -37,10 +36,24 @@ use tlsn_core::{
},
transcript::Transcript,
};
use uid_mux::FramedUidMux as _;
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{debug, info_span, instrument, Instrument, Span};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
mpz_core::Block,
>;
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
bool,
mpz_core::Block,
>;
pub(crate) type Mpc =
mpz_garble::protocol::semihonest::Generator<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState> {
@@ -78,37 +91,43 @@ impl Prover<state::Initialized> {
socket: S,
) -> Result<Prover<state::Setup>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let mut io = mux_fut
.poll_with(mux_ctrl.open_framed(b"tlsnotary"))
.await?;
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Sends protocol configuration to verifier for compatibility check.
mux_fut
.poll_with(io.send(self.config.protocol_config().clone()))
.poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
.await?;
// Maximum thread forking concurrency of 8.
// TODO: Determine the optimal number of threads.
let mut exec = Executor::new(mux_ctrl.clone(), 8);
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
let (mpc_tls, vm, ot_recv) = mux_fut
.poll_with(setup_mpc_backend(&self.config, &mux_ctrl, &mut exec))
.await?;
// Allocate resources for MPC-TLS in VM.
let keys = mpc_tls.alloc()?;
// Allocate for committing to plaintext.
let mut zk_aes = ZkAesCtr::new(Role::Prover);
zk_aes.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes.alloc(
&mut (*vm.try_lock().expect("VM is not locked").zk()),
self.config.protocol_config().max_recv_data(),
)?;
let ctx = mux_fut.poll_with(exec.new_thread()).await?;
debug!("setting up mpc-tls");
mux_fut.poll_with(mpc_tls.preprocess()).await?;
debug!("mpc-tls setup complete");
Ok(Prover {
config: self.config,
span: self.span,
state: state::Setup {
io,
mux_ctrl,
mux_fut,
mt,
mpc_tls,
zk_aes,
keys,
vm,
ot_recv,
ctx,
},
})
}
@@ -129,13 +148,13 @@ impl Prover<state::Setup> {
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::Setup {
io,
mux_ctrl,
mut mux_fut,
mt,
mpc_tls,
mut zk_aes,
keys,
vm,
ot_recv,
ctx,
} = self.state;
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
@@ -158,79 +177,122 @@ impl Prover<state::Setup> {
let (conn, conn_fut) = bind_client(socket, client);
let start_time = web_time::UNIX_EPOCH.elapsed().unwrap().as_secs();
let start_time = web_time::UNIX_EPOCH
.elapsed()
.expect("system time is available")
.as_secs();
let fut = Box::pin({
let span = self.span.clone();
let mpc_ctrl = mpc_ctrl.clone();
async move {
let conn_fut = async {
let ClosedConnection { sent, recv, .. } = mux_fut
mux_fut
.poll_with(conn_fut.map_err(ProverError::from))
.await?;
mpc_ctrl.close_connection().await?;
mpc_ctrl.stop().await?;
Ok::<_, ProverError>((sent, recv))
Ok::<_, ProverError>(())
};
let ((sent, recv), mpc_tls_data) = futures::try_join!(
let (_, (mut ctx, mut data)) = futures::try_join!(
conn_fut,
mpc_fut.in_current_span().map_err(ProverError::from)
)?;
{
let mut vm = vm.try_lock().expect("VM should not be locked");
// Prove received plaintext. Prover drops the proof output, as they trust
// themselves.
_ = commit_records(
&mut (*vm.zk()),
&mut zk_aes,
data.transcript
.recv
.iter_mut()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
debug!("finalizing mpc");
// Finalize DEAP and execute the plaintext proofs.
mux_fut
.poll_with(vm.finalize(&mut ctx))
.await
.map_err(ProverError::mpc)?;
debug!("mpc finalized");
}
let transcript = data
.transcript
.to_transcript()
.expect("transcript is complete");
let transcript_refs = data
.transcript
.to_transcript_refs()
.expect("transcript is complete");
let connection_info = ConnectionInfo {
time: start_time,
version: mpc_tls_data
version: data
.protocol_version
.try_into()
.expect("only supported version should have been accepted"),
transcript_length: TranscriptLength {
sent: sent.len() as u32,
received: recv.len() as u32,
sent: transcript.sent().len() as u32,
received: transcript.received().len() as u32,
},
};
let server_cert_data = ServerCertData {
certs: mpc_tls_data
.server_cert_details
.cert_chain()
.iter()
.cloned()
.map(|c| c.into())
.collect(),
sig: ServerSignature {
scheme: mpc_tls_data
.server_kx_details
.kx_sig()
.scheme
.try_into()
.expect("only supported signature scheme should have been accepted"),
sig: mpc_tls_data.server_kx_details.kx_sig().sig.0.clone(),
},
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
client_random: mpc_tls_data.client_random.0,
server_random: mpc_tls_data.server_random.0,
server_ephemeral_key: mpc_tls_data
.server_public_key
.try_into()
.expect("only supported key scheme should have been accepted"),
}),
};
let server_cert_data =
ServerCertData {
certs: data
.server_cert_details
.cert_chain()
.iter()
.cloned()
.map(|c| c.into())
.collect(),
sig: ServerSignature {
scheme: data.server_kx_details.kx_sig().scheme.try_into().expect(
"only supported signature scheme should have been accepted",
),
sig: data.server_kx_details.kx_sig().sig.0.clone(),
},
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
client_random: data.client_random.0,
server_random: data.server_random.0,
server_ephemeral_key: data
.server_key
.try_into()
.expect("only supported key scheme should have been accepted"),
}),
};
// Pull out ZK VM
let (_, vm) = Arc::into_inner(vm)
.expect("vm should have only 1 reference")
.into_inner()
.into_inner();
Ok(Prover {
config: self.config,
span: self.span,
state: state::Closed {
io,
mux_ctrl,
mux_fut,
vm,
ot_recv,
mt,
ctx,
_keys: keys,
vm,
connection_info,
server_cert_data,
transcript: Transcript::new(sent, recv),
transcript,
transcript_refs,
},
})
}
@@ -279,123 +341,55 @@ impl Prover<state::Closed> {
}
}
/// Performs a setup of the various MPC subprotocols.
#[instrument(level = "debug", skip_all, err)]
async fn setup_mpc_backend(
config: &ProverConfig,
mux: &MuxControl,
exec: &mut Executor,
) -> Result<(MpcTlsLeader, DEAPThread, OTReceiver), ProverError> {
debug!("starting MPC backend setup");
fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
let mut rng = thread_rng();
let delta = Delta::new(rng.gen());
let mut ot_sender = kos::Sender::new(
config.build_ot_sender_config(),
chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()),
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(
mpz_ot::kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
ot_sender.alloc(config.protocol_config().ot_sender_setup_count(Role::Prover));
let mut ot_receiver = kos::Receiver::new(
config.build_ot_receiver_config(),
chou_orlandi::Sender::new(config.build_base_ot_sender_config()),
);
ot_receiver.alloc(
config
.protocol_config()
.ot_receiver_setup_count(Role::Prover),
let rcot_recv =
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
let rcot_recv = mpz_ot::ferret::Receiver::new(
mpz_ot::ferret::FerretConfig::builder()
.lpn_type(mpz_ot::ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
rng.gen(),
rcot_recv,
);
let ot_sender = OTSender::new(ot_sender);
let ot_receiver = OTReceiver::new(ot_receiver);
let mut rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(4, rcot_send);
let mut rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(2, rcot_recv);
let (
ctx_vm,
ctx_ke_0,
ctx_ke_1,
ctx_prf_0,
ctx_prf_1,
ctx_encrypter_block_cipher,
ctx_encrypter_stream_cipher,
ctx_encrypter_ghash,
ctx_encrypter,
ctx_decrypter_block_cipher,
ctx_decrypter_stream_cipher,
ctx_decrypter_ghash,
ctx_decrypter,
) = futures::try_join!(
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
exec.new_thread(),
)?;
let vm = DEAPThread::new(
DEAPRole::Leader,
rand::rngs::OsRng.gen(),
ctx_vm,
ot_sender.clone(),
ot_receiver.clone(),
let mpc = Mpc::new(
mpz_ot::cot::DerandCOTSender::new(rcot_send.next().expect("enough senders are available")),
rng.gen(),
delta,
);
let mpc_tls_config = config.build_mpc_tls_config();
let (ke, prf, encrypter, decrypter) = build_components(
TlsRole::Leader,
mpc_tls_config.common(),
ctx_ke_0,
ctx_encrypter,
ctx_decrypter,
ctx_encrypter_ghash,
ctx_decrypter_ghash,
vm.new_thread(ctx_ke_1, ot_sender.clone(), ot_receiver.clone())?,
vm.new_thread(ctx_prf_0, ot_sender.clone(), ot_receiver.clone())?,
vm.new_thread(ctx_prf_1, ot_sender.clone(), ot_receiver.clone())?,
vm.new_thread(
ctx_encrypter_block_cipher,
ot_sender.clone(),
ot_receiver.clone(),
)?,
vm.new_thread(
ctx_decrypter_block_cipher,
ot_sender.clone(),
ot_receiver.clone(),
)?,
vm.new_thread(
ctx_encrypter_stream_cipher,
ot_sender.clone(),
ot_receiver.clone(),
)?,
vm.new_thread(
ctx_decrypter_stream_cipher,
ot_sender.clone(),
ot_receiver.clone(),
)?,
ot_sender.clone(),
ot_receiver.clone(),
);
let zk = Zk::new(rcot_recv.next().expect("enough receivers are available"));
let channel = mux.open_framed(b"mpc_tls").await?;
let mut mpc_tls = MpcTlsLeader::new(
mpc_tls_config,
Box::new(StreamExt::compat_stream(channel)),
ke,
prf,
encrypter,
decrypter,
);
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
mpc_tls.setup().await?;
debug!("MPC backend setup complete");
Ok((mpc_tls, vm, ot_receiver))
(
vm.clone(),
MpcTlsLeader::new(
config.build_mpc_tls_config(),
ctx,
vm,
(
rcot_send.next().expect("enough senders are available"),
rcot_send.next().expect("enough senders are available"),
rcot_send.next().expect("enough senders are available"),
),
rcot_recv.next().expect("enough receivers are available"),
),
)
}
/// A controller for the prover.

Some files were not shown because too many files have changed in this diff Show More