mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-08 23:18:00 -05:00
Malloc device_ks before sumcheck proving
This commit is contained in:
@@ -32,7 +32,7 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half(
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge
|
||||
) {
|
||||
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
||||
const int stride = 1 << (num_vars - 1);
|
||||
@@ -47,7 +47,7 @@ extern "C" __global__ void fold_into_half(
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half_in_place(
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, const fr* challenge
|
||||
) {
|
||||
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
||||
const int stride = 1 << (num_vars - 1);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::cell::{RefCell, RefMut};
|
||||
|
||||
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||
use ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
||||
@@ -16,6 +17,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
max_degree: usize,
|
||||
sum: F,
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
device_ks: &[CudaView<FieldBinding>],
|
||||
buf: RefCell<CudaViewMut<FieldBinding>>,
|
||||
challenges: &mut CudaViewMut<FieldBinding>,
|
||||
round_evals: RefCell<CudaViewMut<FieldBinding>>,
|
||||
@@ -28,6 +30,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
max_degree,
|
||||
num_polys,
|
||||
&polys.slice(..),
|
||||
device_ks,
|
||||
buf.borrow_mut(),
|
||||
round_evals.borrow_mut(),
|
||||
)?;
|
||||
@@ -52,15 +55,13 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
max_degree: usize,
|
||||
num_polys: usize,
|
||||
polys: &CudaView<FieldBinding>,
|
||||
device_ks: &[CudaView<FieldBinding>],
|
||||
mut buf: RefMut<CudaViewMut<FieldBinding>>,
|
||||
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
|
||||
let num_threads_per_block = 1024;
|
||||
for k in 0..max_degree + 1 {
|
||||
let device_k = self
|
||||
.gpu
|
||||
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])?;
|
||||
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
|
||||
@@ -76,7 +77,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
&mut *buf,
|
||||
&device_k,
|
||||
&device_ks[k],
|
||||
),
|
||||
)?;
|
||||
};
|
||||
@@ -173,7 +174,7 @@ mod tests {
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
use crate::{cpu, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
|
||||
use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
||||
@@ -228,6 +229,13 @@ mod tests {
|
||||
|
||||
// copy polynomials to device
|
||||
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let device_ks = (0..max_degree + 1)
|
||||
.map(|k| {
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
|
||||
@@ -240,6 +248,10 @@ mod tests {
|
||||
max_degree as usize,
|
||||
num_polys,
|
||||
&gpu_polys.slice(..),
|
||||
&device_ks
|
||||
.iter()
|
||||
.map(|device_k| device_k.slice(..))
|
||||
.collect_vec(),
|
||||
buf_view.borrow_mut(),
|
||||
round_evals_view.borrow_mut(),
|
||||
)?;
|
||||
@@ -248,7 +260,7 @@ mod tests {
|
||||
now.elapsed()
|
||||
);
|
||||
let gpu_round_evals = gpu_api_wrapper
|
||||
.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
||||
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
||||
cpu_round_evals
|
||||
.iter()
|
||||
.zip_eq(gpu_round_evals.iter())
|
||||
@@ -312,7 +324,7 @@ mod tests {
|
||||
let gpu_result = (0..num_polys)
|
||||
.map(|i| {
|
||||
gpu_api_wrapper.dtoh_sync_copy(
|
||||
gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
|
||||
&gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
|
||||
true,
|
||||
)
|
||||
})
|
||||
@@ -338,9 +350,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
||||
let num_vars = 23;
|
||||
let num_polys = 4;
|
||||
let max_degree = 4;
|
||||
let num_vars = 25;
|
||||
let num_polys = 2;
|
||||
let max_degree = 2;
|
||||
|
||||
let rng = OsRng::default();
|
||||
let polys = (0..num_polys)
|
||||
@@ -380,6 +392,13 @@ mod tests {
|
||||
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
|
||||
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
|
||||
});
|
||||
let device_ks = (0..max_degree + 1)
|
||||
.map(|k| {
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
|
||||
@@ -395,20 +414,25 @@ mod tests {
|
||||
max_degree,
|
||||
sum,
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
&device_ks
|
||||
.iter()
|
||||
.map(|device_k| device_k.slice(..))
|
||||
.collect_vec(),
|
||||
buf_view,
|
||||
&mut challenges.slice_mut(..),
|
||||
round_evals_view,
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to prove sumcheck on gpu : {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
|
||||
let challenges = gpu_api_wrapper.dtoh_sync_copy(challenges.slice(..), true)?;
|
||||
let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?;
|
||||
let round_evals = (0..num_vars)
|
||||
.map(|i| {
|
||||
gpu_api_wrapper.dtoh_sync_copy(
|
||||
round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
||||
&round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
||||
true,
|
||||
)
|
||||
})
|
||||
|
||||
@@ -78,10 +78,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
|
||||
pub fn dtoh_sync_copy(
|
||||
&self,
|
||||
device_data: CudaView<FieldBinding>,
|
||||
device_data: &CudaView<FieldBinding>,
|
||||
convert_from_montgomery_form: bool,
|
||||
) -> Result<Vec<F>, DriverError> {
|
||||
let host_data = self.gpu.dtoh_sync_copy(&device_data)?;
|
||||
let host_data = self.gpu.dtoh_sync_copy(device_data)?;
|
||||
let mut target = vec![F::ZERO; host_data.len()];
|
||||
if convert_from_montgomery_form {
|
||||
parallelize(&mut target, |(target, start)| {
|
||||
|
||||
Reference in New Issue
Block a user