Integrate, fix, cuda prover

This commit is contained in:
DoHoonKim8
2024-11-13 11:59:25 +00:00
parent 58faedebeb
commit c60482924f
6 changed files with 146 additions and 18 deletions

View File

@@ -1,3 +1,4 @@
pub mod cuda_prover;
pub mod precomputation;
pub mod prover;
pub mod test;

114
src/core/cuda_prover.rs Normal file
View File

@@ -0,0 +1,114 @@
use super::precomputation::Table;
use crate::{
pcs::{Evaluation, PolynomialCommitmentScheme},
poly::multilinear::MultilinearPolynomial,
sumcheck::{cuda::CudaSumcheck, SumCheck, VirtualPolynomial},
utils::{
arithmetic::powers, end_timer, start_timer, transcript::TranscriptWrite, transpose,
ProtocolError,
},
};
use cuda_sumcheck::fieldbinding::{FromFieldBinding, ToFieldBinding};
use ff::PrimeField;
use itertools::Itertools;
use rand::RngCore;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use std::{cmp::max, hash::Hash, iter, marker::PhantomData};
#[derive(Clone, Debug)]
pub struct CudaProver<
F: PrimeField + Hash,
Pcs: PolynomialCommitmentScheme<F, Polynomial = MultilinearPolynomial<F>>,
>(PhantomData<F>, PhantomData<Pcs>);
impl<
F: PrimeField + Hash + FromFieldBinding<F> + ToFieldBinding<F>,
Pcs: PolynomialCommitmentScheme<F, Polynomial = MultilinearPolynomial<F>>,
> CudaProver<F, Pcs>
{
pub fn setup(
table: &Table<F>,
witness: &Vec<F>,
rng: impl RngCore,
) -> Result<Pcs::Param, ProtocolError> {
let poly_size = max(table.len(), witness.len());
let batch_size = 1 + 1 + table.num_vars();
Pcs::setup(poly_size, batch_size, rng)
}
fn sigma_polys(
table: &Table<F>,
witness: &Vec<F>,
) -> Result<Vec<MultilinearPolynomial<F>>, ProtocolError> {
let indices = table.find_indices(&witness)?;
let sigma: Vec<MultilinearPolynomial<F>> = transpose(indices)
.par_iter()
.map(|idx| MultilinearPolynomial::eval_to_coeff(idx, idx.len().ilog2() as usize))
.collect();
Ok(sigma)
}
fn h_function<'a>(
table_poly: &'a MultilinearPolynomial<F>,
gamma: F,
) -> impl Fn(&Vec<F>) -> F + 'a {
move |evals: &Vec<F>| {
gamma
}
}
pub fn prove(
pp: &Pcs::ProverParam,
transcript: &mut impl TranscriptWrite<Pcs::CommitmentChunk, F>,
table: &Table<F>,
witness: &Vec<F>,
) -> Result<(), ProtocolError> {
let witness_poly =
MultilinearPolynomial::new(witness.clone(), vec![], witness.len().ilog2() as usize);
let table_poly = table.polynomial();
let num_vars = witness_poly.num_vars();
let max_degree = 2;
// get sigma_polys
let timer = start_timer(|| "sigma_polys");
let sigma_polys = Self::sigma_polys(table, witness)?;
end_timer(timer);
// commit to sigma_polys, witness polys, table polys
let witness_poly_comm = Pcs::commit_and_write(pp, &witness_poly, transcript)?;
let sigma_polys_comms = Pcs::batch_commit_and_write(pp, &sigma_polys, transcript)?;
// squeeze challenges
let gamma = transcript.squeeze_challenge();
let ys = transcript.squeeze_challenges(num_vars);
let eq = MultilinearPolynomial::eq_xy(&ys);
let h_function = Self::h_function(&table_poly, gamma);
// proceed sumcheck
let (x, evals) = {
let virtual_poly = VirtualPolynomial::new(
num_vars,
iter::once(&witness_poly)
.chain(sigma_polys.iter())
.chain(&[eq])
.collect_vec()
.as_ref(),
);
let pp = <CudaSumcheck as SumCheck<F>>::generate_pp(num_vars, max_degree)?;
CudaSumcheck::prove(&pp, &h_function, F::ZERO, virtual_poly, transcript)?
};
// open polynomials at x
let witness_poly_x = evals.first().unwrap();
let sigma_polys_x = evals
.iter()
.skip(1)
.take(table_poly.num_vars())
.collect_vec();
let polys = iter::once(&witness_poly).chain(sigma_polys.iter());
let comms = iter::once(&witness_poly_comm).chain(sigma_polys_comms.iter());
let points = iter::repeat(x).take(1 + sigma_polys.len()).collect_vec();
let evals = iter::once(witness_poly_x)
.chain(sigma_polys_x)
.enumerate()
.map(|(poly, value)| Evaluation::new(poly, 0, *value))
.collect_vec();
Pcs::batch_open(pp, polys, comms, &points, &evals, transcript)
}
}

View File

@@ -1,4 +1,5 @@
mod test {
use crate::core::cuda_prover::CudaProver;
use crate::core::{precomputation::Table, prover::Prover, verifier::Verifier};
use crate::pcs::multilinear::kzg::MultilinearKzg;
use crate::pcs::PolynomialCommitmentScheme;
@@ -16,21 +17,22 @@ mod test {
use std::cmp::max;
use std::io::Cursor;
type ClookupProver = Prover<Fr, MultilinearKzg<Bn256>, CudaSumcheck>;
type ClookupProver = CudaProver<Fr, MultilinearKzg<Bn256>>;
type ClookupVerifier = Verifier<Fr, MultilinearKzg<Bn256>, CudaSumcheck>;
#[test]
pub fn test_clookup() -> Result<(), ProtocolError> {
let table_dim = 8;
let witness_dim = 4;
let table_vec: Vec<Fr> = (0..1 << table_dim).map(|_| random_fe()).collect_vec();
// Range table 0..1 << table_dim - 1
let table_vec: Vec<Fr> = (0..1 << table_dim).map(|i| Fr::from(i)).collect_vec();
let witness_vec = table_vec
.iter()
.take(1 << witness_dim)
.cloned()
.collect_vec();
let table: Table<Fr> = table_vec.try_into()?;
let max_degree = 1 + max(2, table_dim);
let max_degree = 3;
let (pp, vp) = {
let rng = rand::thread_rng();
let param = ClookupProver::setup(&table, &witness_vec, rng)?;

View File

@@ -64,7 +64,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
) -> Result<(Vec<F>, Vec<F>), ProtocolError> {
let mut challenges = vec![];
let mut gpu_api_wrapper =
GPUApiWrapper::<F>::setup().map_err(|_| ProtocolError::CudaLibraryError)?;
GPUApiWrapper::<F>::setup().map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
gpu_api_wrapper
.gpu
.load_ptx(
@@ -72,7 +72,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
"multilinear",
&["convert_to_montgomery_form"],
)
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
gpu_api_wrapper
.gpu
@@ -86,8 +86,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
"sum",
],
)
.map_err(|_| ProtocolError::CudaLibraryError)?;
let mut cuda_transcript = Keccak256Transcript::<F>::new();
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let polys: Vec<Vec<F>> = virtual_poly
.polys()
.iter()
@@ -96,7 +95,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
let mut gpu_polys = gpu_api_wrapper
.copy_to_device(&polys.concat())
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let device_ks = (0..pp.max_degree + 1)
.map(|k| {
gpu_api_wrapper
@@ -104,22 +103,28 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let mut buf = gpu_api_wrapper
.malloc_on_device(polys.len() << (pp.num_vars - 1))
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let buf_view = RefCell::new(buf.slice_mut(..));
let mut challenges_cuda = gpu_api_wrapper
.malloc_on_device(1)
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let mut round_evals = gpu_api_wrapper
.malloc_on_device(pp.max_degree + 1)
.map_err(|e| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
// This is shit code
let gamma = gpu_api_wrapper
.gpu
.htod_copy(vec![F::to_montgomery_form(combine_function(&vec![]))])
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
gpu_api_wrapper
.prove_sumcheck(
pp.num_vars,
@@ -135,13 +140,14 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
&mut challenges_cuda,
round_evals_view,
&mut cuda_transcript,
&gamma.slice(..),
)
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(String::from("library error")))?;
gpu_api_wrapper
.gpu
.synchronize()
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
let evaluations = (0..polys.len())
.map(|i| {
@@ -151,10 +157,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
true,
)
.map(|res| res.first().unwrap().clone())
.map_err(|_| ProtocolError::CudaLibraryError)
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))
})
.collect::<Result<Vec<F>, _>>()
.map_err(|_| ProtocolError::CudaLibraryError)?;
.map_err(|e| ProtocolError::CudaLibraryError(String::from("")))?;
let mut cuda_transcript = cuda_sumcheck::transcript::Transcript::<Keccak256, F>::from_proof(
cuda_transcript.into_proof().as_slice(),
);
@@ -170,6 +176,8 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
transcript.write_field_elements(evaluations.iter())?;
challenges.reverse();
println!("{:?}", challenges);
Ok((challenges, evaluations))
}
@@ -190,6 +198,8 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
(msgs, challenges)
};
let evaluations = transcript.read_field_elements(num_polys)?;
let mut expected_sum = sum.clone();
let points_vec: Vec<F> = (0..vp.max_degree + 1)
@@ -212,6 +222,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
// Check r_{i}(α_i) == r_{i+1}(0) + r_{i+1}(1)
if computed_sum != expected_sum {
println!("round : {:?}", round_index);
return Err(ProtocolError::InvalidSumcheck(format!(
"computed sum != expected sum"
)));

View File

@@ -15,7 +15,7 @@ pub enum ProtocolError {
InvalidSumcheck(String),
InvalidPcsParam(String),
InvalidPcsOpen(String),
CudaLibraryError,
CudaLibraryError(String),
SizeError,
NotInclusion,
Transcript(std::io::ErrorKind, String),