Add tests and refactor

This commit is contained in:
th4s
2023-07-27 15:33:33 +02:00
parent 1e38bee459
commit a6bd12be4f
2 changed files with 121 additions and 50 deletions

17
src/dot.rs Normal file
View File

@@ -0,0 +1,17 @@
use mpz_share_conversion_core::Field;
pub trait DotProduct: Copy {
type Output;
fn dot(self, other: Self) -> Self::Output;
}
impl<T: Field, const N: usize> DotProduct for [T; N] {
type Output = T;
fn dot(self, other: Self) -> Self::Output {
self.into_iter()
.zip(other)
.fold(T::zero(), |acc, (a, b)| acc + a * b)
}
}

View File

@@ -5,6 +5,9 @@ use mpz_share_conversion_core::Field;
use rand::{rngs::ThreadRng, thread_rng, Rng};
use std::marker::PhantomData;
mod dot;
use dot::DotProduct;
// Some constants defined in the paper
//
// Bits needed for to represent elements of the field
@@ -32,14 +35,12 @@ impl<T: Field> M2A<T> {
std::array::from_fn(|_| T::rand(&mut self.rng))
}
// Note that the return dimension is [[[T; 2]; ZETA]; L] which we simplify a little bit
fn alpha(&mut self, a_tilde: [T; L], a_head: [T; L]) -> [[T; 2]; ETA] {
let mut alpha = [[T::zero(); 2]; ETA];
fn alpha(&mut self, a_tilde: [T; L], a_head: [T; L]) -> [[T; ETA]; 2] {
let mut alpha = [[T::zero(); ETA]; 2];
for k in 0..L {
for i in 0..ZETA {
alpha[k * ZETA + i] = [a_tilde[k], a_head[k]];
}
alpha[0][k * ZETA..(k + 1) * ZETA].copy_from_slice(&[a_tilde[k]; ZETA]);
alpha[1][k * ZETA..(k + 1) * ZETA].copy_from_slice(&[a_head[k]; ZETA]);
}
alpha
@@ -73,8 +74,11 @@ impl<T: Field> M2A<T> {
b_tilde
}
fn omega_a(&mut self) -> [[T; 2]; ETA] {
std::array::from_fn(|_| [T::rand(&mut self.rng), T::rand(&mut self.rng)])
fn omega_a(&mut self) -> [[T; ETA]; 2] {
let first = std::array::from_fn(|_| T::rand(&mut self.rng));
let second = std::array::from_fn(|_| T::rand(&mut self.rng));
[first, second]
}
fn chi_head(&self) -> [T; L] {
@@ -113,7 +117,7 @@ impl<T: Field> M2A<T> {
u
}
fn bob_check(
fn check(
&self,
r: [T; ZETA],
u: [T; L],
@@ -151,13 +155,12 @@ impl<T: Field> M2A<T> {
gamma
}
//TODO: Is there a mistake in the paper with b_tilde ?
fn output(&self, input: [T; L], gamma: [T; L], gadget: [T; ZETA], z_tilde: [T; ETA]) -> [T; L] {
let mut z = [T::zero(); L];
for (i, el) in z.iter_mut().enumerate() {
*el = input[i] * gamma[i]
+ (0..ZETA).fold(T::zero(), |acc, j| acc + gadget[i] * z_tilde[i * ZETA + j]);
+ (0..ZETA).fold(T::zero(), |acc, j| acc + gadget[j] * z_tilde[i * ZETA + j]);
}
z
}
@@ -165,7 +168,7 @@ impl<T: Field> M2A<T> {
impl<T: Field> Default for M2A<T> {
fn default() -> Self {
let mut rng = thread_rng();
let rng = thread_rng();
Self {
_field: PhantomData,
@@ -174,68 +177,119 @@ impl<T: Field> Default for M2A<T> {
}
}
pub fn func_cote<T: Field>(
alpha: [[T; 2]; ETA],
fn func_cote<T: Field>(
alpha: [[T; ETA]; 2],
beta: [T; ETA],
omega_a: [[T; 2]; ETA],
) -> ([[T; 2]; ETA], [[T; 2]; ETA]) {
let mut omega_b: [[T; 2]; ETA] = [[T::zero(); 2]; ETA];
omega_a: [[T; ETA]; 2],
) -> ([[T; ETA]; 2], [[T; ETA]; 2]) {
let mut omega_b: [[T; ETA]; 2] = [[T::zero(); ETA]; 2];
for k in 0..ETA {
omega_b[k] = [
alpha[k][0] * beta[k] + -omega_a[k][0],
alpha[k][1] * beta[k] + -omega_a[k][1],
];
omega_b[0][k] = alpha[0][k] * beta[k] + -omega_a[0][k];
omega_b[1][k] = alpha[1][k] * beta[k] + -omega_a[1][k];
}
(omega_a, omega_b)
}
trait DotProduct: Copy {
type Output;
pub fn protocol<T: Field>(
a: [T; L],
b: [T; L],
malicious_omega_a: Option<[[T; ETA]; 2]>,
) -> ([T; L], [T; L]) {
//Setup
let mut m2a = M2A::<T>::default();
let gadget = m2a.gadget();
fn dot(self, other: Self) -> Self::Output;
}
// Step 1
let beta = m2a.beta();
let b_tilde = m2a.b_tilde(gadget, beta);
impl<T: Field, const N: usize> DotProduct for [T; N] {
type Output = T;
// Step 2
let a_tilde = m2a.a_tilde();
let a_head = m2a.a_head();
let alpha = m2a.alpha(a_tilde, a_head);
fn dot(self, other: Self) -> Self::Output {
self.into_iter()
.zip(other)
.fold(T::zero(), |acc, (a, b)| acc + a * b)
// Step 3
// If Alice is malicious, she can choose omega_a in the implementation of F_COTE
let (omega_a, omega_b) = if let Some(omega_a) = malicious_omega_a {
func_cote(alpha, beta, omega_a)
} else {
// If she is not malicious she supplies random input
func_cote(alpha, beta, m2a.omega_a())
};
let [z_tilde_a, z_head_a] = [omega_a[0], omega_a[1]];
let [z_tilde_b, z_head_b] = [omega_b[0], omega_b[1]];
// Step 4
let (chi_tilde, chi_head) = (m2a.chi_tilde(), m2a.chi_head());
// Step 5
let r = m2a.r(chi_tilde, chi_head, z_tilde_a, z_head_a);
let u = m2a.u(chi_tilde, chi_head, a_tilde, a_head);
// Step 6
let check = m2a.check(r, u, beta, chi_tilde, chi_head, z_tilde_b, z_head_b);
if !check {
panic!("Consistency check failed!");
}
// Step 7
let gamma_a = m2a.gamma(a, a_tilde);
let gamma_b = m2a.gamma(b, b_tilde);
// Step 8
let z_a = m2a.output(a, gamma_b, gadget, z_tilde_a);
let z_b = m2a.output(b_tilde, gamma_a, gadget, z_tilde_b);
(z_a, z_b)
}
#[cfg(test)]
mod tests {
use mpz_share_conversion_core::fields::p256::P256;
use mpz_share_conversion_core::fields::{p256::P256, UniformRand};
use super::*;
#[test]
fn test_two_party_mul() {
let mut m2a = M2A::<P256>::default();
let gadget = m2a.gadget();
// Get rng
let mut rng = thread_rng();
// Step 1
let beta = m2a.beta();
let b_tilde = m2a.b_tilde(gadget, beta);
// Random input
let a = [P256::rand(&mut rng)];
let b = [P256::rand(&mut rng)];
// Step 2
let a_tilde = m2a.a_tilde();
let a_head = m2a.a_head();
let alpha = m2a.alpha(a_tilde, a_head);
// Execute protocol
// Alice is honest
let (z_a, z_b) = protocol(a, b, None);
// Step 3
let (omega_a, omega_b) = func_cote(alpha, beta, m2a.omega_a());
let [z_tilde_a, z_head_a] = [omega_a[0], omega_a[1]];
let [z_tilde_b, z_head_b] = [omega_b[0], omega_b[1]];
// Check result
for k in 0..L {
assert_eq!(z_a[k] + z_b[k], a[k] * b[k]);
}
}
// Step 4
let (chi_tilde, chi_head) = (m2a.chi_tilde(), m2a.chi_head());
#[test]
fn test_alice_malicious() {
// Get rng
let mut rng = thread_rng();
// Step 5
let r = m2a.r(chi_tilde, chi_head, z_tilde_a, z_head_a);
let u = m2a.u(chi_tilde, chi_head, a_tilde, a_head);
// Alice is malicious and supplies zero as input. This allows her to infer Bob's output z_b
let a = [P256::zero()];
let b = [P256::rand(&mut rng)];
// Execute protocol
// No need for Alice to choose omega_a
let (z_a, z_b) = protocol(a, b, None);
// Check result
for k in 0..L {
//Protocol works
assert_eq!(z_a[k] + z_b[k], a[k] * b[k]);
//Check Bob's output. Indeed z_b = -z_a
assert_ne!(z_b[k], -z_a[k]);
}
}
}