Add field binding conversion impl

This commit is contained in:
DoHoonKim8
2024-06-19 17:11:17 +00:00
committed by DoHoon Kim
parent 1a57a0179d
commit 62b2444811
2 changed files with 155 additions and 41 deletions

64
sumcheck/src/field.rs Normal file
View File

@@ -0,0 +1,64 @@
use crate::FieldBinding;
use ff::{Field, PrimeField};
use halo2curves::{bn256::Fr, serde::SerdeObject};
use itertools::Itertools;
pub trait FromFieldBinding<F> {
fn from_canonical_form(b: FieldBinding) -> F;
fn from_montgomery_form(b: FieldBinding) -> F;
}
pub trait ToFieldBinding<F> {
fn to_canonical_form(f: F) -> FieldBinding;
fn to_montgomery_form(f: F) -> FieldBinding;
}
macro_rules! field_binding_conversion {
($field:ident) => {
impl FromFieldBinding<$field> for $field {
fn from_canonical_form(b: FieldBinding) -> $field {
$field::from_raw(b.data)
}
fn from_montgomery_form(b: FieldBinding) -> $field {
let bytes = b
.data
.into_iter()
.map(|data| data.to_le_bytes())
.collect_vec()
.concat();
$field::from_raw_bytes_unchecked(bytes.as_ref())
}
}
impl ToFieldBinding<$field> for $field {
fn to_canonical_form(f: $field) -> FieldBinding {
let repr = f.to_repr();
let bytes = repr.as_ref();
let data = bytes
.chunks(8)
.map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap()))
.collect_vec();
FieldBinding {
data: data.try_into().unwrap(),
}
}
fn to_montgomery_form(f: $field) -> FieldBinding {
let mut buf = vec![];
f.write_raw(&mut buf);
let data = buf
.chunks(8)
.map(|bytes| u64::from_le_bytes(bytes.try_into().unwrap()))
.collect_vec();
FieldBinding {
data: data.try_into().unwrap(),
}
}
}
};
}
field_binding_conversion!(Fr);

View File

@@ -1,58 +1,108 @@
use std::marker::PhantomData;
use std::time::Instant;
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
use cudarc::driver::{CudaDevice, DeviceRepr, DriverError, LaunchAsync, LaunchConfig};
use cudarc::nvrtc::Ptx;
use ff::{Field, PrimeField};
use field::{FromFieldBinding, ToFieldBinding};
use halo2curves::bn256::Fr;
use itertools::Itertools;
use std::marker::PhantomData;
use std::process::Output;
use std::time::Instant;
pub mod field;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
// include the compiled PTX code as string
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
unsafe impl DeviceRepr for FieldBinding {}
impl Default for FieldBinding {
fn default() -> Self {
Self{ data: [0; 4] }
Self { data: [0; 4] }
}
}
impl<F: PrimeField> From<F> for FieldBinding {
fn from(value: F) -> Self {
let repr = value.to_repr();
let bytes = repr.as_ref();
let data = bytes.chunks(8).map(|bytes| {
u64::from_le_bytes(bytes.try_into().unwrap())
}).collect_vec();
FieldBinding {
data: data.try_into().unwrap()
}
}
}
// include the compiled PTX code as string
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
/// Wrapper struct for APIs using GPU
#[derive(Default)]
pub struct GPUApiWrapper<F: Field>(PhantomData<F>);
pub struct GPUApiWrapper<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>>(PhantomData<F>);
impl<F: Field> GPUApiWrapper<F> {
pub fn evaluate_poly(&self, poly_coeffs: &[F], point: &[F]) -> Result<F, DriverError> {
todo!()
impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
pub fn evaluate_poly(
&self,
num_vars: usize,
poly_coeffs: &[F],
point: &[F],
) -> Result<F, DriverError> {
// setup GPU device
let now = Instant::now();
let gpu = CudaDevice::new(0)?;
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
// compile ptx
let now = Instant::now();
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
gpu.load_ptx(ptx, "multilinear", &["evaluate"])?;
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
let point = point
.into_iter()
.map(|f| F::to_canonical_form(*f))
.collect_vec();
// copy to GPU
let gpu_coeffs = gpu.htod_copy(
poly_coeffs
.into_iter()
.map(|&coeff| F::to_canonical_form(coeff))
.collect_vec(),
)?;
let gpu_eval_point = gpu.htod_copy(point)?;
let monomial_evals = gpu.htod_copy(vec![FieldBinding::default(); 1 << num_vars])?;
println!("Time taken to initialise data: {:.2?}", now.elapsed());
let now = Instant::now();
let f = gpu.get_func("multilinear", "evaluate").unwrap();
unsafe {
f.launch(
LaunchConfig::for_num_elems(1 << num_vars as u32),
(&gpu_coeffs, &gpu_eval_point, num_vars, &monomial_evals),
)
}?;
println!("Time taken to call kernel: {:.2?}", now.elapsed());
let monomial_evals = gpu.sync_reclaim(monomial_evals)?;
let result = monomial_evals
.into_iter()
.map(|eval| F::from_canonical_form(eval))
.sum::<F>();
Ok(result)
}
}
#[cfg(test)]
mod tests {
use std::{default, time::Instant};
use std::{default, fmt::Error, time::Instant};
use cudarc::{driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig}, nvrtc::Ptx};
use cudarc::{
driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig},
nvrtc::Ptx,
};
use ff::{Field, PrimeField};
use halo2curves::bn256::Fr;
use itertools::Itertools;
use rand::rngs::OsRng;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use super::{GPUApiWrapper, FieldBinding, CUDA_KERNEL_MY_STRUCT};
use super::{FieldBinding, GPUApiWrapper, CUDA_KERNEL_MY_STRUCT};
fn evaluate_poly_cpu<F: Field>(poly_coeffs: &[F], point: &[F], num_vars: usize) -> F {
poly_coeffs
@@ -63,11 +113,11 @@ mod tests {
F::ZERO
} else {
let indices = (0..num_vars).map(|j| (i >> j) & 1).collect_vec();
let mut result = F::ONE;
let mut result = coeff.clone();
for (index, point) in indices.iter().zip(point.iter()) {
result *= if *index == 1 { *point } else { F::ONE };
}
result * coeff
result
}
})
.sum()
@@ -75,17 +125,14 @@ mod tests {
#[test]
fn test_evaluate_poly() -> Result<(), DriverError> {
let num_vars = 16;
let num_vars = 6;
let rng = OsRng::default();
let poly_coeffs = (0..1 << num_vars).map(|_| {
Fr::random(rng)
}).collect_vec();
let point = (0..num_vars).map(|_| {
Fr::random(rng)
}).collect_vec();
let poly_coeffs = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();
let gpu_api_wrapper = GPUApiWrapper::<Fr>::default();
let eval_poly_result_by_cpu = evaluate_poly_cpu(&poly_coeffs, &point, num_vars);
let eval_poly_result_by_gpu = gpu_api_wrapper.evaluate_poly(&poly_coeffs, &point)?;
let eval_poly_result_by_gpu =
gpu_api_wrapper.evaluate_poly(num_vars, &poly_coeffs, &point)?;
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
Ok(())
}
@@ -112,9 +159,7 @@ mod tests {
println!("a * b : {:?}", a * b);
let a_data = FieldBinding {
data: [2, 0, 0, 0]
};
let a_data = FieldBinding { data: [2, 0, 0, 0] };
let b_data = FieldBinding {
data: [
@@ -122,7 +167,7 @@ mod tests {
0x9419f4243cdcb848,
0xdc2822db40c0ac2e,
0x183227397098d014,
]
],
};
// copy to GPU
@@ -135,7 +180,12 @@ mod tests {
let f = gpu.get_func("my_module", "mul").unwrap();
unsafe { f.launch(LaunchConfig::for_num_elems(1024 as u32), (&gpu_field_structs, &results)) }?;
unsafe {
f.launch(
LaunchConfig::for_num_elems(1024 as u32),
(&gpu_field_structs, &results),
)
}?;
println!("Time taken to call kernel: {:.2?}", now.elapsed());