mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-08 23:18:00 -05:00
Add field binding conversion impl
This commit is contained in:
64
sumcheck/src/field.rs
Normal file
64
sumcheck/src/field.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use crate::FieldBinding;
|
||||
use ff::{Field, PrimeField};
|
||||
use halo2curves::{bn256::Fr, serde::SerdeObject};
|
||||
use itertools::Itertools;
|
||||
|
||||
pub trait FromFieldBinding<F> {
|
||||
fn from_canonical_form(b: FieldBinding) -> F;
|
||||
|
||||
fn from_montgomery_form(b: FieldBinding) -> F;
|
||||
}
|
||||
|
||||
pub trait ToFieldBinding<F> {
|
||||
fn to_canonical_form(f: F) -> FieldBinding;
|
||||
|
||||
fn to_montgomery_form(f: F) -> FieldBinding;
|
||||
}
|
||||
|
||||
macro_rules! field_binding_conversion {
|
||||
($field:ident) => {
|
||||
impl FromFieldBinding<$field> for $field {
|
||||
fn from_canonical_form(b: FieldBinding) -> $field {
|
||||
$field::from_raw(b.data)
|
||||
}
|
||||
|
||||
fn from_montgomery_form(b: FieldBinding) -> $field {
|
||||
let bytes = b
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|data| data.to_le_bytes())
|
||||
.collect_vec()
|
||||
.concat();
|
||||
$field::from_raw_bytes_unchecked(bytes.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFieldBinding<$field> for $field {
|
||||
fn to_canonical_form(f: $field) -> FieldBinding {
|
||||
let repr = f.to_repr();
|
||||
let bytes = repr.as_ref();
|
||||
let data = bytes
|
||||
.chunks(8)
|
||||
.map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap()))
|
||||
.collect_vec();
|
||||
FieldBinding {
|
||||
data: data.try_into().unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_montgomery_form(f: $field) -> FieldBinding {
|
||||
let mut buf = vec![];
|
||||
f.write_raw(&mut buf);
|
||||
let data = buf
|
||||
.chunks(8)
|
||||
.map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap()))
|
||||
.collect_vec();
|
||||
FieldBinding {
|
||||
data: data.try_into().unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
field_binding_conversion!(Fr);
|
||||
@@ -1,58 +1,108 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::time::Instant;
|
||||
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
|
||||
use cudarc::driver::{CudaDevice, DeviceRepr, DriverError, LaunchAsync, LaunchConfig};
|
||||
use cudarc::nvrtc::Ptx;
|
||||
use ff::{Field, PrimeField};
|
||||
use field::{FromFieldBinding, ToFieldBinding};
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use std::marker::PhantomData;
|
||||
use std::process::Output;
|
||||
use std::time::Instant;
|
||||
|
||||
pub mod field;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
// include the compiled PTX code as string
|
||||
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
|
||||
|
||||
unsafe impl DeviceRepr for FieldBinding {}
|
||||
impl Default for FieldBinding {
|
||||
fn default() -> Self {
|
||||
Self{ data: [0; 4] }
|
||||
Self { data: [0; 4] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField> From<F> for FieldBinding {
|
||||
fn from(value: F) -> Self {
|
||||
let repr = value.to_repr();
|
||||
let bytes = repr.as_ref();
|
||||
let data = bytes.chunks(8).map(|bytes| {
|
||||
u64::from_le_bytes(bytes.try_into().unwrap())
|
||||
}).collect_vec();
|
||||
FieldBinding {
|
||||
data: data.try_into().unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
// include the compiled PTX code as string
|
||||
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
|
||||
|
||||
/// Wrapper struct for APIs using GPU
|
||||
#[derive(Default)]
|
||||
pub struct GPUApiWrapper<F: Field>(PhantomData<F>);
|
||||
pub struct GPUApiWrapper<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>>(PhantomData<F>);
|
||||
|
||||
impl<F: Field> GPUApiWrapper<F> {
|
||||
pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result<F, DriverError> {
|
||||
todo!()
|
||||
impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
pub fn evaluate_poly(
|
||||
&self,
|
||||
num_vars: usize,
|
||||
poly_coeffs: &[F],
|
||||
point: &[F],
|
||||
) -> Result<F, DriverError> {
|
||||
// setup GPU device
|
||||
let now = Instant::now();
|
||||
|
||||
let gpu = CudaDevice::new(0)?;
|
||||
|
||||
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
|
||||
|
||||
// compile ptx
|
||||
let now = Instant::now();
|
||||
|
||||
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
|
||||
gpu.load_ptx(ptx, "multilinear", &["evaluate"])?;
|
||||
|
||||
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
|
||||
|
||||
let point = point
|
||||
.into_iter()
|
||||
.map(|f| F::to_canonical_form(*f))
|
||||
.collect_vec();
|
||||
|
||||
// copy to GPU
|
||||
let gpu_coeffs = gpu.htod_copy(
|
||||
poly_coeffs
|
||||
.into_iter()
|
||||
.map(|&coeff| F::to_canonical_form(coeff))
|
||||
.collect_vec(),
|
||||
)?;
|
||||
let gpu_eval_point = gpu.htod_copy(point)?;
|
||||
let monomial_evals = gpu.htod_copy(vec![FieldBinding::default(); 1 << num_vars])?;
|
||||
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
let f = gpu.get_func("multilinear", "evaluate").unwrap();
|
||||
|
||||
unsafe {
|
||||
f.launch(
|
||||
LaunchConfig::for_num_elems(1 << num_vars as u32),
|
||||
(&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals),
|
||||
)
|
||||
}?;
|
||||
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
|
||||
let monomial_evals = gpu.sync_reclaim(monomial_evals)?;
|
||||
|
||||
let result = monomial_evals
|
||||
.into_iter()
|
||||
.map(|eval| F::from_canonical_form(eval))
|
||||
.sum::<F>();
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{default, time::Instant};
|
||||
use std::{default, fmt::Error, time::Instant};
|
||||
|
||||
use cudarc::{driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, nvrtc::Ptx};
|
||||
use cudarc::{
|
||||
driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig},
|
||||
nvrtc::Ptx,
|
||||
};
|
||||
use ff::{Field, PrimeField};
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
use super::{GPUApiWrapper, FieldBinding, CUDA_KERNEL_MY_STRUCT};
|
||||
use super::{FieldBinding, GPUApiWrapper, CUDA_KERNEL_MY_STRUCT};
|
||||
|
||||
fn evaluate_poly_cpu<F: Field>(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F {
|
||||
poly_coeffs
|
||||
@@ -63,11 +113,11 @@ mod tests {
|
||||
F::ZERO
|
||||
} else {
|
||||
let indices = (0..num_vars).map(|j| (i >> j) & 1).collect_vec();
|
||||
let mut result = F::ONE;
|
||||
let mut result = coeff.clone();
|
||||
for (index, point) in indices.iter().zip(point.iter()) {
|
||||
result *= if *index == 1 { *point } else { F::ONE };
|
||||
}
|
||||
result * coeff
|
||||
result
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
@@ -75,17 +125,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_poly() -> Result<(), DriverError> {
|
||||
let num_vars = 16;
|
||||
let num_vars = 6;
|
||||
let rng = OsRng::default();
|
||||
let poly_coeffs = (0..1 << num_vars).map(|_| {
|
||||
Fr::random(rng)
|
||||
}).collect_vec();
|
||||
let point = (0..num_vars).map(|_| {
|
||||
Fr::random(rng)
|
||||
}).collect_vec();
|
||||
let poly_coeffs = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
let gpu_api_wrapper = GPUApiWrapper::<Fr>::default();
|
||||
let eval_poly_result_by_cpu = evaluate_poly_cpu(&poly_coeffs, &point, num_vars);
|
||||
let eval_poly_result_by_gpu = gpu_api_wrapper.evaluate_poly(&poly_coeffs, &point)?;
|
||||
let eval_poly_result_by_gpu =
|
||||
gpu_api_wrapper.evaluate_poly(num_vars, &poly_coeffs, &point)?;
|
||||
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
|
||||
Ok(())
|
||||
}
|
||||
@@ -112,9 +159,7 @@ mod tests {
|
||||
|
||||
println!("a * b : {:?}", a * b);
|
||||
|
||||
let a_data = FieldBinding {
|
||||
data: [2, 0, 0, 0]
|
||||
};
|
||||
let a_data = FieldBinding { data: [2, 0, 0, 0] };
|
||||
|
||||
let b_data = FieldBinding {
|
||||
data: [
|
||||
@@ -122,7 +167,7 @@ mod tests {
|
||||
0x9419f4243cdcb848,
|
||||
0xdc2822db40c0ac2e,
|
||||
0x183227397098d014,
|
||||
]
|
||||
],
|
||||
};
|
||||
|
||||
// copy to GPU
|
||||
@@ -135,7 +180,12 @@ mod tests {
|
||||
|
||||
let f = gpu.get_func("my_module", "mul").unwrap();
|
||||
|
||||
unsafe { f.launch(LaunchConfig::for_num_elems(1024 as u32), (&gpu_field_structs, &results)) }?;
|
||||
unsafe {
|
||||
f.launch(
|
||||
LaunchConfig::for_num_elems(1024 as u32),
|
||||
(&gpu_field_structs, &results),
|
||||
)
|
||||
}?;
|
||||
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user