Write scalar multiplication test

This commit is contained in:
DoHoonKim8
2024-06-15 18:30:30 +00:00
committed by DoHoon Kim
parent d3bd36b696
commit 5957edb71c

View File

@@ -1,43 +1,44 @@
use std::marker::PhantomData;
use std::time::Instant;
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
use cudarc::nvrtc::Ptx;
use ff::Field;
use ff::{Field, PrimeField};
use halo2curves::bn256::Fr;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
// TODO : Replace this with our sumcheck struct
unsafe impl DeviceRepr for MyStruct {}
impl Default for MyStruct {
fn default() -> Self{
Self{ data: [0.0; 4]}
// include the compiled PTX code as string
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
unsafe impl DeviceRepr for FieldBinding {}
impl Default for FieldBinding {
fn default() -> Self {
Self{ data: [0; 4] }
}
}
// include the compiled PTX code as string
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx"));
/// Wrapper struct for APIs using GPU
#[derive(Default)]
pub struct GPUApiWrapper<F: Field> {}
pub struct GPUApiWrapper<F: Field>(PhantomData<F>);
impl<F: Field> GPUApiWrapper<F> {
pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result<F, DriverError> {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::default;
use std::{default, time::Instant};
use cudarc::driver::DriverError;
use ff::Field;
use cudarc::{driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, nvrtc::Ptx};
use ff::{Field, PrimeField};
use halo2curves::bn256::Fr;
use itertools::Itertools;
use rand::rngs::OsRng;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use super::GPUApiWrapper;
use super::{GPUApiWrapper, FieldBinding, CUDA_KERNEL_MY_STRUCT};
fn evaluate_poly_cpu<F: Field>(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F {
poly_coeffs
@@ -74,4 +75,64 @@ mod tests {
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
Ok(())
}
#[test]
fn test_scalar_multiplication() -> Result<(), DriverError> {
// setup GPU device
let now = Instant::now();
let gpu = CudaDevice::new(0)?;
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
// compile ptx
let now = Instant::now();
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
gpu.load_ptx(ptx, "my_module", &["mul"])?;
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
let a = Fr::from(2);
let b = Fr::TWO_INV;
println!("a * b : {:?}", a * b);
let a_data = FieldBinding {
data: [2, 0, 0, 0]
};
let b_data = FieldBinding {
data: [
0xa1f0fac9f8000001,
0x9419f4243cdcb848,
0xdc2822db40c0ac2e,
0x183227397098d014,
]
};
// copy to GPU
let gpu_field_structs = gpu.htod_copy(vec![a_data, b_data])?;
let results = gpu.htod_copy(vec![FieldBinding::default(); 1024])?;
println!("Time taken to initialise data: {:.2?}", now.elapsed());
let now = Instant::now();
let f = gpu.get_func("my_module", "mul").unwrap();
unsafe { f.launch(LaunchConfig::for_num_elems(1024 as u32), (&gpu_field_structs, &results)) }?;
println!("Time taken to call kernel: {:.2?}", now.elapsed());
let results = gpu.sync_reclaim(results)?;
results.iter().for_each(|result| {
assert_eq!(result.data[0], 1);
assert_eq!(result.data[1], 0);
assert_eq!(result.data[1], 0);
assert_eq!(result.data[1], 0);
});
Ok(())
}
}