Sumcheck prover

Co-authored-by: jeong0982 <soowon1106@gmail.com>
This commit is contained in:
DoHoonKim
2024-05-13 22:05:58 +09:00
parent c928660689
commit 9e3da9e229
2 changed files with 91 additions and 22 deletions

View File

@@ -1,3 +1,4 @@
use std::fmt::{Debug, Formatter, Result as fmtResult};
use super::{SumCheck, VirtualPolynomial};
use crate::utils::transcript::{FieldTranscriptRead, FieldTranscriptWrite};
use crate::utils::ProtocolError;
@@ -6,16 +7,61 @@ use ff::{Field, PrimeField};
#[derive(Clone, Debug)]
struct ClassicSumcheck;
#[derive(Clone, Debug)]
struct ClassicSumcheckProverParam<F: Field> {
num_vars: usize,
max_degree: usize,
combine_function: fn(&Vec<F>) -> F,
struct CombineFunction<F>(Box<dyn Fn(&Vec<F>) -> F>);
pub trait CombineFunctionClone<F> {
fn clone_box(&self) -> Box<dyn Fn(&Vec<F>) -> F>;
}
impl<F> CombineFunctionClone<F> for dyn Fn(&Vec<F>) -> F {
fn clone_box(&self) -> Box<dyn Fn(&Vec<F>) -> F> {
Box::new(self.clone())
}
}
impl<F> Clone for CombineFunction<F> {
fn clone(&self) -> Self {
Self(self.0.clone_box())
}
}
impl<F> Debug for CombineFunction<F> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmtResult {
write!(f, "debug")
}
}
impl<F> CombineFunction<F> {
pub fn new(combine_function: impl Fn(&Vec<F>) -> F) -> Self {
Self(Box::new(combine_function))
}
pub fn apply(&self, evals: &Vec<F>) -> F {
(self.0)(evals)
}
}
#[derive(Clone, Debug)]
struct ClassicSumcheckVerifierParam<F: Field> {
combine_function: fn(Vec<F>) -> F,
pub struct ClassicSumcheckProverParam<F: Field> {
num_vars: usize,
max_degree: usize,
combine_function: CombineFunction<F>,
}
impl<F: Field> ClassicSumcheckProverParam<F> {
pub fn new(num_vars: usize, max_degree: usize, combine_function: impl Fn(&Vec<F>) -> F) -> Self {
ClassicSumcheckProverParam {
num_vars,
max_degree,
combine_function: CombineFunction::new(combine_function),
}
}
}
#[derive(Clone, Debug)]
struct ClassicSumcheckVerifierParam {
num_vars: usize,
max_degree: usize,
}
impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
@@ -34,8 +80,8 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
.map(|_| vec![F::ZERO; r_degree + 1])
.collect();
transcript.write_field_element(&sum)?;
let mut challenges = vec![];
let mut evaluations = vec![];
for round_index in 0..pp.num_vars {
for k in 0..(r_degree + 1) {
for i in 0..virtual_poly.polys()[0].size() {
@@ -50,7 +96,7 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
.collect::<Vec<F>>();
// apply combine function
r_polys[round_index][k] += (pp.combine_function)(&evaluations_at_k);
r_polys[round_index][k] += pp.combine_function.apply(&evaluations_at_k);
}
}
// append the round polynomial (i.e. prover message) to the transcript
@@ -58,16 +104,22 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
// generate challenge α_i = H( transcript );
let alpha = transcript.squeeze_challenge();
challenges.push(alpha);
// update prover state polynomials
virtual_poly.fold_into_half(alpha);
if round_index == pp.num_vars - 1 {
// last round
evaluations = virtual_poly.evaluations(alpha);
} else {
// update prover state polynomials
virtual_poly.fold_into_half(alpha);
}
}
for i in 0..virtual_poly.polys().len() {
assert_eq!(virtual_poly.polys()[i].size(), 1);
}
Ok((vec![], vec![]))
Ok((challenges, evaluations))
}
fn verify(
@@ -86,7 +138,7 @@ mod test {
use crate::{
poly::multilinear::MultilinearPolynomial,
sumcheck::{EvalTable, SumCheck, VirtualPolynomial},
sumcheck::{EvalTable, SumCheck, VirtualPolynomial, classic::CombineFunction},
utils::{transcript::Keccak256Transcript, ProtocolError},
};
use ff::Field;
@@ -151,10 +203,10 @@ mod test {
let pp = ClassicSumcheckProverParam {
num_vars,
max_degree: 3,
combine_function,
combine_function: CombineFunction::new(combine_function),
};
let mut transcript = Keccak256Transcript::<Cursor<Vec<u8>>>::default();
let virtual_poly = VirtualPolynomial::new(num_vars, &polys);
let virtual_poly = VirtualPolynomial::new(num_vars, polys.iter().collect_vec().borrow());
ClassicSumcheck::prove(&pp, claimed_sum, virtual_poly, &mut transcript)?;
Ok(())

View File

@@ -1,4 +1,4 @@
use std::fmt::Debug;
use std::{borrow::Borrow, fmt::Debug};
use ff::Field;
use itertools::Itertools;
@@ -48,9 +48,13 @@ impl<F: Field> EvalTable<F> {
}
pub fn fold_into_half(&mut self, challenge: F) {
for i in 0..self.table.len() / 2 {
let eval_pair_distant = self.table[i + self.table.len() / 2].clone();
let eval_pair = &mut self.table[i];
assert_ne!(self.table.len(), 1);
let len = self.table.len();
let (pairs, pairs_distant) = self.table.split_at_mut(len / 2);
let pairs_distant = &*pairs_distant;
for i in 0..len / 2 {
let eval_pair_distant = &pairs_distant[i];
let eval_pair = &mut pairs[i];
eval_pair.even = eval_pair.even + challenge * (eval_pair.odd - eval_pair.even);
eval_pair.odd = eval_pair_distant.even
+ challenge * (eval_pair_distant.odd - eval_pair_distant.even);
@@ -65,7 +69,7 @@ pub struct VirtualPolynomial<F: Field> {
}
impl<F: Field> VirtualPolynomial<F> {
pub fn new(num_vars: usize, polys: &Vec<MultilinearPolynomial<F>>) -> Self {
pub fn new(num_vars: usize, polys: &[&MultilinearPolynomial<F>]) -> Self {
let polys = polys
.iter()
.map(|poly| EvalTable::new(num_vars, poly))
@@ -82,6 +86,19 @@ impl<F: Field> VirtualPolynomial<F> {
poly.fold_into_half(challenge);
}
}
/// called at the last round of sumcheck
pub fn evaluations(&self, challenge: F) -> Vec<F> {
self.polys.iter().for_each(|poly| {
assert_eq!(poly.size(), 1);
});
self.polys
.iter()
.map(|poly| {
poly.table()[0].even + challenge * (poly.table()[0].odd - poly.table()[0].even)
})
.collect_vec()
}
}
pub trait SumCheck<F: Field>: Clone + Debug {
@@ -101,5 +118,5 @@ pub trait SumCheck<F: Field>: Clone + Debug {
degree: usize,
sum: F,
transcript: &mut impl FieldTranscriptRead<F>,
) -> Result<(Vec<F>, Vec<F>), ProtocolError>;
) -> Result<(F, Vec<F>), ProtocolError>;
}