diff --git a/sumcheck/src/field.rs b/sumcheck/src/field.rs new file mode 100644 index 0000000..4eeadf8 --- /dev/null +++ b/sumcheck/src/field.rs @@ -0,0 +1,64 @@ +use crate::FieldBinding; +use ff::{Field, PrimeField}; +use halo2curves::{bn256::Fr, serde::SerdeObject}; +use itertools::Itertools; + +pub trait FromFieldBinding { + fn from_canonical_form(b: FieldBinding) -> F; + + fn from_montgomery_form(b: FieldBinding) -> F; +} + +pub trait ToFieldBinding { + fn to_canonical_form(f: F) -> FieldBinding; + + fn to_montgomery_form(f: F) -> FieldBinding; +} + +macro_rules! field_binding_conversion { + ($field:ident) => { + impl FromFieldBinding<$field> for $field { + fn from_canonical_form(b: FieldBinding) -> $field { + $field::from_raw(b.data) + } + + fn from_montgomery_form(b: FieldBinding) -> $field { + let bytes = b + .data + .into_iter() + .map(|data| data.to_le_bytes()) + .collect_vec() + .concat(); + $field::from_raw_bytes_unchecked(bytes.as_ref()) + } + } + + impl ToFieldBinding<$field> for $field { + fn to_canonical_form(f: $field) -> FieldBinding { + let repr = f.to_repr(); + let bytes = repr.as_ref(); + let data = bytes + .chunks(8) + .map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap())) + .collect_vec(); + FieldBinding { + data: data.try_into().unwrap(), + } + } + + fn to_montgomery_form(f: $field) -> FieldBinding { + let mut buf = vec![]; + f.write_raw(&mut buf); + let data = buf + .chunks(8) + .map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap())) + .collect_vec(); + FieldBinding { + data: data.try_into().unwrap(), + } + } + } + }; +} + +field_binding_conversion!(Fr); diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index ad25d1f..37806d2 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -1,58 +1,108 @@ -use std::marker::PhantomData; -use std::time::Instant; -use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync}; +use cudarc::driver::{CudaDevice, DeviceRepr, DriverError, LaunchAsync, LaunchConfig}; use cudarc::nvrtc::Ptx; use ff::{Field, PrimeField}; +use field::{FromFieldBinding, ToFieldBinding}; use halo2curves::bn256::Fr; use itertools::Itertools; +use std::marker::PhantomData; +use std::process::Output; +use std::time::Instant; + +pub mod field; include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -// include the compiled PTX code as string -const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx")); - unsafe impl DeviceRepr for FieldBinding {} impl Default for FieldBinding { fn default() -> Self { - Self{ data: [0; 4] } + Self { data: [0; 4] } } } -impl From for FieldBinding { - fn from(value: F) -> Self { - let repr = value.to_repr(); - let bytes = repr.as_ref(); - let data = bytes.chunks(8).map(|bytes| { - u64::from_le_bytes(bytes.try_into().unwrap()) - }).collect_vec(); - FieldBinding { - data: data.try_into().unwrap() - } - } -} +// include the compiled PTX code as string +const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx")); /// Wrapper struct for APIs using GPU #[derive(Default)] -pub struct GPUApiWrapper(PhantomData); +pub struct GPUApiWrapper + ToFieldBinding>(PhantomData); -impl GPUApiWrapper { - pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result { - todo!() +impl + ToFieldBinding> GPUApiWrapper { + pub fn evaluate_poly( + &self, + num_vars: usize, + poly_coeffs: &[F], + point: &[F], + ) -> Result { + // setup GPU device + let now = Instant::now(); + + let gpu = CudaDevice::new(0)?; + + println!("Time taken to initialise CUDA: {:.2?}", now.elapsed()); + + // compile ptx + let now = Instant::now(); + + let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT); + gpu.load_ptx(ptx, "multilinear", &["evaluate"])?; + + println!("Time taken to compile and load PTX: {:.2?}", now.elapsed()); + + let point = point + .into_iter() + .map(|f| F::to_canonical_form(*f)) + .collect_vec(); + + // copy to GPU + let gpu_coeffs = gpu.htod_copy( + poly_coeffs + .into_iter() + .map(|&coeff| F::to_canonical_form(coeff)) + .collect_vec(), + )?; + let gpu_eval_point = gpu.htod_copy(point)?; + let monomial_evals = gpu.htod_copy(vec![FieldBinding::default(); 1 << num_vars])?; + + println!("Time taken to initialise data: {:.2?}", now.elapsed()); + + let now = Instant::now(); + + let f = gpu.get_func("multilinear", "evaluate").unwrap(); + + unsafe { + f.launch( + LaunchConfig::for_num_elems(1 << num_vars as u32), + (&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals), + ) + }?; + + println!("Time taken to call kernel: {:.2?}", now.elapsed()); + + let monomial_evals = gpu.sync_reclaim(monomial_evals)?; + + let result = monomial_evals + .into_iter() + .map(|eval| F::from_canonical_form(eval)) + .sum::(); + Ok(result) } } #[cfg(test)] mod tests { - use std::{default, time::Instant}; + use std::{default, fmt::Error, time::Instant}; - use cudarc::{driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, nvrtc::Ptx}; + use cudarc::{ + driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, + nvrtc::Ptx, + }; use ff::{Field, PrimeField}; use halo2curves::bn256::Fr; use itertools::Itertools; use rand::rngs::OsRng; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; - use super::{GPUApiWrapper, FieldBinding, CUDA_KERNEL_MY_STRUCT}; + use super::{FieldBinding, GPUApiWrapper, CUDA_KERNEL_MY_STRUCT}; fn evaluate_poly_cpu(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F { poly_coeffs @@ -63,11 +113,11 @@ mod tests { F::ZERO } else { let indices = (0..num_vars).map(|j| (i >> j) & 1).collect_vec(); - let mut result = F::ONE; + let mut result = coeff.clone(); for (index, point) in indices.iter().zip(point.iter()) { result *= if *index == 1 { *point } else { F::ONE }; } - result * coeff + result } }) .sum() @@ -75,17 +125,14 @@ mod tests { #[test] fn test_evaluate_poly() -> Result<(), DriverError> { - let num_vars = 16; + let num_vars = 6; let rng = OsRng::default(); - let poly_coeffs = (0..1 << num_vars).map(|_| { - Fr::random(rng) - }).collect_vec(); - let point = (0..num_vars).map(|_| { - Fr::random(rng) - }).collect_vec(); + let poly_coeffs = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec(); + let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec(); let gpu_api_wrapper = GPUApiWrapper::::default(); let eval_poly_result_by_cpu = evaluate_poly_cpu(&poly_coeffs, &point, num_vars); - let eval_poly_result_by_gpu = gpu_api_wrapper.evaluate_poly(&poly_coeffs, &point)?; + let eval_poly_result_by_gpu = + gpu_api_wrapper.evaluate_poly(num_vars, &poly_coeffs, &point)?; assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu); Ok(()) } @@ -112,9 +159,7 @@ mod tests { println!("a * b : {:?}", a * b); - let a_data = FieldBinding { - data: [2, 0, 0, 0] - }; + let a_data = FieldBinding { data: [2, 0, 0, 0] }; let b_data = FieldBinding { data: [ @@ -122,7 +167,7 @@ mod tests { 0x9419f4243cdcb848, 0xdc2822db40c0ac2e, 0x183227397098d014, - ] + ], }; // copy to GPU @@ -135,7 +180,12 @@ mod tests { let f = gpu.get_func("my_module", "mul").unwrap(); - unsafe { f.launch(LaunchConfig::for_num_elems(1024 as u32), (&gpu_field_structs, &results)) }?; + unsafe { + f.launch( + LaunchConfig::for_num_elems(1024 as u32), + (&gpu_field_structs, &results), + ) + }?; println!("Time taken to call kernel: {:.2?}", now.elapsed());