mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-08 23:18:00 -05:00
Write scalar multiplication test
This commit is contained in:
@@ -1,43 +1,44 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::time::Instant;
|
||||
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
|
||||
use cudarc::nvrtc::Ptx;
|
||||
use ff::Field;
|
||||
use ff::{Field, PrimeField};
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
// TODO : Replace this with our sumcheck struct
|
||||
unsafe impl DeviceRepr for MyStruct {}
|
||||
impl Default for MyStruct {
|
||||
fn default() -> Self{
|
||||
Self{ data: [0.0; 4]}
|
||||
// include the compiled PTX code as string
|
||||
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
|
||||
|
||||
unsafe impl DeviceRepr for FieldBinding {}
|
||||
impl Default for FieldBinding {
|
||||
fn default() -> Self {
|
||||
Self{ data: [0; 4] }
|
||||
}
|
||||
}
|
||||
|
||||
// include the compiled PTX code as string
|
||||
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx"));
|
||||
|
||||
/// Wrapper struct for APIs using GPU
|
||||
#[derive(Default)]
|
||||
pub struct GPUApiWrapper<F: Field> {}
|
||||
pub struct GPUApiWrapper<F: Field>(PhantomData<F>);
|
||||
|
||||
impl<F: Field> GPUApiWrapper<F> {
|
||||
pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result<F, DriverError> {
|
||||
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::default;
|
||||
use std::{default, time::Instant};
|
||||
|
||||
use cudarc::driver::DriverError;
|
||||
use ff::Field;
|
||||
use cudarc::{driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, nvrtc::Ptx};
|
||||
use ff::{Field, PrimeField};
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
use super::GPUApiWrapper;
|
||||
use super::{GPUApiWrapper, FieldBinding, CUDA_KERNEL_MY_STRUCT};
|
||||
|
||||
fn evaluate_poly_cpu<F: Field>(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F {
|
||||
poly_coeffs
|
||||
@@ -74,4 +75,64 @@ mod tests {
|
||||
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_multiplication() -> Result<(), DriverError> {
|
||||
// setup GPU device
|
||||
let now = Instant::now();
|
||||
|
||||
let gpu = CudaDevice::new(0)?;
|
||||
|
||||
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
|
||||
|
||||
// compile ptx
|
||||
let now = Instant::now();
|
||||
|
||||
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
|
||||
gpu.load_ptx(ptx, "my_module", &["mul"])?;
|
||||
|
||||
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
|
||||
|
||||
let a = Fr::from(2);
|
||||
let b = Fr::TWO_INV;
|
||||
|
||||
println!("a * b : {:?}", a * b);
|
||||
|
||||
let a_data = FieldBinding {
|
||||
data: [2, 0, 0, 0]
|
||||
};
|
||||
|
||||
let b_data = FieldBinding {
|
||||
data: [
|
||||
0xa1f0fac9f8000001,
|
||||
0x9419f4243cdcb848,
|
||||
0xdc2822db40c0ac2e,
|
||||
0x183227397098d014,
|
||||
]
|
||||
};
|
||||
|
||||
// copy to GPU
|
||||
let gpu_field_structs = gpu.htod_copy(vec![a_data, b_data])?;
|
||||
let results = gpu.htod_copy(vec![FieldBinding::default(); 1024])?;
|
||||
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let f = gpu.get_func("my_module", "mul").unwrap();
|
||||
|
||||
unsafe { f.launch(LaunchConfig::for_num_elems(1024 as u32), (&gpu_field_structs, &results)) }?;
|
||||
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
|
||||
let results = gpu.sync_reclaim(results)?;
|
||||
|
||||
results.iter().for_each(|result| {
|
||||
assert_eq!(result.data[0], 1);
|
||||
assert_eq!(result.data[1], 0);
|
||||
assert_eq!(result.data[1], 0);
|
||||
assert_eq!(result.data[1], 0);
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user