mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 23:47:57 -05:00
Malloc device_ks before sumcheck proving
This commit is contained in:
@@ -32,7 +32,7 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign
|
|||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void fold_into_half(
|
extern "C" __global__ void fold_into_half(
|
||||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge
|
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge
|
||||||
) {
|
) {
|
||||||
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
||||||
const int stride = 1 << (num_vars - 1);
|
const int stride = 1 << (num_vars - 1);
|
||||||
@@ -47,7 +47,7 @@ extern "C" __global__ void fold_into_half(
|
|||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void fold_into_half_in_place(
|
extern "C" __global__ void fold_into_half_in_place(
|
||||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge
|
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, const fr* challenge
|
||||||
) {
|
) {
|
||||||
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
||||||
const int stride = 1 << (num_vars - 1);
|
const int stride = 1 << (num_vars - 1);
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
use std::cell::{RefCell, RefMut};
|
use std::cell::{RefCell, RefMut};
|
||||||
|
|
||||||
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||||
use ff::PrimeField;
|
use ff::PrimeField;
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
||||||
@@ -16,6 +17,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
|||||||
max_degree: usize,
|
max_degree: usize,
|
||||||
sum: F,
|
sum: F,
|
||||||
polys: &mut CudaViewMut<FieldBinding>,
|
polys: &mut CudaViewMut<FieldBinding>,
|
||||||
|
device_ks: &[CudaView<FieldBinding>],
|
||||||
buf: RefCell<CudaViewMut<FieldBinding>>,
|
buf: RefCell<CudaViewMut<FieldBinding>>,
|
||||||
challenges: &mut CudaViewMut<FieldBinding>,
|
challenges: &mut CudaViewMut<FieldBinding>,
|
||||||
round_evals: RefCell<CudaViewMut<FieldBinding>>,
|
round_evals: RefCell<CudaViewMut<FieldBinding>>,
|
||||||
@@ -28,6 +30,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
|||||||
max_degree,
|
max_degree,
|
||||||
num_polys,
|
num_polys,
|
||||||
&polys.slice(..),
|
&polys.slice(..),
|
||||||
|
device_ks,
|
||||||
buf.borrow_mut(),
|
buf.borrow_mut(),
|
||||||
round_evals.borrow_mut(),
|
round_evals.borrow_mut(),
|
||||||
)?;
|
)?;
|
||||||
@@ -52,15 +55,13 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
|||||||
max_degree: usize,
|
max_degree: usize,
|
||||||
num_polys: usize,
|
num_polys: usize,
|
||||||
polys: &CudaView<FieldBinding>,
|
polys: &CudaView<FieldBinding>,
|
||||||
|
device_ks: &[CudaView<FieldBinding>],
|
||||||
mut buf: RefMut<CudaViewMut<FieldBinding>>,
|
mut buf: RefMut<CudaViewMut<FieldBinding>>,
|
||||||
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
|
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
|
||||||
) -> Result<(), DriverError> {
|
) -> Result<(), DriverError> {
|
||||||
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
|
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
|
||||||
let num_threads_per_block = 1024;
|
let num_threads_per_block = 1024;
|
||||||
for k in 0..max_degree + 1 {
|
for k in 0..max_degree + 1 {
|
||||||
let device_k = self
|
|
||||||
.gpu
|
|
||||||
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])?;
|
|
||||||
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
|
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
|
||||||
let launch_config = LaunchConfig {
|
let launch_config = LaunchConfig {
|
||||||
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
|
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
|
||||||
@@ -76,7 +77,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
|||||||
num_blocks_per_poly,
|
num_blocks_per_poly,
|
||||||
polys,
|
polys,
|
||||||
&mut *buf,
|
&mut *buf,
|
||||||
&device_k,
|
&device_ks[k],
|
||||||
),
|
),
|
||||||
)?;
|
)?;
|
||||||
};
|
};
|
||||||
@@ -173,7 +174,7 @@ mod tests {
|
|||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
|
|
||||||
use crate::{cpu, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
|
use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
||||||
@@ -228,6 +229,13 @@ mod tests {
|
|||||||
|
|
||||||
// copy polynomials to device
|
// copy polynomials to device
|
||||||
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||||
|
let device_ks = (0..max_degree + 1)
|
||||||
|
.map(|k| {
|
||||||
|
gpu_api_wrapper
|
||||||
|
.gpu
|
||||||
|
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||||
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
|
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
|
||||||
@@ -240,6 +248,10 @@ mod tests {
|
|||||||
max_degree as usize,
|
max_degree as usize,
|
||||||
num_polys,
|
num_polys,
|
||||||
&gpu_polys.slice(..),
|
&gpu_polys.slice(..),
|
||||||
|
&device_ks
|
||||||
|
.iter()
|
||||||
|
.map(|device_k| device_k.slice(..))
|
||||||
|
.collect_vec(),
|
||||||
buf_view.borrow_mut(),
|
buf_view.borrow_mut(),
|
||||||
round_evals_view.borrow_mut(),
|
round_evals_view.borrow_mut(),
|
||||||
)?;
|
)?;
|
||||||
@@ -248,7 +260,7 @@ mod tests {
|
|||||||
now.elapsed()
|
now.elapsed()
|
||||||
);
|
);
|
||||||
let gpu_round_evals = gpu_api_wrapper
|
let gpu_round_evals = gpu_api_wrapper
|
||||||
.dtoh_sync_copy(round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
||||||
cpu_round_evals
|
cpu_round_evals
|
||||||
.iter()
|
.iter()
|
||||||
.zip_eq(gpu_round_evals.iter())
|
.zip_eq(gpu_round_evals.iter())
|
||||||
@@ -312,7 +324,7 @@ mod tests {
|
|||||||
let gpu_result = (0..num_polys)
|
let gpu_result = (0..num_polys)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
gpu_api_wrapper.dtoh_sync_copy(
|
gpu_api_wrapper.dtoh_sync_copy(
|
||||||
gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
|
&gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@@ -338,9 +350,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
||||||
let num_vars = 23;
|
let num_vars = 25;
|
||||||
let num_polys = 4;
|
let num_polys = 2;
|
||||||
let max_degree = 4;
|
let max_degree = 2;
|
||||||
|
|
||||||
let rng = OsRng::default();
|
let rng = OsRng::default();
|
||||||
let polys = (0..num_polys)
|
let polys = (0..num_polys)
|
||||||
@@ -380,6 +392,13 @@ mod tests {
|
|||||||
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
|
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
|
||||||
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
|
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
|
||||||
});
|
});
|
||||||
|
let device_ks = (0..max_degree + 1)
|
||||||
|
.map(|k| {
|
||||||
|
gpu_api_wrapper
|
||||||
|
.gpu
|
||||||
|
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||||
|
|
||||||
@@ -395,20 +414,25 @@ mod tests {
|
|||||||
max_degree,
|
max_degree,
|
||||||
sum,
|
sum,
|
||||||
&mut gpu_polys.slice_mut(..),
|
&mut gpu_polys.slice_mut(..),
|
||||||
|
&device_ks
|
||||||
|
.iter()
|
||||||
|
.map(|device_k| device_k.slice(..))
|
||||||
|
.collect_vec(),
|
||||||
buf_view,
|
buf_view,
|
||||||
&mut challenges.slice_mut(..),
|
&mut challenges.slice_mut(..),
|
||||||
round_evals_view,
|
round_evals_view,
|
||||||
)?;
|
)?;
|
||||||
|
gpu_api_wrapper.gpu.synchronize()?;
|
||||||
println!(
|
println!(
|
||||||
"Time taken to prove sumcheck on gpu : {:.2?}",
|
"Time taken to prove sumcheck on gpu : {:.2?}",
|
||||||
now.elapsed()
|
now.elapsed()
|
||||||
);
|
);
|
||||||
|
|
||||||
let challenges = gpu_api_wrapper.dtoh_sync_copy(challenges.slice(..), true)?;
|
let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?;
|
||||||
let round_evals = (0..num_vars)
|
let round_evals = (0..num_vars)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
gpu_api_wrapper.dtoh_sync_copy(
|
gpu_api_wrapper.dtoh_sync_copy(
|
||||||
round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
&round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -78,10 +78,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
|||||||
|
|
||||||
pub fn dtoh_sync_copy(
|
pub fn dtoh_sync_copy(
|
||||||
&self,
|
&self,
|
||||||
device_data: CudaView<FieldBinding>,
|
device_data: &CudaView<FieldBinding>,
|
||||||
convert_from_montgomery_form: bool,
|
convert_from_montgomery_form: bool,
|
||||||
) -> Result<Vec<F>, DriverError> {
|
) -> Result<Vec<F>, DriverError> {
|
||||||
let host_data = self.gpu.dtoh_sync_copy(&device_data)?;
|
let host_data = self.gpu.dtoh_sync_copy(device_data)?;
|
||||||
let mut target = vec![F::ZERO; host_data.len()];
|
let mut target = vec![F::ZERO; host_data.len()];
|
||||||
if convert_from_montgomery_form {
|
if convert_from_montgomery_form {
|
||||||
parallelize(&mut target, |(target, start)| {
|
parallelize(&mut target, |(target, start)| {
|
||||||
|
|||||||
Reference in New Issue
Block a user