From 0287e2cd924e13cfacbe30db092e91c2a96158ba Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Wed, 11 Jan 2023 09:24:27 -0800 Subject: [PATCH] Label Encoder (#148) * Label encoder * fix word pos storage * fix test Co-authored-by: themighty1 --- mpc-core/Cargo.toml | 4 + mpc-core/benches/encoder.rs | 25 +++++ mpc-core/src/garble/label/encoder.rs | 135 +++++++++++++++++++++++++++ mpc-core/src/garble/label/mod.rs | 2 + mpc-core/src/garble/mod.rs | 4 +- 5 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 mpc-core/benches/encoder.rs create mode 100644 mpc-core/src/garble/label/encoder.rs diff --git a/mpc-core/Cargo.toml b/mpc-core/Cargo.toml index f540102cb..9752472cb 100644 --- a/mpc-core/Cargo.toml +++ b/mpc-core/Cargo.toml @@ -52,3 +52,7 @@ harness = false [[bench]] name = "ot" harness = false + +[[bench]] +name = "encoder" +harness = false diff --git a/mpc-core/benches/encoder.rs b/mpc-core/benches/encoder.rs new file mode 100644 index 000000000..9f32c6dc4 --- /dev/null +++ b/mpc-core/benches/encoder.rs @@ -0,0 +1,25 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use mpc_circuits::{Circuit, WireGroup, AES_128_REVERSE}; +use mpc_core::garble::ChaChaEncoder; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("encoder"); + + let circ = Arc::new(Circuit::load_bytes(AES_128_REVERSE).unwrap()); + group.bench_function(circ.name(), |b| { + let mut enc = ChaChaEncoder::new([0u8; 32], 0); + b.iter(|| { + black_box( + circ.inputs() + .iter() + .map(|input| enc.encode(input.id() as u32, input)) + .collect::>(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/mpc-core/src/garble/label/encoder.rs b/mpc-core/src/garble/label/encoder.rs new file mode 100644 index 000000000..4087c9895 --- /dev/null +++ b/mpc-core/src/garble/label/encoder.rs @@ -0,0 +1,135 @@ +use std::collections::HashMap; + +use mpc_circuits::Input; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; + +use super::{Delta, FullInputLabels}; + +/// Encodes wire labels using the ChaCha algorithm and a global offset (delta). +/// +/// An encoder instance is configured using a domain id. Domain ids can be used in combination +/// with stream ids to partition label sets. +#[derive(Debug)] +pub struct ChaChaEncoder { + seed: [u8; 32], + domain: u32, + rng: ChaCha20Rng, + stream_state: HashMap, + delta: Delta, +} + +impl ChaChaEncoder { + /// Creates a new encoder with the provided seed + /// + /// * `seed` - 32-byte seed for ChaChaRng + /// * `domain` - Domain id + /// + /// Domain id must be less than 2^31 + pub fn new(seed: [u8; 32], domain: u32) -> Self { + assert!(domain <= u32::MAX >> 1); + + let mut rng = ChaCha20Rng::from_seed(seed); + + // Stream id 0 is reserved to generate delta. + // This way there is only ever 1 delta per seed + rng.set_stream(0); + let delta = Delta::random(&mut rng); + + Self { + seed, + domain, + rng, + stream_state: HashMap::default(), + delta, + } + } + + /// Returns encoder's rng seed + pub fn get_seed(&self) -> [u8; 32] { + self.seed + } + + /// Encodes input using the provided stream id + /// + /// * `stream_id` - Stream id which can be used to partition label sets + /// * `input` - Circuit input to encode + pub fn encode(&mut self, stream_id: u32, input: &Input) -> FullInputLabels { + self.set_stream(stream_id); + FullInputLabels::generate(&mut self.rng, input.clone(), self.delta) + } + + /// Sets the selected stream id, restoring word position if a stream + /// has been used before. + fn set_stream(&mut self, id: u32) { + // MSB -> LSB + // 31 bits 32 bits 1 bit + // [domain] [id] [reserved] + // The reserved bit ensures that we never pull from stream 0 which + // is reserved to generate delta + let new_id = ((self.domain as u64) << 33) + ((id as u64) << 1) + 1; + + let current_id = self.rng.get_stream(); + + // noop if stream already set + if new_id == current_id { + return; + } + + // Store word position for current stream + self.stream_state + .insert(current_id, self.rng.get_word_pos()); + + // Update stream id + self.rng.set_stream(new_id); + + // Get word position if stored, otherwise default to 0 + let word_pos = self.stream_state.get(&new_id).copied().unwrap_or(0); + + // Update word position + self.rng.set_word_pos(word_pos); + } +} + +#[cfg(test)] +mod test { + use mpc_circuits::{Circuit, WireGroup, ADDER_64}; + + use super::*; + use rstest::*; + + #[fixture] + fn circ() -> Circuit { + Circuit::load_bytes(ADDER_64).unwrap() + } + + #[rstest] + fn test_encoder(circ: Circuit) { + let mut enc = ChaChaEncoder::new([0u8; 32], 0); + + for input in circ.inputs() { + enc.encode(input.id() as u32, input); + } + } + + #[rstest] + fn test_encoder_no_duplicates(circ: Circuit) { + let input = circ.input(0).unwrap(); + + let mut enc = ChaChaEncoder::new([0u8; 32], 0); + + // Pull from stream 0 + let a = enc.encode(0, &input); + + // Pull from a different stream + let c = enc.encode(1, &input); + + // Pull from stream 0 again + let b = enc.encode(0, &input); + + // Switching back to the same stream should preserve the word position + assert_ne!(a, b); + // Different stream ids should produce different labels + assert_ne!(a, c); + } +} diff --git a/mpc-core/src/garble/label/mod.rs b/mpc-core/src/garble/label/mod.rs index 98514fda0..5f68d5388 100644 --- a/mpc-core/src/garble/label/mod.rs +++ b/mpc-core/src/garble/label/mod.rs @@ -1,6 +1,7 @@ //! Collection of labels corresponding to a wire group. mod digest; +mod encoder; pub(crate) mod input; pub(crate) mod output; mod state; @@ -15,6 +16,7 @@ use std::{ use crate::{block::Block, garble::LabelError}; pub use digest::LabelsDigest; +pub use encoder::ChaChaEncoder; pub(crate) use input::SanitizedInputLabels; pub use output::OutputLabelsCommitment; diff --git a/mpc-core/src/garble/mod.rs b/mpc-core/src/garble/mod.rs index 623a929f5..1500dcc94 100644 --- a/mpc-core/src/garble/mod.rs +++ b/mpc-core/src/garble/mod.rs @@ -15,8 +15,8 @@ pub(crate) mod label; pub use circuit::{state as gc_state, CircuitOpening, GarbledCircuit}; pub use error::{Error, InputError, LabelError}; pub use label::{ - ActiveInputLabels, ActiveOutputLabels, Delta, FullInputLabels, FullOutputLabels, Labels, - LabelsDecodingInfo, LabelsDigest, WireLabel, WireLabelPair, + ActiveInputLabels, ActiveOutputLabels, ChaChaEncoder, Delta, FullInputLabels, FullOutputLabels, + Labels, LabelsDecodingInfo, LabelsDigest, WireLabel, WireLabelPair, }; #[cfg(test)]