Remove unnecessary synchronization and time consuming data generation from test code

This commit is contained in:
DoHoonKim8
2024-09-13 15:51:40 +00:00
parent 6ca58b8458
commit eb683eb1a8
2 changed files with 38 additions and 11 deletions

View File

@@ -178,14 +178,24 @@ mod tests {
#[test]
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
let num_vars = 10;
let num_polys = 4;
let max_degree = 4;
let num_polys = 3;
let max_degree = 3;
let rng = OsRng::default();
let combine_function = |args: &Vec<Fr>| args.iter().product();
let polys = (0..num_polys)
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
.map(|_| {
(0..1 << num_vars)
.map(|i| {
if i < 1024 {
Fr::random(rng)
} else {
Fr::from(i)
}
})
.collect_vec()
})
.collect_vec();
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
@@ -233,7 +243,6 @@ mod tests {
buf_view.borrow_mut(),
round_evals_view.borrow_mut(),
)?;
gpu_api_wrapper.gpu.synchronize()?;
println!(
"Time taken to eval_at_k_and_combine on gpu: {:.2?}",
now.elapsed()
@@ -257,7 +266,17 @@ mod tests {
let rng = OsRng::default();
let mut polys = (0..num_polys)
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
.map(|_| {
(0..1 << num_vars)
.map(|i| {
if i < 1024 {
Fr::random(rng)
} else {
Fr::from(i)
}
})
.collect_vec()
})
.collect_vec();
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
@@ -285,7 +304,6 @@ mod tests {
&mut gpu_polys.slice_mut(..),
&gpu_challenge.slice(..),
)?;
gpu_api_wrapper.gpu.synchronize()?;
println!(
"Time taken to fold_into_half_in_place on gpu: {:.2?}",
now.elapsed()
@@ -320,13 +338,23 @@ mod tests {
#[test]
fn test_prove_sumcheck() -> Result<(), DriverError> {
let num_vars = 12;
let num_vars = 23;
let num_polys = 4;
let max_degree = 4;
let rng = OsRng::default();
let polys = (0..num_polys)
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
.map(|_| {
(0..1 << num_vars)
.map(|i| {
if i < 1024 {
Fr::random(rng)
} else {
Fr::from(i)
}
})
.collect_vec()
})
.collect_vec();
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
@@ -371,7 +399,6 @@ mod tests {
&mut challenges.slice_mut(..),
round_evals_view,
)?;
gpu_api_wrapper.gpu.synchronize()?;
println!(
"Time taken to prove sumcheck on gpu : {:.2?}",
now.elapsed()

View File

@@ -79,11 +79,11 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
pub fn dtoh_sync_copy(
&self,
device_data: CudaView<FieldBinding>,
convert_to_montgomery_form: bool,
convert_from_montgomery_form: bool,
) -> Result<Vec<F>, DriverError> {
let host_data = self.gpu.dtoh_sync_copy(&device_data)?;
let mut target = vec![F::ZERO; host_data.len()];
if convert_to_montgomery_form {
if convert_from_montgomery_form {
parallelize(&mut target, |(target, start)| {
target
.iter_mut()