mirror of
https://github.com/pseXperiments/clookup.git
synced 2026-01-08 23:28:10 -05:00
@@ -1,3 +1,4 @@
|
||||
use std::fmt::{Debug, Formatter, Result as fmtResult};
|
||||
use super::{SumCheck, VirtualPolynomial};
|
||||
use crate::utils::transcript::{FieldTranscriptRead, FieldTranscriptWrite};
|
||||
use crate::utils::ProtocolError;
|
||||
@@ -6,16 +7,61 @@ use ff::{Field, PrimeField};
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClassicSumcheck;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClassicSumcheckProverParam<F: Field> {
|
||||
num_vars: usize,
|
||||
max_degree: usize,
|
||||
combine_function: fn(&Vec<F>) -> F,
|
||||
struct CombineFunction<F>(Box<dyn Fn(&Vec<F>) -> F>);
|
||||
|
||||
pub trait CombineFunctionClone<F> {
|
||||
fn clone_box(&self) -> Box<dyn Fn(&Vec<F>) -> F>;
|
||||
}
|
||||
|
||||
impl<F> CombineFunctionClone<F> for dyn Fn(&Vec<F>) -> F {
|
||||
fn clone_box(&self) -> Box<dyn Fn(&Vec<F>) -> F> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Clone for CombineFunction<F> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone_box())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Debug for CombineFunction<F> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmtResult {
|
||||
write!(f, "debug")
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> CombineFunction<F> {
|
||||
pub fn new(combine_function: impl Fn(&Vec<F>) -> F) -> Self {
|
||||
Self(Box::new(combine_function))
|
||||
}
|
||||
|
||||
pub fn apply(&self, evals: &Vec<F>) -> F {
|
||||
(self.0)(evals)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClassicSumcheckVerifierParam<F: Field> {
|
||||
combine_function: fn(Vec<F>) -> F,
|
||||
pub struct ClassicSumcheckProverParam<F: Field> {
|
||||
num_vars: usize,
|
||||
max_degree: usize,
|
||||
combine_function: CombineFunction<F>,
|
||||
}
|
||||
|
||||
impl<F: Field> ClassicSumcheckProverParam<F> {
|
||||
pub fn new(num_vars: usize, max_degree: usize, combine_function: impl Fn(&Vec<F>) -> F) -> Self {
|
||||
ClassicSumcheckProverParam {
|
||||
num_vars,
|
||||
max_degree,
|
||||
combine_function: CombineFunction::new(combine_function),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClassicSumcheckVerifierParam {
|
||||
num_vars: usize,
|
||||
max_degree: usize,
|
||||
}
|
||||
|
||||
impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
|
||||
@@ -34,8 +80,8 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
|
||||
.map(|_| vec![F::ZERO; r_degree + 1])
|
||||
.collect();
|
||||
|
||||
transcript.write_field_element(&sum)?;
|
||||
|
||||
let mut challenges = vec![];
|
||||
let mut evaluations = vec![];
|
||||
for round_index in 0..pp.num_vars {
|
||||
for k in 0..(r_degree + 1) {
|
||||
for i in 0..virtual_poly.polys()[0].size() {
|
||||
@@ -50,7 +96,7 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
|
||||
.collect::<Vec<F>>();
|
||||
|
||||
// apply combine function
|
||||
r_polys[round_index][k] += (pp.combine_function)(&evaluations_at_k);
|
||||
r_polys[round_index][k] += pp.combine_function.apply(&evaluations_at_k);
|
||||
}
|
||||
}
|
||||
// append the round polynomial (i.e. prover message) to the transcript
|
||||
@@ -58,16 +104,22 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
|
||||
|
||||
// generate challenge α_i = H( transcript );
|
||||
let alpha = transcript.squeeze_challenge();
|
||||
challenges.push(alpha);
|
||||
|
||||
// update prover state polynomials
|
||||
virtual_poly.fold_into_half(alpha);
|
||||
if round_index == pp.num_vars - 1 {
|
||||
// last round
|
||||
evaluations = virtual_poly.evaluations(alpha);
|
||||
} else {
|
||||
// update prover state polynomials
|
||||
virtual_poly.fold_into_half(alpha);
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..virtual_poly.polys().len() {
|
||||
assert_eq!(virtual_poly.polys()[i].size(), 1);
|
||||
}
|
||||
|
||||
Ok((vec![], vec![]))
|
||||
Ok((challenges, evaluations))
|
||||
}
|
||||
|
||||
fn verify(
|
||||
@@ -86,7 +138,7 @@ mod test {
|
||||
|
||||
use crate::{
|
||||
poly::multilinear::MultilinearPolynomial,
|
||||
sumcheck::{EvalTable, SumCheck, VirtualPolynomial},
|
||||
sumcheck::{EvalTable, SumCheck, VirtualPolynomial, classic::CombineFunction},
|
||||
utils::{transcript::Keccak256Transcript, ProtocolError},
|
||||
};
|
||||
use ff::Field;
|
||||
@@ -151,10 +203,10 @@ mod test {
|
||||
let pp = ClassicSumcheckProverParam {
|
||||
num_vars,
|
||||
max_degree: 3,
|
||||
combine_function,
|
||||
combine_function: CombineFunction::new(combine_function),
|
||||
};
|
||||
let mut transcript = Keccak256Transcript::<Cursor<Vec<u8>>>::default();
|
||||
let virtual_poly = VirtualPolynomial::new(num_vars, &polys);
|
||||
let virtual_poly = VirtualPolynomial::new(num_vars, polys.iter().collect_vec().borrow());
|
||||
ClassicSumcheck::prove(&pp, claimed_sum, virtual_poly, &mut transcript)?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::fmt::Debug;
|
||||
use std::{borrow::Borrow, fmt::Debug};
|
||||
|
||||
use ff::Field;
|
||||
use itertools::Itertools;
|
||||
@@ -48,9 +48,13 @@ impl<F: Field> EvalTable<F> {
|
||||
}
|
||||
|
||||
pub fn fold_into_half(&mut self, challenge: F) {
|
||||
for i in 0..self.table.len() / 2 {
|
||||
let eval_pair_distant = self.table[i + self.table.len() / 2].clone();
|
||||
let eval_pair = &mut self.table[i];
|
||||
assert_ne!(self.table.len(), 1);
|
||||
let len = self.table.len();
|
||||
let (pairs, pairs_distant) = self.table.split_at_mut(len / 2);
|
||||
let pairs_distant = &*pairs_distant;
|
||||
for i in 0..len / 2 {
|
||||
let eval_pair_distant = &pairs_distant[i];
|
||||
let eval_pair = &mut pairs[i];
|
||||
eval_pair.even = eval_pair.even + challenge * (eval_pair.odd - eval_pair.even);
|
||||
eval_pair.odd = eval_pair_distant.even
|
||||
+ challenge * (eval_pair_distant.odd - eval_pair_distant.even);
|
||||
@@ -65,7 +69,7 @@ pub struct VirtualPolynomial<F: Field> {
|
||||
}
|
||||
|
||||
impl<F: Field> VirtualPolynomial<F> {
|
||||
pub fn new(num_vars: usize, polys: &Vec<MultilinearPolynomial<F>>) -> Self {
|
||||
pub fn new(num_vars: usize, polys: &[&MultilinearPolynomial<F>]) -> Self {
|
||||
let polys = polys
|
||||
.iter()
|
||||
.map(|poly| EvalTable::new(num_vars, poly))
|
||||
@@ -82,6 +86,19 @@ impl<F: Field> VirtualPolynomial<F> {
|
||||
poly.fold_into_half(challenge);
|
||||
}
|
||||
}
|
||||
|
||||
/// called at the last round of sumcheck
|
||||
pub fn evaluations(&self, challenge: F) -> Vec<F> {
|
||||
self.polys.iter().for_each(|poly| {
|
||||
assert_eq!(poly.size(), 1);
|
||||
});
|
||||
self.polys
|
||||
.iter()
|
||||
.map(|poly| {
|
||||
poly.table()[0].even + challenge * (poly.table()[0].odd - poly.table()[0].even)
|
||||
})
|
||||
.collect_vec()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SumCheck<F: Field>: Clone + Debug {
|
||||
@@ -101,5 +118,5 @@ pub trait SumCheck<F: Field>: Clone + Debug {
|
||||
degree: usize,
|
||||
sum: F,
|
||||
transcript: &mut impl FieldTranscriptRead<F>,
|
||||
) -> Result<(Vec<F>, Vec<F>), ProtocolError>;
|
||||
) -> Result<(F, Vec<F>), ProtocolError>;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user