diff --git a/sumcheck/src/gpu/cuda/kernels/sumcheck.cu b/sumcheck/src/gpu/cuda/kernels/sumcheck.cu index 0f0c39e..eecdaf4 100644 --- a/sumcheck/src/gpu/cuda/kernels/sumcheck.cu +++ b/sumcheck/src/gpu/cuda/kernels/sumcheck.cu @@ -39,11 +39,11 @@ extern "C" __global__ void combine_and_sum( extern "C" __global__ void fold_into_half( unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge ) { - int tid = threadIdx.x; + int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; const int stride = 1 << (num_vars - 1); - const int buf_offset = (blockIdx.x / num_blocks_per_poly) * stride + (blockIdx.x % num_blocks_per_poly) * blockDim.x; - const int poly_offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size + (blockIdx.x % num_blocks_per_poly) * blockDim.x; - while (tid < stride && buf_offset < (blockIdx.x / num_blocks_per_poly + 1) * stride) { + const int buf_offset = (blockIdx.x / num_blocks_per_poly) * stride; + const int poly_offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size; + while (tid < stride) { buf[buf_offset + tid] = (*challenge) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid]; tid += blockDim.x * num_blocks_per_poly; } @@ -52,9 +52,9 @@ extern "C" __global__ void fold_into_half( extern "C" __global__ void fold_into_half_in_place( unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge ) { - int tid = threadIdx.x; + int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; const int stride = 1 << (num_vars - 1); - const int offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size + (blockIdx.x % num_blocks_per_poly) * blockDim.x; + const int offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size; while (tid < stride) { int idx = offset + tid; polys[idx] = (*challenge) * (polys[idx + stride] - polys[idx]) + polys[idx]; diff --git a/sumcheck/src/gpu/sumcheck.rs b/sumcheck/src/gpu/sumcheck.rs index f2b6b2d..7949e5d 100644 --- a/sumcheck/src/gpu/sumcheck.rs +++ b/sumcheck/src/gpu/sumcheck.rs @@ -55,7 +55,8 @@ impl + ToFieldBinding> GPUApiWrapper { mut buf: RefMut>, mut round_evals: RefMut>, ) -> Result<(), DriverError> { - let num_blocks_per_poly = self.max_blocks_per_sm()?; + let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?; + let num_threads_per_block = 1024; for k in 0..max_degree + 1 { let device_k = self .gpu @@ -63,7 +64,7 @@ impl + ToFieldBinding> GPUApiWrapper { let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap(); let launch_config = LaunchConfig { grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1), - block_dim: (1024, 1, 1), + block_dim: (num_threads_per_block as u32, 1, 1), shared_mem_bytes: 0, }; unsafe { @@ -130,10 +131,11 @@ impl + ToFieldBinding> GPUApiWrapper { .gpu .get_func("sumcheck", "fold_into_half_in_place") .unwrap(); - let num_blocks_per_poly = self.max_blocks_per_sm()?; + let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?; + let num_threads_per_block = 1024; let launch_config = LaunchConfig { grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1), - block_dim: (1024, 1, 1), + block_dim: (num_threads_per_block as u32, 1, 1), shared_mem_bytes: 0, }; unsafe { @@ -171,7 +173,7 @@ mod tests { let max_degree = 4; let rng = OsRng::default(); - let combine_function = |args: &Vec| args.iter().product(); + let combine_function = |args: &Vec| args.iter().sum(); let polys = (0..num_polys) .map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec()) @@ -207,10 +209,9 @@ mod tests { // copy polynomials to device let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?; - let mut buf = gpu_api_wrapper.copy_to_device(&vec![Fr::ZERO; num_polys << num_vars])?; + let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; let buf_view = RefCell::new(buf.slice_mut(..)); - let mut round_evals = - gpu_api_wrapper.copy_to_device(&vec![Fr::ZERO; max_degree as usize + 1])?; + let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?; let round_evals_view = RefCell::new(round_evals.slice_mut(..)); let round = 0; let now = Instant::now(); @@ -228,7 +229,7 @@ mod tests { "Time taken to eval_at_k_and_combine on gpu: {:.2?}", now.elapsed() ); - let gpu_round_evals = gpu_api_wrapper.dtoh_sync_copy(round_evals.slice(..), true)?; + let gpu_round_evals = gpu_api_wrapper.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?; cpu_round_evals .iter() .zip_eq(gpu_round_evals.iter()) @@ -241,7 +242,7 @@ mod tests { #[test] fn test_fold_into_half_in_place() -> Result<(), DriverError> { - let num_vars = 20; + let num_vars = 15; let num_polys = 4; let rng = OsRng::default(); @@ -310,8 +311,8 @@ mod tests { #[test] fn test_prove_sumcheck() -> Result<(), DriverError> { let num_vars = 19; - let num_polys = 9; - let max_degree = 1; + let num_polys = 4; + let max_degree = 4; let rng = OsRng::default(); let polys = (0..num_polys) diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 2302382..010997b 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -109,6 +109,10 @@ impl + ToFieldBinding> GPUApiWrapper { Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR)? as usize) } + pub fn num_sm(&self) -> Result { + Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)? as usize) + } + pub fn max_threads_per_sm(&self) -> Result { Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR)? as usize) }