Malloc device_ks before sumcheck proving

This commit is contained in:
DoHoonKim8
2024-09-15 06:38:42 +00:00
parent ae8be89df1
commit 57bc9ac507
3 changed files with 41 additions and 17 deletions

View File

@@ -32,7 +32,7 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign
} }
extern "C" __global__ void fold_into_half( extern "C" __global__ void fold_into_half(
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge
) { ) {
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
const int stride = 1 << (num_vars - 1); const int stride = 1 << (num_vars - 1);
@@ -47,7 +47,7 @@ extern "C" __global__ void fold_into_half(
} }
extern "C" __global__ void fold_into_half_in_place( extern "C" __global__ void fold_into_half_in_place(
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, const fr* challenge
) { ) {
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
const int stride = 1 << (num_vars - 1); const int stride = 1 << (num_vars - 1);

View File

@@ -1,7 +1,8 @@
use std::cell::{RefCell, RefMut}; use std::cell::{RefCell, RefMut};
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig}; use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
use ff::PrimeField; use ff::PrimeField;
use itertools::Itertools;
use crate::{ use crate::{
fieldbinding::{FromFieldBinding, ToFieldBinding}, fieldbinding::{FromFieldBinding, ToFieldBinding},
@@ -16,6 +17,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
max_degree: usize, max_degree: usize,
sum: F, sum: F,
polys: &mut CudaViewMut<FieldBinding>, polys: &mut CudaViewMut<FieldBinding>,
device_ks: &[CudaView<FieldBinding>],
buf: RefCell<CudaViewMut<FieldBinding>>, buf: RefCell<CudaViewMut<FieldBinding>>,
challenges: &mut CudaViewMut<FieldBinding>, challenges: &mut CudaViewMut<FieldBinding>,
round_evals: RefCell<CudaViewMut<FieldBinding>>, round_evals: RefCell<CudaViewMut<FieldBinding>>,
@@ -28,6 +30,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
max_degree, max_degree,
num_polys, num_polys,
&polys.slice(..), &polys.slice(..),
device_ks,
buf.borrow_mut(), buf.borrow_mut(),
round_evals.borrow_mut(), round_evals.borrow_mut(),
)?; )?;
@@ -52,15 +55,13 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
max_degree: usize, max_degree: usize,
num_polys: usize, num_polys: usize,
polys: &CudaView<FieldBinding>, polys: &CudaView<FieldBinding>,
device_ks: &[CudaView<FieldBinding>],
mut buf: RefMut<CudaViewMut<FieldBinding>>, mut buf: RefMut<CudaViewMut<FieldBinding>>,
mut round_evals: RefMut<CudaViewMut<FieldBinding>>, mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> { ) -> Result<(), DriverError> {
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?; let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
let num_threads_per_block = 1024; let num_threads_per_block = 1024;
for k in 0..max_degree + 1 { for k in 0..max_degree + 1 {
let device_k = self
.gpu
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])?;
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap(); let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
let launch_config = LaunchConfig { let launch_config = LaunchConfig {
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1), grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
@@ -76,7 +77,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
num_blocks_per_poly, num_blocks_per_poly,
polys, polys,
&mut *buf, &mut *buf,
&device_k, &device_ks[k],
), ),
)?; )?;
}; };
@@ -173,7 +174,7 @@ mod tests {
use itertools::Itertools; use itertools::Itertools;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use crate::{cpu, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX}; use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
#[test] #[test]
fn test_eval_at_k_and_combine() -> Result<(), DriverError> { fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
@@ -228,6 +229,13 @@ mod tests {
// copy polynomials to device // copy polynomials to device
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?; let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper
.gpu
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()?;
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
let buf_view = RefCell::new(buf.slice_mut(..)); let buf_view = RefCell::new(buf.slice_mut(..));
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?; let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
@@ -240,6 +248,10 @@ mod tests {
max_degree as usize, max_degree as usize,
num_polys, num_polys,
&gpu_polys.slice(..), &gpu_polys.slice(..),
&device_ks
.iter()
.map(|device_k| device_k.slice(..))
.collect_vec(),
buf_view.borrow_mut(), buf_view.borrow_mut(),
round_evals_view.borrow_mut(), round_evals_view.borrow_mut(),
)?; )?;
@@ -248,7 +260,7 @@ mod tests {
now.elapsed() now.elapsed()
); );
let gpu_round_evals = gpu_api_wrapper let gpu_round_evals = gpu_api_wrapper
.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?; .dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?;
cpu_round_evals cpu_round_evals
.iter() .iter()
.zip_eq(gpu_round_evals.iter()) .zip_eq(gpu_round_evals.iter())
@@ -312,7 +324,7 @@ mod tests {
let gpu_result = (0..num_polys) let gpu_result = (0..num_polys)
.map(|i| { .map(|i| {
gpu_api_wrapper.dtoh_sync_copy( gpu_api_wrapper.dtoh_sync_copy(
gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)), &gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
true, true,
) )
}) })
@@ -338,9 +350,9 @@ mod tests {
#[test] #[test]
fn test_prove_sumcheck() -> Result<(), DriverError> { fn test_prove_sumcheck() -> Result<(), DriverError> {
let num_vars = 23; let num_vars = 25;
let num_polys = 4; let num_polys = 2;
let max_degree = 4; let max_degree = 2;
let rng = OsRng::default(); let rng = OsRng::default();
let polys = (0..num_polys) let polys = (0..num_polys)
@@ -380,6 +392,13 @@ mod tests {
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| { let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
acc + polys.iter().map(|poly| poly[index]).product::<Fr>() acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
}); });
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper
.gpu
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()?;
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
let buf_view = RefCell::new(buf.slice_mut(..)); let buf_view = RefCell::new(buf.slice_mut(..));
@@ -395,20 +414,25 @@ mod tests {
max_degree, max_degree,
sum, sum,
&mut gpu_polys.slice_mut(..), &mut gpu_polys.slice_mut(..),
&device_ks
.iter()
.map(|device_k| device_k.slice(..))
.collect_vec(),
buf_view, buf_view,
&mut challenges.slice_mut(..), &mut challenges.slice_mut(..),
round_evals_view, round_evals_view,
)?; )?;
gpu_api_wrapper.gpu.synchronize()?;
println!( println!(
"Time taken to prove sumcheck on gpu : {:.2?}", "Time taken to prove sumcheck on gpu : {:.2?}",
now.elapsed() now.elapsed()
); );
let challenges = gpu_api_wrapper.dtoh_sync_copy(challenges.slice(..), true)?; let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?;
let round_evals = (0..num_vars) let round_evals = (0..num_vars)
.map(|i| { .map(|i| {
gpu_api_wrapper.dtoh_sync_copy( gpu_api_wrapper.dtoh_sync_copy(
round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)), &round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
true, true,
) )
}) })

View File

@@ -78,10 +78,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
pub fn dtoh_sync_copy( pub fn dtoh_sync_copy(
&self, &self,
device_data: CudaView<FieldBinding>, device_data: &CudaView<FieldBinding>,
convert_from_montgomery_form: bool, convert_from_montgomery_form: bool,
) -> Result<Vec<F>, DriverError> { ) -> Result<Vec<F>, DriverError> {
let host_data = self.gpu.dtoh_sync_copy(&device_data)?; let host_data = self.gpu.dtoh_sync_copy(device_data)?;
let mut target = vec![F::ZERO; host_data.len()]; let mut target = vec![F::ZERO; host_data.len()];
if convert_from_montgomery_form { if convert_from_montgomery_form {
parallelize(&mut target, |(target, start)| { parallelize(&mut target, |(target, start)| {