mirror of
https://github.com/pseXperiments/clookup.git
synced 2026-01-08 23:28:10 -05:00
Integrate, fix, cuda prover
This commit is contained in:
Submodule cuda-sumcheck updated: 6ee13529f9...3cae446d1b
@@ -1,3 +1,4 @@
|
||||
pub mod cuda_prover;
|
||||
pub mod precomputation;
|
||||
pub mod prover;
|
||||
pub mod test;
|
||||
|
||||
114
src/core/cuda_prover.rs
Normal file
114
src/core/cuda_prover.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use super::precomputation::Table;
|
||||
use crate::{
|
||||
pcs::{Evaluation, PolynomialCommitmentScheme},
|
||||
poly::multilinear::MultilinearPolynomial,
|
||||
sumcheck::{cuda::CudaSumcheck, SumCheck, VirtualPolynomial},
|
||||
utils::{
|
||||
arithmetic::powers, end_timer, start_timer, transcript::TranscriptWrite, transpose,
|
||||
ProtocolError,
|
||||
},
|
||||
};
|
||||
use cuda_sumcheck::fieldbinding::{FromFieldBinding, ToFieldBinding};
|
||||
use ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::RngCore;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
|
||||
use std::{cmp::max, hash::Hash, iter, marker::PhantomData};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CudaProver<
|
||||
F: PrimeField + Hash,
|
||||
Pcs: PolynomialCommitmentScheme<F, Polynomial = MultilinearPolynomial<F>>,
|
||||
>(PhantomData<F>, PhantomData<Pcs>);
|
||||
|
||||
impl<
|
||||
F: PrimeField + Hash + FromFieldBinding<F> + ToFieldBinding<F>,
|
||||
Pcs: PolynomialCommitmentScheme<F, Polynomial = MultilinearPolynomial<F>>,
|
||||
> CudaProver<F, Pcs>
|
||||
{
|
||||
pub fn setup(
|
||||
table: &Table<F>,
|
||||
witness: &Vec<F>,
|
||||
rng: impl RngCore,
|
||||
) -> Result<Pcs::Param, ProtocolError> {
|
||||
let poly_size = max(table.len(), witness.len());
|
||||
let batch_size = 1 + 1 + table.num_vars();
|
||||
Pcs::setup(poly_size, batch_size, rng)
|
||||
}
|
||||
|
||||
fn sigma_polys(
|
||||
table: &Table<F>,
|
||||
witness: &Vec<F>,
|
||||
) -> Result<Vec<MultilinearPolynomial<F>>, ProtocolError> {
|
||||
let indices = table.find_indices(&witness)?;
|
||||
let sigma: Vec<MultilinearPolynomial<F>> = transpose(indices)
|
||||
.par_iter()
|
||||
.map(|idx| MultilinearPolynomial::eval_to_coeff(idx, idx.len().ilog2() as usize))
|
||||
.collect();
|
||||
Ok(sigma)
|
||||
}
|
||||
|
||||
fn h_function<'a>(
|
||||
table_poly: &'a MultilinearPolynomial<F>,
|
||||
gamma: F,
|
||||
) -> impl Fn(&Vec<F>) -> F + 'a {
|
||||
move |evals: &Vec<F>| {
|
||||
gamma
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prove(
|
||||
pp: &Pcs::ProverParam,
|
||||
transcript: &mut impl TranscriptWrite<Pcs::CommitmentChunk, F>,
|
||||
table: &Table<F>,
|
||||
witness: &Vec<F>,
|
||||
) -> Result<(), ProtocolError> {
|
||||
let witness_poly =
|
||||
MultilinearPolynomial::new(witness.clone(), vec![], witness.len().ilog2() as usize);
|
||||
let table_poly = table.polynomial();
|
||||
let num_vars = witness_poly.num_vars();
|
||||
let max_degree = 2;
|
||||
// get sigma_polys
|
||||
let timer = start_timer(|| "sigma_polys");
|
||||
let sigma_polys = Self::sigma_polys(table, witness)?;
|
||||
end_timer(timer);
|
||||
// commit to sigma_polys, witness polys, table polys
|
||||
let witness_poly_comm = Pcs::commit_and_write(pp, &witness_poly, transcript)?;
|
||||
let sigma_polys_comms = Pcs::batch_commit_and_write(pp, &sigma_polys, transcript)?;
|
||||
|
||||
// squeeze challenges
|
||||
let gamma = transcript.squeeze_challenge();
|
||||
let ys = transcript.squeeze_challenges(num_vars);
|
||||
let eq = MultilinearPolynomial::eq_xy(&ys);
|
||||
let h_function = Self::h_function(&table_poly, gamma);
|
||||
// proceed sumcheck
|
||||
let (x, evals) = {
|
||||
let virtual_poly = VirtualPolynomial::new(
|
||||
num_vars,
|
||||
iter::once(&witness_poly)
|
||||
.chain(sigma_polys.iter())
|
||||
.chain(&[eq])
|
||||
.collect_vec()
|
||||
.as_ref(),
|
||||
);
|
||||
let pp = <CudaSumcheck as SumCheck<F>>::generate_pp(num_vars, max_degree)?;
|
||||
CudaSumcheck::prove(&pp, &h_function, F::ZERO, virtual_poly, transcript)?
|
||||
};
|
||||
// open polynomials at x
|
||||
let witness_poly_x = evals.first().unwrap();
|
||||
let sigma_polys_x = evals
|
||||
.iter()
|
||||
.skip(1)
|
||||
.take(table_poly.num_vars())
|
||||
.collect_vec();
|
||||
let polys = iter::once(&witness_poly).chain(sigma_polys.iter());
|
||||
let comms = iter::once(&witness_poly_comm).chain(sigma_polys_comms.iter());
|
||||
let points = iter::repeat(x).take(1 + sigma_polys.len()).collect_vec();
|
||||
let evals = iter::once(witness_poly_x)
|
||||
.chain(sigma_polys_x)
|
||||
.enumerate()
|
||||
.map(|(poly, value)| Evaluation::new(poly, 0, *value))
|
||||
.collect_vec();
|
||||
Pcs::batch_open(pp, polys, comms, &points, &evals, transcript)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
mod test {
|
||||
use crate::core::cuda_prover::CudaProver;
|
||||
use crate::core::{precomputation::Table, prover::Prover, verifier::Verifier};
|
||||
use crate::pcs::multilinear::kzg::MultilinearKzg;
|
||||
use crate::pcs::PolynomialCommitmentScheme;
|
||||
@@ -16,21 +17,22 @@ mod test {
|
||||
use std::cmp::max;
|
||||
use std::io::Cursor;
|
||||
|
||||
type ClookupProver = Prover<Fr, MultilinearKzg<Bn256>, CudaSumcheck>;
|
||||
type ClookupProver = CudaProver<Fr, MultilinearKzg<Bn256>>;
|
||||
type ClookupVerifier = Verifier<Fr, MultilinearKzg<Bn256>, CudaSumcheck>;
|
||||
|
||||
#[test]
|
||||
pub fn test_clookup() -> Result<(), ProtocolError> {
|
||||
let table_dim = 8;
|
||||
let witness_dim = 4;
|
||||
let table_vec: Vec<Fr> = (0..1 << table_dim).map(|_| random_fe()).collect_vec();
|
||||
// Range table 0..1 << table_dim - 1
|
||||
let table_vec: Vec<Fr> = (0..1 << table_dim).map(|i| Fr::from(i)).collect_vec();
|
||||
let witness_vec = table_vec
|
||||
.iter()
|
||||
.take(1 << witness_dim)
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
let table: Table<Fr> = table_vec.try_into()?;
|
||||
let max_degree = 1 + max(2, table_dim);
|
||||
let max_degree = 3;
|
||||
let (pp, vp) = {
|
||||
let rng = rand::thread_rng();
|
||||
let param = ClookupProver::setup(&table, &witness_vec, rng)?;
|
||||
|
||||
@@ -64,7 +64,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
) -> Result<(Vec<F>, Vec<F>), ProtocolError> {
|
||||
let mut challenges = vec![];
|
||||
let mut gpu_api_wrapper =
|
||||
GPUApiWrapper::<F>::setup().map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
GPUApiWrapper::<F>::setup().map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
@@ -72,7 +72,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
@@ -86,8 +86,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
"sum",
|
||||
],
|
||||
)
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
let mut cuda_transcript = Keccak256Transcript::<F>::new();
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
let polys: Vec<Vec<F>> = virtual_poly
|
||||
.polys()
|
||||
.iter()
|
||||
@@ -96,7 +95,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
|
||||
let mut gpu_polys = gpu_api_wrapper
|
||||
.copy_to_device(&polys.concat())
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
let device_ks = (0..pp.max_degree + 1)
|
||||
.map(|k| {
|
||||
gpu_api_wrapper
|
||||
@@ -104,22 +103,28 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
|
||||
let mut buf = gpu_api_wrapper
|
||||
.malloc_on_device(polys.len() << (pp.num_vars - 1))
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
|
||||
let mut challenges_cuda = gpu_api_wrapper
|
||||
.malloc_on_device(1)
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
|
||||
let mut round_evals = gpu_api_wrapper
|
||||
.malloc_on_device(pp.max_degree + 1)
|
||||
.map_err(|e| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
|
||||
|
||||
// This is shit code
|
||||
let gamma = gpu_api_wrapper
|
||||
.gpu
|
||||
.htod_copy(vec![F::to_montgomery_form(combine_function(&vec![]))])
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
|
||||
gpu_api_wrapper
|
||||
.prove_sumcheck(
|
||||
pp.num_vars,
|
||||
@@ -135,13 +140,14 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
&mut challenges_cuda,
|
||||
round_evals_view,
|
||||
&mut cuda_transcript,
|
||||
&gamma.slice(..),
|
||||
)
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(String::from("library error")))?;
|
||||
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.synchronize()
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))?;
|
||||
|
||||
let evaluations = (0..polys.len())
|
||||
.map(|i| {
|
||||
@@ -151,10 +157,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
true,
|
||||
)
|
||||
.map(|res| res.first().unwrap().clone())
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(e.to_string()))
|
||||
})
|
||||
.collect::<Result<Vec<F>, _>>()
|
||||
.map_err(|_| ProtocolError::CudaLibraryError)?;
|
||||
.map_err(|e| ProtocolError::CudaLibraryError(String::from("")))?;
|
||||
let mut cuda_transcript = cuda_sumcheck::transcript::Transcript::<Keccak256, F>::from_proof(
|
||||
cuda_transcript.into_proof().as_slice(),
|
||||
);
|
||||
@@ -170,6 +176,8 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
transcript.write_field_elements(evaluations.iter())?;
|
||||
challenges.reverse();
|
||||
|
||||
println!("{:?}", challenges);
|
||||
|
||||
Ok((challenges, evaluations))
|
||||
}
|
||||
|
||||
@@ -190,6 +198,8 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
(msgs, challenges)
|
||||
};
|
||||
|
||||
|
||||
|
||||
let evaluations = transcript.read_field_elements(num_polys)?;
|
||||
let mut expected_sum = sum.clone();
|
||||
let points_vec: Vec<F> = (0..vp.max_degree + 1)
|
||||
@@ -212,6 +222,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> SumCheck<F> for Cu
|
||||
|
||||
// Check r_{i}(α_i) == r_{i+1}(0) + r_{i+1}(1)
|
||||
if computed_sum != expected_sum {
|
||||
println!("round : {:?}", round_index);
|
||||
return Err(ProtocolError::InvalidSumcheck(format!(
|
||||
"computed sum != expected sum"
|
||||
)));
|
||||
|
||||
@@ -15,7 +15,7 @@ pub enum ProtocolError {
|
||||
InvalidSumcheck(String),
|
||||
InvalidPcsParam(String),
|
||||
InvalidPcsOpen(String),
|
||||
CudaLibraryError,
|
||||
CudaLibraryError(String),
|
||||
SizeError,
|
||||
NotInclusion,
|
||||
Transcript(std::io::ErrorKind, String),
|
||||
|
||||
Reference in New Issue
Block a user