This commit is contained in:
DoHoonKim
2024-05-08 14:38:05 +09:00
parent b2f21dec81
commit e664e2f3cd
2 changed files with 47 additions and 61 deletions

View File

@@ -1,5 +1,4 @@
use super::{SumCheck, VirtualPolynomial};
use crate::poly::multilinear::MultilinearPolynomial;
use crate::utils::transcript::{FieldTranscriptRead, FieldTranscriptWrite};
use crate::utils::ProtocolError;
use ff::{Field, PrimeField};
@@ -10,7 +9,6 @@ struct ClassicSumcheck;
#[derive(Clone, Debug)]
struct ClassicSumcheckProverParam<F: Field> {
num_vars: usize,
round: usize,
max_degree: usize,
combine_function: fn(&Vec<F>) -> F,
}
@@ -26,9 +24,8 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
fn prove(
pp: &Self::ProverParam,
num_vars: usize,
sum: F,
mut virtual_polys: Vec<VirtualPolynomial<F>>,
mut virtual_poly: VirtualPolynomial<F>,
transcript: &mut impl FieldTranscriptWrite<F>,
) -> Result<(F, Vec<F>), ProtocolError> {
// Declare r_polys and initialise it with 0s
@@ -40,14 +37,14 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
transcript.write_field_element(&sum)?;
for round_index in 0..pp.num_vars {
let virtual_polynomial_len = virtual_polys[0].evals().len();
for k in 0..(r_degree + 1) {
for i in 0..virtual_polynomial_len {
let evaluations_at_k = virtual_polys
for i in 0..virtual_poly.polys()[0].size() {
let evaluations_at_k = virtual_poly
.polys()
.iter()
.map(|virtual_poly| {
let o = virtual_poly.evals()[i].odd;
let e = virtual_poly.evals()[i].even;
.map(|poly| {
let o = poly.table()[i].odd;
let e = poly.table()[i].even;
e + F::from(k as u64) * (o - e)
})
.collect::<Vec<F>>();
@@ -63,16 +60,13 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
let alpha = transcript.squeeze_challenge();
// update prover state polynomials
for j in 0..virtual_polys.len() {
virtual_polys[j].fold_into_half(alpha);
}
virtual_poly.fold_into_half(alpha);
}
Ok((F::ONE, vec![]))
}
fn verify(
vp: &Self::VerifierParam,
num_vars: usize,
degree: usize,
sum: F,
transcript: &mut impl FieldTranscriptRead<F>,
@@ -83,33 +77,37 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
#[cfg(test)]
mod test {
use crate::sumcheck::classic::*;
use crate::{
poly::multilinear::MultilinearPolynomial,
sumcheck::EvalTable,
};
use ff::Field;
use halo2curves::bn256::Fr;
#[test]
fn test_fold_into_half() {
let num_vars = 3;
let evals_list = (0..1 << num_vars)
let num_vars = 16;
let evals = (0..1 << num_vars)
.map(|_| crate::utils::random_fe())
.collect::<Vec<Fr>>();
let poly = MultilinearPolynomial::new(evals, vec![], num_vars);
let mut virtual_poly: VirtualPolynomial<Fr> = VirtualPolynomial::new(num_vars, evals_list);
let evals = virtual_poly.evals().clone();
let size_before = virtual_poly.size();
let mut eval_table = EvalTable::new(num_vars, &poly);
let evals = eval_table.table().clone();
let size_before = eval_table.size();
let alpha: Fr = crate::utils::random_fe();
virtual_poly.fold_into_half(alpha);
let size_after = virtual_poly.size();
eval_table.fold_into_half(alpha);
let size_after = eval_table.size();
assert_eq!(2 * size_after, size_before);
for i in 0..virtual_poly.size() {
for i in 0..eval_table.size() {
let expected_even = (Fr::ONE - alpha) * evals[i].even + alpha * evals[i].odd;
let expected_odd =
(Fr::ONE - alpha) * evals[size_after + i].even + alpha * evals[size_after + i].odd;
assert_eq!(virtual_poly.evals()[i].even, expected_even);
assert_eq!(virtual_poly.evals()[i].odd, expected_odd);
assert_eq!(eval_table.table()[i].even, expected_even);
assert_eq!(eval_table.table()[i].odd, expected_odd);
}
}
}

View File

@@ -2,12 +2,13 @@ use std::fmt::Debug;
use ff::Field;
use itertools::Itertools;
use rand::RngCore;
use crate::utils::{
random_fe,
transcript::{FieldTranscriptRead, FieldTranscriptWrite},
ProtocolError,
use crate::{
poly::multilinear::MultilinearPolynomial,
utils::{
transcript::{FieldTranscriptRead, FieldTranscriptWrite},
ProtocolError,
},
};
pub mod classic;
@@ -20,72 +21,59 @@ pub(super) struct EvalPair<F: Field> {
}
#[derive(Clone, Debug)]
pub struct VirtualPolynomial<F: Field> {
pub(super) struct EvalTable<F: Field> {
num_vars: usize,
evals: Vec<EvalPair<F>>,
table: Vec<EvalPair<F>>,
}
impl<F: Field> VirtualPolynomial<F> {
pub fn new(num_vars: usize, evals: Vec<F>) -> Self {
assert_eq!(evals.len(), 1 << num_vars);
let evals = evals[..1 << (num_vars - 1)]
impl<F: Field> EvalTable<F> {
pub fn new(num_vars: usize, poly: &MultilinearPolynomial<F>) -> Self {
assert_eq!(poly.evals().len(), 1 << num_vars);
let table = poly
.iter()
.take(1 << (num_vars - 1))
.copied()
.zip(evals[(1 << (num_vars - 1))..].iter().copied())
.zip(poly.iter().skip(1 << (num_vars - 1)).copied())
.map(|(even, odd)| EvalPair { even, odd })
.collect_vec();
Self { num_vars, evals }
Self { num_vars, table }
}
pub fn size(&self) -> usize {
self.evals.len()
self.table.len()
}
pub(super) fn evals(&self) -> &Vec<EvalPair<F>> {
&self.evals
pub(super) fn table(&self) -> &Vec<EvalPair<F>> {
&self.table
}
pub fn fold_into_half(&mut self, challenge: F) {
for i in 0..self.evals.len() / 2 {
let eval_pair_distant = self.evals[i + self.evals.len() / 2].clone();
let eval_pair = &mut self.evals[i];
for i in 0..self.table.len() / 2 {
let eval_pair_distant = self.table[i + self.table.len() / 2].clone();
let eval_pair = &mut self.table[i];
eval_pair.even = eval_pair.even + challenge * (eval_pair.odd - eval_pair.even);
eval_pair.odd = eval_pair_distant.even
+ challenge * (eval_pair_distant.odd - eval_pair_distant.even);
}
self.num_vars -= 1;
self.evals.truncate(self.evals.len() / 2);
}
pub fn get_random_virtual(num_vars: usize) -> Self {
Self {
num_vars,
evals: (0..1 << num_vars)
.map(|_| EvalPair {
even: random_fe(),
odd: random_fe(),
})
.collect(),
}
self.table.truncate(self.table.len() / 2);
}
}
pub struct VirtualPolynomial<F: Field> {
num_vars: usize,
polys: Vec<EvalTable<F>>,
}
impl<F: Field> VirtualPolynomial<F> {
pub fn new(num_vars: usize, polys: &Vec<MultilinearPolynomial<F>>) -> Self {
assert_eq!(polys[0].evals().len(), 1 << num_vars);
let polys = polys
.iter()
.map(|poly| EvalTable::new(num_vars, poly))
.collect_vec();
Self { num_vars, polys }
Self { polys }
}
pub fn polys(&self) -> &Vec<EvalTable<F>> {
pub(super) fn polys(&self) -> &Vec<EvalTable<F>> {
&self.polys
}