diff --git a/Cargo.lock b/Cargo.lock index 35248a8..d2aced6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3861,6 +3861,7 @@ dependencies = [ "parking_lot 0.12.4", "prost", "rand 0.8.5", + "rayon", "rln", "rln_proof", "rocksdb", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 845ddeb..cad67d3 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -44,6 +44,7 @@ rln = { git = "https://github.com/vacp2p/zerokit", features = ["pmtree-ft"] } zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] } rln_proof = { path = "../rln_proof" } smart_contract = { path = "../smart_contract" } +rayon = "1.7" [build-dependencies] tonic-build = "*" diff --git a/prover/benches/prover_bench.rs b/prover/benches/prover_bench.rs index 8073a42..38d805b 100644 --- a/prover/benches/prover_bench.rs +++ b/prover/benches/prover_bench.rs @@ -1,4 +1,4 @@ -use criterion::BenchmarkId; +use criterion::{BenchmarkId, Throughput}; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; @@ -134,7 +134,7 @@ fn proof_generation_bench(c: &mut Criterion) { metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), metrics_port: 30051, broadcast_channel_size: 100, - proof_service_count: 32, + proof_service_count: 4, transaction_channel_size: 100, proof_sender_channel_size: 100, }; @@ -167,9 +167,15 @@ fn proof_generation_bench(c: &mut Criterion) { }); println!("Starting benchmark..."); - let size: usize = 1024; - let proof_count = 100; - c.bench_with_input(BenchmarkId::new("input_example", size), &size, |b, &_s| { + // let size: usize = 1024; + + let mut group = c.benchmark_group("prover_bench"); + // group.sampling_mode(criterion::SamplingMode::Flat); + + let proof_count = 5; + + group.throughput(Throughput::Elements(proof_count as u64)); + group.bench_with_input(BenchmarkId::new("proof generation", proof_count), &proof_count, |b, &_s| { b.to_async(&rt).iter(|| { async { let mut set = JoinSet::new(); @@ -182,13 +188,16 @@ fn proof_generation_bench(c: &mut Criterion) { } }); }); + + group.finish(); } criterion_group!( name = benches; config = Criterion::default() - .sample_size(10) - .measurement_time(Duration::from_secs(150)); + .sample_size(20) + // .measurement_time(Duration::from_secs(150)) + ; targets = proof_generation_bench ); criterion_main!(benches); diff --git a/prover/src/proof_service.rs b/prover/src/proof_service.rs index 9d9fd95..cc4a161 100644 --- a/prover/src/proof_service.rs +++ b/prover/src/proof_service.rs @@ -8,7 +8,9 @@ use metrics::{counter, histogram}; use parking_lot::RwLock; use rln::hashers::hash_to_field; use rln::protocol::serialize_proof_values; -use tracing::{Instrument, debug, debug_span, info}; +use tracing::{Instrument, // debug, + debug_span, info +}; // internal use crate::epoch_service::{Epoch, EpochSlice}; use crate::error::{AppError, ProofGenerationError, ProofGenerationStringError}; @@ -71,12 +73,18 @@ impl ProofService { let proof_generation_data_ = proof_generation_data.clone(); let rate_limit = self.rate_limit; - // let counter_label = Arc::new(format!("proof service (id: {})", self.id)); - // let counter_label_ref = counter_label.clone(); let counter_id = self.id; + // println!("[proof service {counter_id}] starting to generate proof..."); + + // Communicate between rayon & current task + let (send, recv) = tokio::sync::oneshot::channel(); + + // Move to a task (as generating the proof can take quite some time) - avoid blocking the tokio runtime + // Note: avoid tokio spawn_blocking as it does not perform great for CPU bounds tasks + // see https://ryhl.io/blog/async-what-is-blocking/ + + rayon::spawn(move || { - // Move to a task (as generating the proof can take quite some time) - let blocking_task = tokio::task::spawn_blocking(move || { let proof_generation_start = std::time::Instant::now(); let message_id = { @@ -102,41 +110,65 @@ impl ProofService { }; let epoch = hash_to_field(epoch_bytes.as_slice()); - let merkle_proof = user_db.get_merkle_proof(&proof_generation_data.tx_sender)?; + let merkle_proof = match user_db.get_merkle_proof(&proof_generation_data.tx_sender) { + Ok(merkle_proof) => merkle_proof, + Err(e) => { + let _ = send.send(Err(ProofGenerationError::MerkleProofError(e))); + return; + } + }; - let (proof, proof_values) = compute_rln_proof_and_values( + // let compute_proof_start = std::time::Instant::now(); + let (proof, proof_values) = match compute_rln_proof_and_values( &proof_generation_data.user_identity, &proof_generation_data.rln_identifier, rln_data, epoch, &merkle_proof, - ) - .map_err(ProofGenerationError::Proof)?; + ) { + Ok((proof, proof_values)) => (proof, proof_values), + Err(e) => { + let _ = send.send(Err(ProofGenerationError::Proof(e))); + return; + } + }; - debug!("proof: {:?}", proof); - debug!("proof_values: {:?}", proof_values); + // debug!("proof: {:?}", proof); + // debug!("proof_values: {:?}", proof_values); // Serialize proof let mut output_buffer = Cursor::new(Vec::with_capacity(PROOF_SIZE)); - proof + if let Err(e) = proof .serialize_compressed(&mut output_buffer) - .map_err(ProofGenerationError::Serialization)?; - output_buffer - .write_all(&serialize_proof_values(&proof_values)) - .map_err(ProofGenerationError::SerializationWrite)?; + { + let _ = send.send(Err(ProofGenerationError::Serialization(e))); + return; + } + if let Err(e) = output_buffer + .write_all(&serialize_proof_values(&proof_values)) { + let _ = send.send(Err(ProofGenerationError::SerializationWrite(e))); + return; + } histogram!(PROOF_SERVICE_GEN_PROOF_TIME.name, "prover" => "proof service") - .record(proof_generation_start.elapsed().as_secs_f64()); + .record(proof_generation_start.elapsed().as_secs_f64()); + // println!("[proof service {counter_id}] proof generation time: {:?} secs", proof_generation_start.elapsed().as_secs_f64()); let labels = [("prover", format!("proof service id: {counter_id}"))]; counter!(PROOF_SERVICE_PROOF_COMPUTED.name, &labels).increment(1); - Ok::, ProofGenerationError>(output_buffer.into_inner()) + // Send the result back to Tokio. + let _ = send.send( + Ok::, ProofGenerationError>(output_buffer.into_inner()) + ); }); - let result = blocking_task.instrument(debug_span!("compute proof")).await; - // Result (1st) is a JoinError (and should not happen) - // Result (2nd) is a ProofGenerationError - let result = result.unwrap(); // Should never happen (but should panic if it does) + // Wait for the rayon task. + // Result 1st is from send channel (no errors expected) + // Result 2nd can be a ProofGenerationError + let result = recv + .instrument(debug_span!("compute proof")) + .await + .expect("Panic in rayon::spawn"); // Should never happen (but panic if it does) let proof_sending_data = result .map(|r| ProofSendingData { @@ -173,7 +205,7 @@ mod tests { use claims::assert_matches; use futures::TryFutureExt; use tokio::sync::broadcast; - use tracing::info; + use tracing::{info, debug}; // third-party: zerokit use rln::{ circuit::{Curve, zkey_from_folder}, diff --git a/prover/src/rocksdb_operands.rs b/prover/src/rocksdb_operands.rs index d5ed1d2..f00dac5 100644 --- a/prover/src/rocksdb_operands.rs +++ b/prover/src/rocksdb_operands.rs @@ -146,7 +146,7 @@ pub fn epoch_counters_operands( // FIXME: assert when reload from disk // debug_assert_ge!(epoch_incr.epoch, acc.epoch); debug_assert!( - epoch_incr.epoch_slice > acc.epoch_slice + epoch_incr.epoch_slice >= acc.epoch_slice || epoch_incr.epoch_slice == EpochSlice::from(0) ); diff --git a/prover/src/user_db_tests.rs b/prover/src/user_db_tests.rs index 4c06f5e..5287687 100644 --- a/prover/src/user_db_tests.rs +++ b/prover/src/user_db_tests.rs @@ -5,14 +5,87 @@ mod user_db_tests { use std::sync::Arc; // third-party use alloy::primitives::{Address, address}; + use claims::assert_matches; use parking_lot::RwLock; + use crate::epoch_service::{Epoch, EpochSlice}; // internal use crate::user_db::UserDb; - use crate::user_db_types::{EpochSliceCounter, MerkleTreeIndex}; + use crate::user_db_types::{EpochCounter, EpochSliceCounter, MerkleTreeIndex}; const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045"); const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0"); + #[tokio::test] + async fn test_incr_tx_counter_2() { + + // Same as test_incr_tx_counter but multi users AND multi incr + + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); + + let epoch_store = Arc::new(RwLock::new(Default::default())); + let epoch = 1; + let epoch_slice = 42; + *epoch_store.write() = (Epoch::from(epoch), EpochSlice::from(epoch_slice)); + + let user_db = UserDb::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), + epoch_store, + Default::default(), + Default::default(), + ) + .unwrap(); + + // Register users + user_db.register(ADDR_1).unwrap(); + user_db.register(ADDR_2).unwrap(); + + assert_eq!( + user_db.get_tx_counter(&ADDR_1), + Ok((EpochCounter::from(0), EpochSliceCounter::from(0))) + ); + assert_eq!( + user_db.get_tx_counter(&ADDR_2), + Ok((EpochCounter::from(0), EpochSliceCounter::from(0))) + ); + + // Now update user tx counter + assert_eq!( + user_db.on_new_tx(&ADDR_1, None), + Ok(EpochSliceCounter::from(1)) + ); + assert_eq!( + user_db.on_new_tx(&ADDR_1, None), + Ok(EpochSliceCounter::from(2)) + ); + assert_eq!( + user_db.on_new_tx(&ADDR_1, Some(2)), + Ok(EpochSliceCounter::from(4)) + ); + + assert_eq!( + user_db.on_new_tx(&ADDR_2, None), + Ok(EpochSliceCounter::from(1)) + ); + + assert_eq!( + user_db.on_new_tx(&ADDR_2, None), + Ok(EpochSliceCounter::from(2)) + ); + + assert_eq!( + user_db.get_tx_counter(&ADDR_1), + Ok((EpochCounter::from(4), EpochSliceCounter::from(4))) + ); + + assert_eq!( + user_db.get_tx_counter(&ADDR_2), + Ok((EpochCounter::from(2), EpochSliceCounter::from(2))) + ); + + } + #[tokio::test] async fn test_persistent_storage() { let temp_folder = tempfile::tempdir().unwrap(); diff --git a/prover/tests/grpc_e2e.rs b/prover/tests/grpc_e2e.rs index b04b3d8..7da23b1 100644 --- a/prover/tests/grpc_e2e.rs +++ b/prover/tests/grpc_e2e.rs @@ -90,7 +90,7 @@ async fn test_grpc_register_users() { metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), metrics_port: 30031, broadcast_channel_size: 100, - proof_service_count: 8, + proof_service_count: 16, transaction_channel_size: 100, proof_sender_channel_size: 100, }; @@ -110,7 +110,10 @@ async fn test_grpc_register_users() { tokio::time::sleep(Duration::from_secs(1)).await; } -async fn proof_sender(port: u16, addresses: Vec
, _proof_count: usize) { +async fn proof_sender(port: u16, addresses: Vec
, proof_count: usize) { + + let start = std::time::Instant::now(); + let chain_id = GrpcU256 { // FIXME: LE or BE? value: U256::from(1).to_le_bytes::<32>().to_vec(), @@ -126,8 +129,29 @@ async fn proof_sender(port: u16, addresses: Vec
, _proof_count: usize) { // FIXME: LE or BE? value: U256::from(1000).to_le_bytes::<32>().to_vec(), }; - let tx_hash = U256::from(42).to_le_bytes::<32>().to_vec(); + let mut count = 0; + for i in 0..proof_count { + let tx_hash = U256::from(42 + i).to_le_bytes::<32>().to_vec(); + + let request_0 = SendTransactionRequest { + gas_price: Some(wei.clone()), + sender: Some(addr.clone()), + chain_id: Some(chain_id.clone()), + transaction_hash: tx_hash, + }; + + let request = tonic::Request::new(request_0); + let response: Response = + client.send_transaction(request).await.unwrap(); + assert_eq!(response.into_inner().result, true); + count += 1; + } + + println!("[proof_sender] sent {} tx - elapsed: {} secs", count, start.elapsed().as_secs_f64()); + + /* + let tx_hash = U256::from(42).to_le_bytes::<32>().to_vec(); let request_0 = SendTransactionRequest { gas_price: Some(wei), sender: Some(addr), @@ -138,9 +162,12 @@ async fn proof_sender(port: u16, addresses: Vec
, _proof_count: usize) { let request = tonic::Request::new(request_0); let response: Response = client.send_transaction(request).await.unwrap(); assert_eq!(response.into_inner().result, true); + */ } -async fn proof_collector(port: u16) -> Vec { +async fn proof_collector(port: u16, proof_count: usize) -> Vec { + + let start = std::time::Instant::now(); let result = Arc::new(RwLock::new(vec![])); let url = format!("http://127.0.0.1:{}", port); @@ -154,14 +181,25 @@ async fn proof_collector(port: u16) -> Vec { let mut stream = stream_.into_inner(); let result_2 = result.clone(); + let mut count = 0; + let mut start_per_message = std::time::Instant::now(); let receiver = async move { while let Some(response) = stream.message().await.unwrap() { result_2.write().push(response); + count += 1; + if count >= proof_count { + break; + } + println!("count {count} - elapsed: {} secs", start_per_message.elapsed().as_secs_f64()); + start_per_message = std::time::Instant::now(); } }; - let _res = tokio::time::timeout(Duration::from_secs(10), receiver).await; - std::mem::take(&mut *result.write()) + let _res = tokio::time::timeout(Duration::from_secs(500), receiver).await; + println!("_res: {:?}", _res); + let res = std::mem::take(&mut *result.write()); + println!("[proof_collector] elapsed: {} secs", start.elapsed().as_secs_f64()); + res } #[tokio::test] @@ -191,10 +229,10 @@ async fn test_grpc_gen_proof() { no_config: Some(true), metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), metrics_port: 30031, - broadcast_channel_size: 100, + broadcast_channel_size: 500, proof_service_count: 8, - transaction_channel_size: 100, - proof_sender_channel_size: 100, + transaction_channel_size: 500, + proof_sender_channel_size: 500, }; info!("Starting prover..."); @@ -206,15 +244,16 @@ async fn test_grpc_gen_proof() { register_users(port, addresses.clone()).await; info!("Sending tx and collecting proofs..."); - let proof_count = 1; + let proof_count = 10; let mut set = JoinSet::new(); set.spawn( proof_sender(port, addresses.clone(), proof_count).map(|_| vec![]), // JoinSet require having the same return type ); - set.spawn(proof_collector(port)); + set.spawn(proof_collector(port, proof_count)); let res = set.join_all().await; - assert_eq!(res[1].len(), proof_count); + println!("res lengths: {} {}", res[0].len(), res[1].len()); + assert_eq!(res[0].len() + res[1].len(), proof_count); info!("Aborting prover..."); prover_handle.abort();