mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-08 23:18:00 -05:00
Split combine and sum kernel
This commit is contained in:
@@ -3,7 +3,23 @@
|
||||
|
||||
using namespace bb;
|
||||
|
||||
__device__ void sum(fr* data, const int stride) {
|
||||
// TODO
|
||||
__device__ fr combine_function(fr* evals, unsigned int start, unsigned int stride, unsigned int num_args) {
|
||||
fr result = fr::one();
|
||||
for (int i = 0; i < num_args; i++) result *= evals[start + i * stride];
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" __global__ void combine(fr* buf, unsigned int size, unsigned int num_args) {
|
||||
const int tid = threadIdx.x;
|
||||
int idx = blockIdx.x * blockDim.x + tid;
|
||||
while (idx < size) {
|
||||
buf[idx] = combine_function(buf, idx, size, num_args);
|
||||
idx += blockDim.x * gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsigned int index) {
|
||||
const int tid = threadIdx.x;
|
||||
for (unsigned int s = stride; s > 0; s >>= 1) {
|
||||
int idx = tid;
|
||||
@@ -13,27 +29,7 @@ __device__ void sum(fr* data, const int stride) {
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
__device__ fr combine_function(fr* evals, unsigned int start, unsigned int stride, unsigned int num_args) {
|
||||
fr result = fr::zero();
|
||||
for (int i = 0; i < num_args; i++) result += evals[start + i * stride];
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" __global__ void combine_and_sum(
|
||||
fr* buf, fr* result, unsigned int size, unsigned int num_args, unsigned int index
|
||||
) {
|
||||
const int tid = threadIdx.x;
|
||||
int idx = tid;
|
||||
while (idx < size) {
|
||||
buf[idx] = combine_function(buf, idx, size, num_args);
|
||||
idx += blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
sum(buf, size >> 1);
|
||||
if (tid == 0) result[index] = buf[0];
|
||||
if (tid == 0) result[index] = data[0];
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half(
|
||||
|
||||
@@ -80,20 +80,29 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
),
|
||||
)?;
|
||||
};
|
||||
let combine_and_sum = self.gpu.get_func("sumcheck", "combine_and_sum").unwrap();
|
||||
let size = 1 << (initial_poly_num_vars - round - 1);
|
||||
let combine = self.gpu.get_func("sumcheck", "combine").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: (1, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
grid_dim: (num_blocks_per_poly as u32, 1, 1),
|
||||
block_dim: (num_threads_per_block, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
combine_and_sum.launch(
|
||||
combine.launch(launch_config, (&mut *buf, size, num_polys))?;
|
||||
};
|
||||
let sum = self.gpu.get_func("sumcheck", "sum").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: (1, 1, 1),
|
||||
block_dim: (num_threads_per_block, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
sum.launch(
|
||||
launch_config,
|
||||
(
|
||||
&mut *buf,
|
||||
&mut *round_evals,
|
||||
1 << (initial_poly_num_vars - round - 1),
|
||||
num_polys,
|
||||
size >> 1,
|
||||
round * (max_degree + 1) + k,
|
||||
),
|
||||
)?;
|
||||
@@ -168,12 +177,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
||||
let num_vars = 20;
|
||||
let num_vars = 10;
|
||||
let num_polys = 4;
|
||||
let max_degree = 4;
|
||||
let rng = OsRng::default();
|
||||
|
||||
let combine_function = |args: &Vec<Fr>| args.iter().sum();
|
||||
let combine_function = |args: &Vec<Fr>| args.iter().product();
|
||||
|
||||
let polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
@@ -188,7 +197,7 @@ mod tests {
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half", "combine_and_sum"],
|
||||
&["fold_into_half", "combine", "sum"],
|
||||
)?;
|
||||
|
||||
let mut cpu_round_evals = vec![];
|
||||
@@ -229,7 +238,8 @@ mod tests {
|
||||
"Time taken to eval_at_k_and_combine on gpu: {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
let gpu_round_evals = gpu_api_wrapper.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
||||
let gpu_round_evals = gpu_api_wrapper
|
||||
.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
||||
cpu_round_evals
|
||||
.iter()
|
||||
.zip_eq(gpu_round_evals.iter())
|
||||
@@ -242,7 +252,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_fold_into_half_in_place() -> Result<(), DriverError> {
|
||||
let num_vars = 15;
|
||||
let num_vars = 6;
|
||||
let num_polys = 4;
|
||||
|
||||
let rng = OsRng::default();
|
||||
@@ -310,7 +320,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
||||
let num_vars = 19;
|
||||
let num_vars = 12;
|
||||
let num_polys = 4;
|
||||
let max_degree = 4;
|
||||
|
||||
@@ -331,14 +341,16 @@ mod tests {
|
||||
&[
|
||||
"fold_into_half",
|
||||
"fold_into_half_in_place",
|
||||
"combine_and_sum",
|
||||
"combine",
|
||||
"sum",
|
||||
"squeeze_challenge",
|
||||
],
|
||||
)?;
|
||||
|
||||
let now = Instant::now();
|
||||
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
|
||||
acc + polys.iter().map(|poly| poly[index]).sum::<Fr>()
|
||||
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
|
||||
});
|
||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
@@ -346,6 +358,7 @@ mod tests {
|
||||
let mut challenges = gpu_api_wrapper.malloc_on_device(num_vars)?;
|
||||
let mut round_evals = gpu_api_wrapper.malloc_on_device(num_vars * (max_degree + 1))?;
|
||||
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
|
||||
println!("Time taken to copy data to device : {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
gpu_api_wrapper.prove_sumcheck(
|
||||
|
||||
@@ -110,10 +110,16 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
}
|
||||
|
||||
pub fn num_sm(&self) -> Result<usize, DriverError> {
|
||||
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)? as usize)
|
||||
Ok(self.gpu.attribute(
|
||||
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
||||
)? as usize)
|
||||
}
|
||||
|
||||
pub fn max_threads_per_sm(&self) -> Result<usize, DriverError> {
|
||||
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR)? as usize)
|
||||
}
|
||||
|
||||
pub fn shared_mem_bytes_per_block(&self) -> Result<usize, DriverError> {
|
||||
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)? as usize)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user