diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index d4e8c50..ad25d1f 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -4,6 +4,7 @@ use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAs use cudarc::nvrtc::Ptx; use ff::{Field, PrimeField}; use halo2curves::bn256::Fr; +use itertools::Itertools; include!(concat!(env!("OUT_DIR"), "/bindings.rs")); @@ -17,6 +18,19 @@ impl Default for FieldBinding { } } +impl From for FieldBinding { + fn from(value: F) -> Self { + let repr = value.to_repr(); + let bytes = repr.as_ref(); + let data = bytes.chunks(8).map(|bytes| { + u64::from_le_bytes(bytes.try_into().unwrap()) + }).collect_vec(); + FieldBinding { + data: data.try_into().unwrap() + } + } +} + /// Wrapper struct for APIs using GPU #[derive(Default)] pub struct GPUApiWrapper(PhantomData);