Malloc device_ks before sumcheck proving

This commit is contained in:
DoHoonKim8
2024-09-15 06:38:42 +00:00
parent ae8be89df1
commit 57bc9ac507
3 changed files with 41 additions and 17 deletions

View File

@@ -32,7 +32,7 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign
}
extern "C" __global__ void fold_into_half(
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge
) {
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
const int stride = 1 << (num_vars - 1);
@@ -47,7 +47,7 @@ extern "C" __global__ void fold_into_half(
}
extern "C" __global__ void fold_into_half_in_place(
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, const fr* challenge
) {
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
const int stride = 1 << (num_vars - 1);

View File

@@ -1,7 +1,8 @@
use std::cell::{RefCell, RefMut};
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
use ff::PrimeField;
use itertools::Itertools;
use crate::{
fieldbinding::{FromFieldBinding, ToFieldBinding},
@@ -16,6 +17,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
max_degree: usize,
sum: F,
polys: &mut CudaViewMut<FieldBinding>,
device_ks: &[CudaView<FieldBinding>],
buf: RefCell<CudaViewMut<FieldBinding>>,
challenges: &mut CudaViewMut<FieldBinding>,
round_evals: RefCell<CudaViewMut<FieldBinding>>,
@@ -28,6 +30,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
max_degree,
num_polys,
&polys.slice(..),
device_ks,
buf.borrow_mut(),
round_evals.borrow_mut(),
)?;
@@ -52,15 +55,13 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
max_degree: usize,
num_polys: usize,
polys: &CudaView<FieldBinding>,
device_ks: &[CudaView<FieldBinding>],
mut buf: RefMut<CudaViewMut<FieldBinding>>,
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> {
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
let num_threads_per_block = 1024;
for k in 0..max_degree + 1 {
let device_k = self
.gpu
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])?;
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
let launch_config = LaunchConfig {
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
@@ -76,7 +77,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
num_blocks_per_poly,
polys,
&mut *buf,
&device_k,
&device_ks[k],
),
)?;
};
@@ -173,7 +174,7 @@ mod tests {
use itertools::Itertools;
use rand::rngs::OsRng;
use crate::{cpu, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
#[test]
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
@@ -228,6 +229,13 @@ mod tests {
// copy polynomials to device
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper
.gpu
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()?;
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
let buf_view = RefCell::new(buf.slice_mut(..));
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
@@ -240,6 +248,10 @@ mod tests {
max_degree as usize,
num_polys,
&gpu_polys.slice(..),
&device_ks
.iter()
.map(|device_k| device_k.slice(..))
.collect_vec(),
buf_view.borrow_mut(),
round_evals_view.borrow_mut(),
)?;
@@ -248,7 +260,7 @@ mod tests {
now.elapsed()
);
let gpu_round_evals = gpu_api_wrapper
.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?;
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?;
cpu_round_evals
.iter()
.zip_eq(gpu_round_evals.iter())
@@ -312,7 +324,7 @@ mod tests {
let gpu_result = (0..num_polys)
.map(|i| {
gpu_api_wrapper.dtoh_sync_copy(
gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
&gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
true,
)
})
@@ -338,9 +350,9 @@ mod tests {
#[test]
fn test_prove_sumcheck() -> Result<(), DriverError> {
let num_vars = 23;
let num_polys = 4;
let max_degree = 4;
let num_vars = 25;
let num_polys = 2;
let max_degree = 2;
let rng = OsRng::default();
let polys = (0..num_polys)
@@ -380,6 +392,13 @@ mod tests {
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
});
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper
.gpu
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()?;
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
let buf_view = RefCell::new(buf.slice_mut(..));
@@ -395,20 +414,25 @@ mod tests {
max_degree,
sum,
&mut gpu_polys.slice_mut(..),
&device_ks
.iter()
.map(|device_k| device_k.slice(..))
.collect_vec(),
buf_view,
&mut challenges.slice_mut(..),
round_evals_view,
)?;
gpu_api_wrapper.gpu.synchronize()?;
println!(
"Time taken to prove sumcheck on gpu : {:.2?}",
now.elapsed()
);
let challenges = gpu_api_wrapper.dtoh_sync_copy(challenges.slice(..), true)?;
let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?;
let round_evals = (0..num_vars)
.map(|i| {
gpu_api_wrapper.dtoh_sync_copy(
round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
&round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
true,
)
})

View File

@@ -78,10 +78,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
pub fn dtoh_sync_copy(
&self,
device_data: CudaView<FieldBinding>,
device_data: &CudaView<FieldBinding>,
convert_from_montgomery_form: bool,
) -> Result<Vec<F>, DriverError> {
let host_data = self.gpu.dtoh_sync_copy(&device_data)?;
let host_data = self.gpu.dtoh_sync_copy(device_data)?;
let mut target = vec![F::ZERO; host_data.len()];
if convert_from_montgomery_form {
parallelize(&mut target, |(target, start)| {