diff --git a/sumcheck/src/gpu/cuda/kernels/sumcheck.cu b/sumcheck/src/gpu/cuda/kernels/sumcheck.cu index 12061e3..dca7105 100644 --- a/sumcheck/src/gpu/cuda/kernels/sumcheck.cu +++ b/sumcheck/src/gpu/cuda/kernels/sumcheck.cu @@ -32,7 +32,7 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign } extern "C" __global__ void fold_into_half( - unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge + unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge ) { int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; const int stride = 1 << (num_vars - 1); @@ -47,7 +47,7 @@ extern "C" __global__ void fold_into_half( } extern "C" __global__ void fold_into_half_in_place( - unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge + unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, const fr* challenge ) { int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; const int stride = 1 << (num_vars - 1); diff --git a/sumcheck/src/gpu/sumcheck.rs b/sumcheck/src/gpu/sumcheck.rs index 3e622f7..8ffc8f6 100644 --- a/sumcheck/src/gpu/sumcheck.rs +++ b/sumcheck/src/gpu/sumcheck.rs @@ -1,7 +1,8 @@ use std::cell::{RefCell, RefMut}; -use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig}; +use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig}; use ff::PrimeField; +use itertools::Itertools; use crate::{ fieldbinding::{FromFieldBinding, ToFieldBinding}, @@ -16,6 +17,7 @@ impl + ToFieldBinding> GPUApiWrapper { max_degree: usize, sum: F, polys: &mut CudaViewMut, + device_ks: &[CudaView], buf: RefCell>, challenges: &mut CudaViewMut, round_evals: RefCell>, @@ -28,6 +30,7 @@ impl + ToFieldBinding> GPUApiWrapper { max_degree, num_polys, &polys.slice(..), + device_ks, buf.borrow_mut(), round_evals.borrow_mut(), )?; @@ -52,15 +55,13 @@ impl + ToFieldBinding> GPUApiWrapper { max_degree: usize, num_polys: usize, polys: &CudaView, + device_ks: &[CudaView], mut buf: RefMut>, mut round_evals: RefMut>, ) -> Result<(), DriverError> { let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?; let num_threads_per_block = 1024; for k in 0..max_degree + 1 { - let device_k = self - .gpu - .htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])?; let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap(); let launch_config = LaunchConfig { grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1), @@ -76,7 +77,7 @@ impl + ToFieldBinding> GPUApiWrapper { num_blocks_per_poly, polys, &mut *buf, - &device_k, + &device_ks[k], ), )?; }; @@ -173,7 +174,7 @@ mod tests { use itertools::Itertools; use rand::rngs::OsRng; - use crate::{cpu, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX}; + use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX}; #[test] fn test_eval_at_k_and_combine() -> Result<(), DriverError> { @@ -228,6 +229,13 @@ mod tests { // copy polynomials to device let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?; + let device_ks = (0..max_degree + 1) + .map(|k| { + gpu_api_wrapper + .gpu + .htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))]) + }) + .collect::, _>>()?; let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; let buf_view = RefCell::new(buf.slice_mut(..)); let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?; @@ -240,6 +248,10 @@ mod tests { max_degree as usize, num_polys, &gpu_polys.slice(..), + &device_ks + .iter() + .map(|device_k| device_k.slice(..)) + .collect_vec(), buf_view.borrow_mut(), round_evals_view.borrow_mut(), )?; @@ -248,7 +260,7 @@ mod tests { now.elapsed() ); let gpu_round_evals = gpu_api_wrapper - .dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?; + .dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?; cpu_round_evals .iter() .zip_eq(gpu_round_evals.iter()) @@ -312,7 +324,7 @@ mod tests { let gpu_result = (0..num_polys) .map(|i| { gpu_api_wrapper.dtoh_sync_copy( - gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)), + &gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)), true, ) }) @@ -338,9 +350,9 @@ mod tests { #[test] fn test_prove_sumcheck() -> Result<(), DriverError> { - let num_vars = 23; - let num_polys = 4; - let max_degree = 4; + let num_vars = 25; + let num_polys = 2; + let max_degree = 2; let rng = OsRng::default(); let polys = (0..num_polys) @@ -380,6 +392,13 @@ mod tests { let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| { acc + polys.iter().map(|poly| poly[index]).product::() }); + let device_ks = (0..max_degree + 1) + .map(|k| { + gpu_api_wrapper + .gpu + .htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))]) + }) + .collect::, _>>()?; let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; let buf_view = RefCell::new(buf.slice_mut(..)); @@ -395,20 +414,25 @@ mod tests { max_degree, sum, &mut gpu_polys.slice_mut(..), + &device_ks + .iter() + .map(|device_k| device_k.slice(..)) + .collect_vec(), buf_view, &mut challenges.slice_mut(..), round_evals_view, )?; + gpu_api_wrapper.gpu.synchronize()?; println!( "Time taken to prove sumcheck on gpu : {:.2?}", now.elapsed() ); - let challenges = gpu_api_wrapper.dtoh_sync_copy(challenges.slice(..), true)?; + let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?; let round_evals = (0..num_vars) .map(|i| { gpu_api_wrapper.dtoh_sync_copy( - round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)), + &round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)), true, ) }) diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index f2e1a12..2675a64 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -78,10 +78,10 @@ impl + ToFieldBinding> GPUApiWrapper { pub fn dtoh_sync_copy( &self, - device_data: CudaView, + device_data: &CudaView, convert_from_montgomery_form: bool, ) -> Result, DriverError> { - let host_data = self.gpu.dtoh_sync_copy(&device_data)?; + let host_data = self.gpu.dtoh_sync_copy(device_data)?; let mut target = vec![F::ZERO; host_data.len()]; if convert_from_montgomery_form { parallelize(&mut target, |(target, start)| {