diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 115f1bf..d4e8c50 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -1,43 +1,44 @@ +use std::marker::PhantomData; use std::time::Instant; use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync}; use cudarc::nvrtc::Ptx; -use ff::Field; +use ff::{Field, PrimeField}; +use halo2curves::bn256::Fr; include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -// TODO : Replace this with our sumcheck struct -unsafe impl DeviceRepr for MyStruct {} -impl Default for MyStruct { - fn default() -> Self{ - Self{ data: [0.0; 4]} +// include the compiled PTX code as string +const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx")); + +unsafe impl DeviceRepr for FieldBinding {} +impl Default for FieldBinding { + fn default() -> Self { + Self{ data: [0; 4] } } } -// include the compiled PTX code as string -const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx")); - /// Wrapper struct for APIs using GPU #[derive(Default)] -pub struct GPUApiWrapper {} +pub struct GPUApiWrapper(PhantomData); impl GPUApiWrapper { pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result { - + todo!() } } #[cfg(test)] mod tests { - use std::default; + use std::{default, time::Instant}; - use cudarc::driver::DriverError; - use ff::Field; + use cudarc::{driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, nvrtc::Ptx}; + use ff::{Field, PrimeField}; use halo2curves::bn256::Fr; use itertools::Itertools; use rand::rngs::OsRng; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; - use super::GPUApiWrapper; + use super::{GPUApiWrapper, FieldBinding, CUDA_KERNEL_MY_STRUCT}; fn evaluate_poly_cpu(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F { poly_coeffs @@ -74,4 +75,64 @@ mod tests { assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu); Ok(()) } + + #[test] + fn test_scalar_multiplication() -> Result<(), DriverError> { + // setup GPU device + let now = Instant::now(); + + let gpu = CudaDevice::new(0)?; + + println!("Time taken to initialise CUDA: {:.2?}", now.elapsed()); + + // compile ptx + let now = Instant::now(); + + let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT); + gpu.load_ptx(ptx, "my_module", &["mul"])?; + + println!("Time taken to compile and load PTX: {:.2?}", now.elapsed()); + + let a = Fr::from(2); + let b = Fr::TWO_INV; + + println!("a * b : {:?}", a * b); + + let a_data = FieldBinding { + data: [2, 0, 0, 0] + }; + + let b_data = FieldBinding { + data: [ + 0xa1f0fac9f8000001, + 0x9419f4243cdcb848, + 0xdc2822db40c0ac2e, + 0x183227397098d014, + ] + }; + + // copy to GPU + let gpu_field_structs = gpu.htod_copy(vec![a_data, b_data])?; + let results = gpu.htod_copy(vec![FieldBinding::default(); 1024])?; + + println!("Time taken to initialise data: {:.2?}", now.elapsed()); + + let now = Instant::now(); + + let f = gpu.get_func("my_module", "mul").unwrap(); + + unsafe { f.launch(LaunchConfig::for_num_elems(1024 as u32), (&gpu_field_structs, &results)) }?; + + println!("Time taken to call kernel: {:.2?}", now.elapsed()); + + let results = gpu.sync_reclaim(results)?; + + results.iter().for_each(|result| { + assert_eq!(result.data[0], 1); + assert_eq!(result.data[1], 0); + assert_eq!(result.data[1], 0); + assert_eq!(result.data[1], 0); + }); + Ok(()) + } }