Write small test

This commit is contained in:
DoHoonKim
2024-05-11 02:07:15 +09:00
parent e0ed47894a
commit 162e89933a

View File

@@ -27,7 +27,7 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
sum: F,
mut virtual_poly: VirtualPolynomial<F>,
transcript: &mut impl FieldTranscriptWrite<F>,
) -> Result<(F, Vec<F>), ProtocolError> {
) -> Result<(Vec<F>, Vec<F>), ProtocolError> {
// Declare r_polys and initialise it with 0s
let r_degree = pp.max_degree;
let mut r_polys: Vec<Vec<F>> = (0..pp.num_vars)
@@ -62,7 +62,12 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
// update prover state polynomials
virtual_poly.fold_into_half(alpha);
}
Ok((F::ONE, vec![]))
for i in 0..virtual_poly.polys().len() {
assert_eq!(virtual_poly.polys()[i].size(), 1);
}
Ok((vec![], vec![]))
}
fn verify(
@@ -70,19 +75,25 @@ impl<F: PrimeField> SumCheck<F> for ClassicSumcheck {
degree: usize,
sum: F,
transcript: &mut impl FieldTranscriptRead<F>,
) -> Result<(F, Vec<F>), ProtocolError> {
Ok((F::ONE, vec![]))
) -> Result<(Vec<F>, Vec<F>), ProtocolError> {
Ok((vec![], vec![]))
}
}
#[cfg(test)]
mod test {
use std::{borrow::Borrow, io::Cursor, iter};
use crate::{
poly::multilinear::MultilinearPolynomial,
sumcheck::EvalTable,
sumcheck::{EvalTable, SumCheck, VirtualPolynomial},
utils::{transcript::Keccak256Transcript, ProtocolError},
};
use ff::Field;
use halo2curves::bn256::Fr;
use itertools::Itertools;
use super::{ClassicSumcheck, ClassicSumcheckProverParam};
#[test]
fn test_fold_into_half() {
@@ -110,4 +121,42 @@ mod test {
assert_eq!(eval_table.table()[i].odd, expected_odd);
}
}
#[test]
fn test_sumcheck() -> Result<(), ProtocolError> {
// Take a simple polynomial
let num_vars = 3;
let evals = (0..1 << num_vars)
.map(|_| crate::utils::random_fe::<Fr>())
.collect_vec();
let polys = iter::repeat(MultilinearPolynomial::new(evals, vec![], num_vars))
.take(3)
.collect_vec();
let combine_function = |evals: &Vec<Fr>| evals.iter().product();
let claimed_sum = {
(0..polys[0].evals().len())
.map(|idx| {
combine_function(
polys
.iter()
.map(|poly| poly.evals()[idx])
.collect_vec()
.borrow(),
)
})
.sum()
};
// Prover
let pp = ClassicSumcheckProverParam {
num_vars,
max_degree: 3,
combine_function,
};
let mut transcript = Keccak256Transcript::<Cursor<Vec<u8>>>::default();
let virtual_poly = VirtualPolynomial::new(num_vars, &polys);
ClassicSumcheck::prove(&pp, claimed_sum, virtual_poly, &mut transcript)?;
Ok(())
}
}