From ffaf9c5f479ee9054f0ef3be6ac52b1bfd632252 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Wed, 30 Oct 2024 10:20:15 +0000 Subject: [PATCH] Implement in-memory transcript --- .gitignore | 1 + sumcheck/Cargo.toml | 2 + sumcheck/src/cpu/sumcheck.rs | 39 ++- sumcheck/src/gpu/cuda/kernels/sumcheck.cu | 15 +- sumcheck/src/gpu/sumcheck.rs | 321 +++++++++++++--------- sumcheck/src/lib.rs | 23 ++ sumcheck/src/transcript.rs | 80 ++++++ sumcheck/src/utils.rs | 2 + sumcheck/src/utils/arithmetic.rs | 18 ++ sumcheck/src/utils/hash.rs | 31 +++ 10 files changed, 387 insertions(+), 145 deletions(-) create mode 100644 sumcheck/src/transcript.rs create mode 100644 sumcheck/src/utils.rs create mode 100644 sumcheck/src/utils/arithmetic.rs create mode 100644 sumcheck/src/utils/hash.rs diff --git a/.gitignore b/.gitignore index 2f26eaf..bd21661 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ Cargo.lock *.o +.vscode diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 76115e6..5942041 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -14,6 +14,8 @@ itertools = "0.10.5" halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.3", package = "halo2curves" } rand = "0.8" num-integer = "0.1.45" +sha3 = "0.10.6" +num-bigint = "0.4.3" [build-dependencies] bindgen = "0.66.1" diff --git a/sumcheck/src/cpu/sumcheck.rs b/sumcheck/src/cpu/sumcheck.rs index ecfb40e..9f8b176 100644 --- a/sumcheck/src/cpu/sumcheck.rs +++ b/sumcheck/src/cpu/sumcheck.rs @@ -1,7 +1,10 @@ use ff::PrimeField; use itertools::Itertools; -use crate::cpu::{arithmetic::barycentric_weights, parallel::parallelize}; +use crate::{ + cpu::{arithmetic::barycentric_weights, parallel::parallelize}, + transcript::Keccak256Transcript, +}; use super::arithmetic::barycentric_interpolate; @@ -36,6 +39,40 @@ pub(crate) fn fold_into_half_in_place(poly: &mut [F], challenge: }); } +pub(crate) fn verify_sumcheck_transcript( + num_vars: usize, + max_degree: usize, + sum: F, + transcript: &mut Keccak256Transcript, +) -> bool { + let points_vec: Vec = (0..max_degree + 1) + .map(|i| F::from_u128(i as u128)) + .collect(); + let weights = barycentric_weights(&points_vec); + let mut expected_sum = sum; + for round_index in 0..num_vars { + let evals = transcript.read_field_elements(max_degree + 1).unwrap(); + let round_poly_eval_at_0 = evals[0]; + let round_poly_eval_at_1 = evals[1]; + + let computed_sum = round_poly_eval_at_0 + round_poly_eval_at_1; + + // Check r_{i}(α_i) == r_{i+1}(0) + r_{i+1}(1) + if computed_sum != expected_sum { + println!("computed_sum : {:?}", computed_sum); + println!("expected_sum : {:?}", expected_sum); + println!("round index : {}", round_index); + return false; + } + + let challenge = transcript.squeeze_challenge(); + + // Compute r_{i}(α_i) using barycentric interpolation + expected_sum = barycentric_interpolate(&weights, &points_vec, &evals, &challenge); + } + true +} + pub(crate) fn verify_sumcheck( num_vars: usize, max_degree: usize, diff --git a/sumcheck/src/gpu/cuda/kernels/sumcheck.cu b/sumcheck/src/gpu/cuda/kernels/sumcheck.cu index dca7105..e6c8c94 100644 --- a/sumcheck/src/gpu/cuda/kernels/sumcheck.cu +++ b/sumcheck/src/gpu/cuda/kernels/sumcheck.cu @@ -32,16 +32,16 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign } extern "C" __global__ void fold_into_half( - unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge + unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* eval_point ) { int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x; const int stride = 1 << (num_vars - 1); const int buf_offset = (blockIdx.x / num_blocks_per_poly) * stride; const int poly_offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size; while (tid < stride) { - if (*challenge == fr::zero()) buf[buf_offset + tid] = polys[poly_offset + tid]; - else if (*challenge == fr::one()) buf[buf_offset + tid] = polys[poly_offset + tid + stride]; - else buf[buf_offset + tid] = (*challenge) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid]; + if (*eval_point == fr::zero()) buf[buf_offset + tid] = polys[poly_offset + tid]; + else if (*eval_point == fr::one()) buf[buf_offset + tid] = polys[poly_offset + tid + stride]; + else buf[buf_offset + tid] = (*eval_point) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid]; tid += blockDim.x * num_blocks_per_poly; } } @@ -58,10 +58,3 @@ extern "C" __global__ void fold_into_half_in_place( tid += blockDim.x * num_blocks_per_poly; } } - -// TODO : Pass transcript and squeeze random challenge using hash function -extern "C" __global__ void squeeze_challenge(fr* challenges, unsigned int index) { - if (threadIdx.x == 0) { - challenges[index] = fr(1034); - } -} diff --git a/sumcheck/src/gpu/sumcheck.rs b/sumcheck/src/gpu/sumcheck.rs index 502e0fe..4d1c046 100644 --- a/sumcheck/src/gpu/sumcheck.rs +++ b/sumcheck/src/gpu/sumcheck.rs @@ -1,11 +1,12 @@ use std::cell::{RefCell, RefMut}; -use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig}; +use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, LaunchAsync, LaunchConfig}; use ff::PrimeField; use crate::{ fieldbinding::{FromFieldBinding, ToFieldBinding}, - FieldBinding, GPUApiWrapper, + transcript::Keccak256Transcript, + FieldBinding, GPUApiWrapper, LibraryError, }; impl + ToFieldBinding> GPUApiWrapper { @@ -18,9 +19,10 @@ impl + ToFieldBinding> GPUApiWrapper { polys: &mut CudaViewMut, device_ks: &[CudaView], buf: RefCell>, - challenges: &mut CudaViewMut, + challenge: &mut CudaSlice, round_evals: RefCell>, - ) -> Result<(), DriverError> { + transcript: &mut Keccak256Transcript, + ) -> Result<(), LibraryError> { let initial_poly_num_vars = num_vars; for round in 0..num_vars { self.eval_at_k_and_combine( @@ -32,16 +34,19 @@ impl + ToFieldBinding> GPUApiWrapper { device_ks, buf.borrow_mut(), round_evals.borrow_mut(), + transcript, )?; // squeeze challenge - self.squeeze_challenge(round, challenges)?; + let alpha = vec![transcript.squeeze_challenge()]; + self.overwrite_to_device(alpha.as_slice(), challenge) + .map_err(|e| LibraryError::Driver(e))?; // fold_into_half_in_place self.fold_into_half_in_place( initial_poly_num_vars, round, num_polys, polys, - &challenges.slice(round..round + 1), + &challenge, )?; } Ok(()) @@ -57,8 +62,13 @@ impl + ToFieldBinding> GPUApiWrapper { device_ks: &[CudaView], mut buf: RefMut>, mut round_evals: RefMut>, - ) -> Result<(), DriverError> { - let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?; + transcript: &mut Keccak256Transcript, + ) -> Result<(), LibraryError> { + let num_blocks_per_poly = self + .max_blocks_per_sm() + .map_err(|e| LibraryError::Driver(e))? + / num_polys + * self.num_sm().map_err(|e| LibraryError::Driver(e))?; let num_threads_per_block = 1024; for k in 0..max_degree + 1 { let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap(); @@ -68,17 +78,19 @@ impl + ToFieldBinding> GPUApiWrapper { shared_mem_bytes: 0, }; unsafe { - fold_into_half.launch( - launch_config, - ( - initial_poly_num_vars - round, - 1 << initial_poly_num_vars, - num_blocks_per_poly, - polys, - &mut *buf, - &device_ks[k], - ), - )?; + fold_into_half + .launch( + launch_config, + ( + initial_poly_num_vars - round, + 1 << initial_poly_num_vars, + num_blocks_per_poly, + polys, + &mut *buf, + &device_ks[k], + ), + ) + .map_err(|e| LibraryError::Driver(e))?; }; let size = 1 << (initial_poly_num_vars - round - 1); let combine = self.gpu.get_func("sumcheck", "combine").unwrap(); @@ -88,7 +100,9 @@ impl + ToFieldBinding> GPUApiWrapper { shared_mem_bytes: 0, }; unsafe { - combine.launch(launch_config, (&mut *buf, size, num_polys))?; + combine + .launch(launch_config, (&mut *buf, size, num_polys)) + .map_err(|e| LibraryError::Driver(e))?; }; let sum = self.gpu.get_func("sumcheck", "sum").unwrap(); let launch_config = LaunchConfig { @@ -105,26 +119,14 @@ impl + ToFieldBinding> GPUApiWrapper { size >> 1, round * (max_degree + 1) + k, ), - )?; + ) + .map_err(|e| LibraryError::Driver(e))?; }; } - Ok(()) - } - - pub(crate) fn squeeze_challenge( - &self, - round: usize, - challenges: &mut CudaViewMut, - ) -> Result<(), DriverError> { - let squeeze_challenge = self.gpu.get_func("sumcheck", "squeeze_challenge").unwrap(); - let launch_config = LaunchConfig { - grid_dim: (1, 1, 1), - block_dim: (1, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { - squeeze_challenge.launch(launch_config, (challenges, round))?; - } + let fes = self + .dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true) + .map_err(|e| LibraryError::Driver(e))?; + transcript.write_field_elements(&fes)?; Ok(()) } @@ -134,13 +136,17 @@ impl + ToFieldBinding> GPUApiWrapper { round: usize, num_polys: usize, polys: &mut CudaViewMut, - challenge: &CudaView, - ) -> Result<(), DriverError> { + challenge: &CudaSlice, + ) -> Result<(), LibraryError> { let fold_into_half_in_place = self .gpu .get_func("sumcheck", "fold_into_half_in_place") .unwrap(); - let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?; + let num_blocks_per_poly = self + .max_blocks_per_sm() + .map_err(|e| LibraryError::Driver(e))? + / num_polys + * self.num_sm().map_err(|e| LibraryError::Driver(e))?; let num_threads_per_block = 1024; let launch_config = LaunchConfig { grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1), @@ -148,16 +154,18 @@ impl + ToFieldBinding> GPUApiWrapper { shared_mem_bytes: 0, }; unsafe { - fold_into_half_in_place.launch( - launch_config, - ( - initial_poly_num_vars - round, - 1 << initial_poly_num_vars, - num_blocks_per_poly, - polys, - challenge, - ), - )?; + fold_into_half_in_place + .launch( + launch_config, + ( + initial_poly_num_vars - round, + 1 << initial_poly_num_vars, + num_blocks_per_poly, + polys, + challenge, + ), + ) + .map_err(|e| LibraryError::Driver(e))?; }; Ok(()) } @@ -167,16 +175,19 @@ impl + ToFieldBinding> GPUApiWrapper { mod tests { use std::{cell::RefCell, time::Instant}; - use cudarc::{driver::DriverError, nvrtc::Ptx}; + use cudarc::nvrtc::Ptx; use ff::Field; use halo2curves::bn256::Fr; use itertools::Itertools; use rand::rngs::OsRng; - use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX}; + use crate::{ + cpu, fieldbinding::ToFieldBinding, transcript::Keccak256Transcript, GPUApiWrapper, + LibraryError, MULTILINEAR_PTX, SUMCHECK_PTX, + }; #[test] - fn test_eval_at_k_and_combine() -> Result<(), DriverError> { + fn test_eval_at_k_and_combine() -> Result<(), LibraryError> { let num_vars = 10; let num_polys = 3; let max_degree = 3; @@ -198,17 +209,24 @@ mod tests { }) .collect_vec(); - let mut gpu_api_wrapper = GPUApiWrapper::::setup()?; - gpu_api_wrapper.gpu.load_ptx( - Ptx::from_src(MULTILINEAR_PTX), - "multilinear", - &["convert_to_montgomery_form"], - )?; - gpu_api_wrapper.gpu.load_ptx( - Ptx::from_src(SUMCHECK_PTX), - "sumcheck", - &["fold_into_half", "combine", "sum"], - )?; + let mut gpu_api_wrapper = + GPUApiWrapper::::setup().map_err(|e| LibraryError::Driver(e))?; + gpu_api_wrapper + .gpu + .load_ptx( + Ptx::from_src(MULTILINEAR_PTX), + "multilinear", + &["convert_to_montgomery_form"], + ) + .map_err(|e| LibraryError::Driver(e))?; + gpu_api_wrapper + .gpu + .load_ptx( + Ptx::from_src(SUMCHECK_PTX), + "sumcheck", + &["fold_into_half", "combine", "sum"], + ) + .map_err(|e| LibraryError::Driver(e))?; let mut cpu_round_evals = vec![]; let now = Instant::now(); @@ -227,19 +245,27 @@ mod tests { ); // copy polynomials to device - let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?; + let gpu_polys = gpu_api_wrapper + .copy_to_device(&polys.concat()) + .map_err(|e| LibraryError::Driver(e))?; let device_ks = (0..max_degree + 1) .map(|k| { gpu_api_wrapper .gpu .htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))]) }) - .collect::, _>>()?; - let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; + .collect::, _>>() + .map_err(|e| LibraryError::Driver(e))?; + let mut buf = gpu_api_wrapper + .malloc_on_device(num_polys << (num_vars - 1)) + .map_err(|e| LibraryError::Driver(e))?; let buf_view = RefCell::new(buf.slice_mut(..)); - let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?; + let mut round_evals = gpu_api_wrapper + .malloc_on_device(max_degree as usize + 1) + .map_err(|e| LibraryError::Driver(e))?; let round_evals_view = RefCell::new(round_evals.slice_mut(..)); let round = 0; + let mut transcript = Keccak256Transcript::::new(); let now = Instant::now(); gpu_api_wrapper.eval_at_k_and_combine( num_vars, @@ -253,13 +279,15 @@ mod tests { .collect_vec(), buf_view.borrow_mut(), round_evals_view.borrow_mut(), + &mut transcript, )?; println!( "Time taken to eval_at_k_and_combine on gpu: {:.2?}", now.elapsed() ); let gpu_round_evals = gpu_api_wrapper - .dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?; + .dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true) + .map_err(|e| LibraryError::Driver(e))?; cpu_round_evals .iter() .zip_eq(gpu_round_evals.iter()) @@ -271,7 +299,7 @@ mod tests { } #[test] - fn test_fold_into_half_in_place() -> Result<(), DriverError> { + fn test_fold_into_half_in_place() -> Result<(), LibraryError> { let num_vars = 6; let num_polys = 4; @@ -290,21 +318,33 @@ mod tests { }) .collect_vec(); - let mut gpu_api_wrapper = GPUApiWrapper::::setup()?; - gpu_api_wrapper.gpu.load_ptx( - Ptx::from_src(MULTILINEAR_PTX), - "multilinear", - &["convert_to_montgomery_form"], - )?; - gpu_api_wrapper.gpu.load_ptx( - Ptx::from_src(SUMCHECK_PTX), - "sumcheck", - &["fold_into_half_in_place"], - )?; + let mut gpu_api_wrapper = + GPUApiWrapper::::setup().map_err(|e| LibraryError::Driver(e))?; + gpu_api_wrapper + .gpu + .load_ptx( + Ptx::from_src(MULTILINEAR_PTX), + "multilinear", + &["convert_to_montgomery_form"], + ) + .map_err(|e| LibraryError::Driver(e))?; + gpu_api_wrapper + .gpu + .load_ptx( + Ptx::from_src(SUMCHECK_PTX), + "sumcheck", + &["fold_into_half_in_place"], + ) + .map_err(|e| LibraryError::Driver(e))?; // copy polynomials to device - let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?; + let mut gpu_polys = gpu_api_wrapper + .copy_to_device(&polys.concat()) + .map_err(|e| LibraryError::Driver(e))?; let challenge = Fr::random(rng); - let gpu_challenge = gpu_api_wrapper.copy_to_device(&vec![challenge])?; + + let gpu_challenge = gpu_api_wrapper + .copy_to_device(&vec![challenge]) + .map_err(|e| LibraryError::Driver(e))?; let round = 0; let now = Instant::now(); @@ -313,7 +353,7 @@ mod tests { round, num_polys, &mut gpu_polys.slice_mut(..), - &gpu_challenge.slice(..), + &gpu_challenge, )?; println!( "Time taken to fold_into_half_in_place on gpu: {:.2?}", @@ -327,7 +367,8 @@ mod tests { true, ) }) - .collect::>, _>>()?; + .collect::>, _>>() + .map_err(|e| LibraryError::Driver(e))?; let now = Instant::now(); (0..num_polys) @@ -348,8 +389,8 @@ mod tests { } #[test] - fn test_prove_sumcheck() -> Result<(), DriverError> { - let num_vars = 25; + fn test_prove_sumcheck() -> Result<(), LibraryError> { + let num_vars = 4; let num_polys = 2; let max_degree = 2; @@ -368,26 +409,34 @@ mod tests { }) .collect_vec(); - let mut gpu_api_wrapper = GPUApiWrapper::::setup()?; - gpu_api_wrapper.gpu.load_ptx( - Ptx::from_src(MULTILINEAR_PTX), - "multilinear", - &["convert_to_montgomery_form"], - )?; - gpu_api_wrapper.gpu.load_ptx( - Ptx::from_src(SUMCHECK_PTX), - "sumcheck", - &[ - "fold_into_half", - "fold_into_half_in_place", - "combine", - "sum", - "squeeze_challenge", - ], - )?; - + let mut gpu_api_wrapper = + GPUApiWrapper::::setup().map_err(|e| LibraryError::Driver(e))?; + gpu_api_wrapper + .gpu + .load_ptx( + Ptx::from_src(MULTILINEAR_PTX), + "multilinear", + &["convert_to_montgomery_form"], + ) + .map_err(|e| LibraryError::Driver(e))?; + gpu_api_wrapper + .gpu + .load_ptx( + Ptx::from_src(SUMCHECK_PTX), + "sumcheck", + &[ + "fold_into_half", + "fold_into_half_in_place", + "combine", + "sum", + ], + ) + .map_err(|e| LibraryError::Driver(e))?; + let mut transcript = Keccak256Transcript::::new(); let now = Instant::now(); - let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?; + let mut gpu_polys = gpu_api_wrapper + .copy_to_device(&polys.concat()) + .map_err(|e| LibraryError::Driver(e))?; let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| { acc + polys.iter().map(|poly| poly[index]).product::() }); @@ -397,12 +446,19 @@ mod tests { .gpu .htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))]) }) - .collect::, _>>()?; - let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?; + .collect::, _>>() + .map_err(|e| LibraryError::Driver(e))?; + let mut buf = gpu_api_wrapper + .malloc_on_device(num_polys << (num_vars - 1)) + .map_err(|e| LibraryError::Driver(e))?; let buf_view = RefCell::new(buf.slice_mut(..)); - let mut challenges = gpu_api_wrapper.malloc_on_device(num_vars)?; - let mut round_evals = gpu_api_wrapper.malloc_on_device(num_vars * (max_degree + 1))?; + let mut challenges = gpu_api_wrapper + .malloc_on_device(1) + .map_err(|e| LibraryError::Driver(e))?; + let mut round_evals = gpu_api_wrapper + .malloc_on_device(num_vars * (max_degree + 1)) + .map_err(|e| LibraryError::Driver(e))?; let round_evals_view = RefCell::new(round_evals.slice_mut(..)); println!("Time taken to copy data to device : {:.2?}", now.elapsed()); @@ -418,35 +474,34 @@ mod tests { .map(|device_k| device_k.slice(..)) .collect_vec(), buf_view, - &mut challenges.slice_mut(..), + &mut challenges, round_evals_view, + &mut transcript, )?; - gpu_api_wrapper.gpu.synchronize()?; + gpu_api_wrapper + .gpu + .synchronize() + .map_err(|e| LibraryError::Driver(e))?; println!( "Time taken to prove sumcheck on gpu : {:.2?}", now.elapsed() ); - let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?; - let round_evals = (0..num_vars) - .map(|i| { - gpu_api_wrapper.dtoh_sync_copy( - &round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)), - true, - ) - }) - .collect::>, _>>()?; - let round_evals = round_evals - .iter() - .map(|round_evals| round_evals.as_slice()) - .collect_vec(); - let result = cpu::sumcheck::verify_sumcheck( - num_vars, - max_degree, - sum, - &challenges[..], - &round_evals[..], - ); + // let challenges = gpu_api_wrapper + // .dtoh_sync_copy(&challenges.slice(..), true) + // .map_err(|e| LibraryError::Driver(e))?; + // let round_evals = (0..num_vars) + // .map(|i| { + // gpu_api_wrapper.dtoh_sync_copy( + // &round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)), + // true, + // ) + // }) + // .collect::>, _>>() + // .map_err(|e| LibraryError::Driver(e))?; + + let result = + cpu::sumcheck::verify_sumcheck_transcript(num_vars, max_degree, sum, &mut transcript); assert!(result); Ok(()) } diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 80ccdf5..ed453fa 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -15,6 +15,8 @@ use std::time::Instant; mod cpu; pub mod fieldbinding; pub mod gpu; +pub mod transcript; +pub mod utils; include!(concat!(env!("OUT_DIR"), "/bindings.rs")); @@ -71,6 +73,21 @@ impl + ToFieldBinding> GPUApiWrapper { Ok(device_data) } + pub fn overwrite_to_device( + &self, + host_data: &[F], + dst: &mut CudaSlice, + ) -> Result<(), DriverError> { + self.gpu.htod_copy_into( + host_data + .into_iter() + .map(|&eval| F::to_canonical_form(eval)) + .collect(), + dst, + )?; + Ok(()) + } + pub fn malloc_on_device(&self, len: usize) -> Result, DriverError> { let device_ptr = unsafe { self.gpu.alloc::(len)? }; Ok(device_ptr) @@ -123,3 +140,9 @@ impl + ToFieldBinding> GPUApiWrapper { Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)? as usize) } } + +#[derive(Debug)] +pub enum LibraryError { + Driver(DriverError), + Transcript(std::io::ErrorKind, String), +} diff --git a/sumcheck/src/transcript.rs b/sumcheck/src/transcript.rs new file mode 100644 index 0000000..8fbdba1 --- /dev/null +++ b/sumcheck/src/transcript.rs @@ -0,0 +1,80 @@ +use crate::{ + fieldbinding::{FromFieldBinding, ToFieldBinding}, + utils::{arithmetic::fe_mod_from_le_bytes, hash::Hash}, + GPUApiWrapper, LibraryError, +}; +use ff::PrimeField; +use halo2curves::serde::SerdeObject; +use sha3::Keccak256; +use std::{ + io::{Cursor, Read, Write}, + marker::PhantomData, +}; + +pub type Keccak256Transcript = CudaTranscript; + +#[derive(Debug, Default)] +pub struct CudaTranscript { + pub stream: Cursor>, + pub state: H, + marker: PhantomData, +} + +impl CudaTranscript { + pub fn new() -> Self { + Self::default() + } + + pub fn squeeze_challenge(&mut self) -> F { + let hash = self.state.finalize_fixed_reset(); + self.state.update(&hash); + fe_mod_from_le_bytes(hash) + } + + fn common_field_element(&mut self, fe: &F) -> Result<(), LibraryError> { + self.state.update_field_element(fe); + Ok(()) + } + + fn write_field_element(&mut self, fe: &F) -> Result<(), LibraryError> { + self.common_field_element(fe)?; + let mut repr = fe.to_repr(); + repr.as_mut().reverse(); + self.stream + .write_all(repr.as_ref()) + .map_err(|err| LibraryError::Transcript(err.kind(), err.to_string())) + } + + pub fn write_field_elements<'a>( + &mut self, + fes: impl IntoIterator, + ) -> Result<(), LibraryError> + where + F: 'a, + { + for fe in fes.into_iter() { + self.write_field_element(fe)?; + } + Ok(()) + } + + fn read_field_element(&mut self) -> Result { + let mut repr = ::Repr::default(); + self.stream + .read_exact(repr.as_mut()) + .map_err(|err| LibraryError::Transcript(err.kind(), err.to_string()))?; + repr.as_mut().reverse(); + let fe = F::from_repr_vartime(repr).ok_or_else(|| { + LibraryError::Transcript( + std::io::ErrorKind::Other, + "Invalid field element encoding in proof".to_string(), + ) + })?; + self.common_field_element(&fe)?; + Ok(fe) + } + + pub fn read_field_elements(&mut self, n: usize) -> Result, LibraryError> { + (0..n).map(|_| self.read_field_element()).collect() + } +} diff --git a/sumcheck/src/utils.rs b/sumcheck/src/utils.rs new file mode 100644 index 0000000..71deeb0 --- /dev/null +++ b/sumcheck/src/utils.rs @@ -0,0 +1,2 @@ +pub mod arithmetic; +pub mod hash; diff --git a/sumcheck/src/utils/arithmetic.rs b/sumcheck/src/utils/arithmetic.rs new file mode 100644 index 0000000..10655fa --- /dev/null +++ b/sumcheck/src/utils/arithmetic.rs @@ -0,0 +1,18 @@ +use ff::PrimeField; +use num_bigint::BigUint; + +pub fn modulus() -> BigUint { + BigUint::from_bytes_le((-F::ONE).to_repr().as_ref()) + 1u64 +} + +pub fn fe_mod_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { + fe_from_le_bytes((BigUint::from_bytes_le(bytes.as_ref()) % modulus::()).to_bytes_le()) +} + +pub fn fe_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { + let bytes = bytes.as_ref(); + let mut repr = F::Repr::default(); + assert!(bytes.len() <= repr.as_ref().len()); + repr.as_mut()[..bytes.len()].copy_from_slice(bytes); + F::from_repr(repr).unwrap() +} diff --git a/sumcheck/src/utils/hash.rs b/sumcheck/src/utils/hash.rs new file mode 100644 index 0000000..d3a265e --- /dev/null +++ b/sumcheck/src/utils/hash.rs @@ -0,0 +1,31 @@ +use ff::PrimeField; +use sha3::digest::{Digest, HashMarker}; +use std::fmt::Debug; + +pub use sha3::{ + digest::{FixedOutputReset, Output, Update}, + Keccak256, +}; + +pub trait Hash: + 'static + Sized + Clone + Debug + FixedOutputReset + Default + Update + HashMarker +{ + fn new() -> Self { + Self::default() + } + + fn update_field_element(&mut self, field: &impl PrimeField) { + Digest::update(self, field.to_repr()); + } + + fn digest(data: impl AsRef<[u8]>) -> Output { + let mut hasher = Self::default(); + hasher.update(data.as_ref()); + hasher.finalize() + } +} + +impl Hash + for T +{ +}