mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 15:38:01 -05:00
Make frame for testing polynomial evaluation using GPU
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use std::time::Instant;
|
||||
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
|
||||
use cudarc::nvrtc::Ptx;
|
||||
use ff::Field;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
@@ -15,44 +16,62 @@ impl Default for MyStruct {
|
||||
// include the compiled PTX code as string
|
||||
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx"));
|
||||
|
||||
fn main() -> Result<(), DriverError> {
|
||||
// setup GPU device
|
||||
let now = Instant::now();
|
||||
/// Wrapper struct for APIs using GPU
|
||||
#[derive(Default)]
|
||||
pub struct GPUApiWrapper<F: Field> {}
|
||||
|
||||
let gpu = CudaDevice::new(0)?;
|
||||
impl<F: Field> GPUApiWrapper<F> {
|
||||
pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result<F, DriverError> {
|
||||
|
||||
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
|
||||
|
||||
// compile ptx
|
||||
let now = Instant::now();
|
||||
|
||||
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
|
||||
gpu.load_ptx(ptx, "my_module", &["my_struct_kernel"])?;
|
||||
|
||||
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
|
||||
|
||||
// create data
|
||||
let now = Instant::now();
|
||||
|
||||
let n = 10_usize;
|
||||
let my_structs = vec![MyStruct { data: [1.0; 4] }; n];
|
||||
|
||||
// copy to GPU
|
||||
let gpu_my_structs = gpu.htod_copy(my_structs)?;
|
||||
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let f = gpu.get_func("my_module", "my_struct_kernel").unwrap();
|
||||
|
||||
unsafe { f.launch(LaunchConfig::for_num_elems(n as u32), (&gpu_my_structs, n)) }?;
|
||||
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
|
||||
let my_structs = gpu.sync_reclaim(gpu_my_structs)?;
|
||||
|
||||
assert!(my_structs.iter().all(|i| i.data == [2.0; 4]));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::default;
|
||||
|
||||
use cudarc::driver::DriverError;
|
||||
use ff::Field;
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
use super::GPUApiWrapper;
|
||||
|
||||
fn evaluate_poly_cpu<F: Field>(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F {
|
||||
poly_coeffs
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, coeff)| {
|
||||
if *coeff == F::ZERO {
|
||||
F::ZERO
|
||||
} else {
|
||||
let indices = (0..num_vars).map(|j| (i >> j) & 1).collect_vec();
|
||||
let mut result = F::ONE;
|
||||
for (index, point) in indices.iter().zip(point.iter()) {
|
||||
result *= if *index == 1 { *point } else { F::ONE };
|
||||
}
|
||||
result * coeff
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_poly() -> Result<(), DriverError> {
|
||||
let num_vars = 16;
|
||||
let rng = OsRng::default();
|
||||
let poly_coeffs = (0..1 << num_vars).map(|_| {
|
||||
Fr::random(rng)
|
||||
}).collect_vec();
|
||||
let point = (0..num_vars).map(|_| {
|
||||
Fr::random(rng)
|
||||
}).collect_vec();
|
||||
let gpu_api_wrapper = GPUApiWrapper::<Fr>::default();
|
||||
let eval_poly_result_by_cpu = evaluate_poly_cpu(&poly_coeffs, &point, num_vars);
|
||||
let eval_poly_result_by_gpu = gpu_api_wrapper.evaluate_poly(&poly_coeffs, &point)?;
|
||||
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user