Fix use of transcript

This commit is contained in:
Soowon Jeong
2024-11-12 09:43:29 +00:00
parent 38bd64fc72
commit 58faedebeb
3 changed files with 35 additions and 19 deletions

View File

@@ -4,6 +4,7 @@ mod test {
use crate::pcs::PolynomialCommitmentScheme;
use crate::poly::multilinear::MultilinearPolynomial;
use crate::sumcheck::classic::ClassicSumcheck;
use crate::sumcheck::cuda::CudaSumcheck;
use crate::utils::{end_timer, start_timer};
use crate::utils::{
random_fe,
@@ -15,8 +16,8 @@ mod test {
use std::cmp::max;
use std::io::Cursor;
type ClookupProver = Prover<Fr, MultilinearKzg<Bn256>, ClassicSumcheck>;
type ClookupVerifier = Verifier<Fr, MultilinearKzg<Bn256>, ClassicSumcheck>;
type ClookupProver = Prover<Fr, MultilinearKzg<Bn256>, CudaSumcheck>;
type ClookupVerifier = Verifier<Fr, MultilinearKzg<Bn256>, CudaSumcheck>;
#[test]
pub fn test_clookup() -> Result<(), ProtocolError> {

View File

@@ -14,6 +14,7 @@ use cuda_sumcheck::{
use cudarc::nvrtc::Ptx;
use ff::PrimeField;
use itertools::Itertools;
use sha3::Keccak256;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
const MULTILINEAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
@@ -58,7 +59,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
pp: &Self::ProverParam,
combine_function: &impl Fn(&Vec<F>) -> F,
sum: F,
mut virtual_poly: VirtualPolynomial<F>,
virtual_poly: VirtualPolynomial<F>,
transcript: &mut impl FieldTranscriptWrite<F>,
) -> Result<(Vec<F>, Vec<F>), ProtocolError> {
let mut challenges = vec![];
@@ -86,7 +87,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
],
)
.map_err(|_| ProtocolError::CudaLibraryError)?;
let mut transcript = Keccak256Transcript::<F>::new();
let mut cuda_transcript = Keccak256Transcript::<F>::new();
let polys: Vec<Vec<F>> = virtual_poly
.polys()
.iter()
@@ -133,26 +134,40 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
buf_view,
&mut challenges_cuda,
round_evals_view,
&mut transcript,
&mut challenges,
&mut cuda_transcript,
)
.map_err(|_| ProtocolError::CudaLibraryError)?;
gpu_api_wrapper
.gpu
.synchronize()
.map_err(|_| ProtocolError::CudaLibraryError)?;
.gpu
.synchronize()
.map_err(|_| ProtocolError::CudaLibraryError)?;
let evaluations = (0..polys.len())
.map(|i| {
gpu_api_wrapper.dtoh_sync_copy(
&gpu_polys.slice(i << pp.num_vars..i << pp.num_vars + 1),
true,
).map(|res| res.first().unwrap().clone()).map_err(|_| ProtocolError::CudaLibraryError)
})
.collect::<Result<Vec<F>, _>>()
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map(|i| {
gpu_api_wrapper
.dtoh_sync_copy(
&gpu_polys.slice(i << pp.num_vars..(i * 2 + 1) << (pp.num_vars - 1)),
true,
)
.map(|res| res.first().unwrap().clone())
.map_err(|_| ProtocolError::CudaLibraryError)
})
.collect::<Result<Vec<F>, _>>()
.map_err(|_| ProtocolError::CudaLibraryError)?;
let mut cuda_transcript = cuda_sumcheck::transcript::Transcript::<Keccak256, F>::from_proof(
cuda_transcript.into_proof().as_slice(),
);
for _ in 0..pp.num_vars {
transcript.write_field_elements(
cuda_transcript
.read_field_elements(pp.max_degree + 1)
.unwrap()
.iter(),
)?;
challenges.push(transcript.squeeze_challenge());
}
transcript.write_field_elements(evaluations.iter())?;
challenges.reverse();
Ok((challenges, evaluations))