Add throughput measurement for prover benchmark (#22)

* Add throughput measurement for prover benchmark
* Use rayon instead of tokio spawn blocking
* Add new user_db unit tests for user tx counterx
This commit is contained in:
Sydhds
2025-07-23 12:26:34 +02:00
committed by GitHub
parent 1ec9f1f48d
commit b802b80664
7 changed files with 199 additions and 44 deletions

1
Cargo.lock generated
View File

@@ -3861,6 +3861,7 @@ dependencies = [
"parking_lot 0.12.4",
"prost",
"rand 0.8.5",
"rayon",
"rln",
"rln_proof",
"rocksdb",

View File

@@ -44,6 +44,7 @@ rln = { git = "https://github.com/vacp2p/zerokit", features = ["pmtree-ft"] }
zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] }
rln_proof = { path = "../rln_proof" }
smart_contract = { path = "../smart_contract" }
rayon = "1.7"
[build-dependencies]
tonic-build = "*"

View File

@@ -1,4 +1,4 @@
use criterion::BenchmarkId;
use criterion::{BenchmarkId, Throughput};
use criterion::Criterion;
use criterion::{criterion_group, criterion_main};
@@ -134,7 +134,7 @@ fn proof_generation_bench(c: &mut Criterion) {
metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
metrics_port: 30051,
broadcast_channel_size: 100,
proof_service_count: 32,
proof_service_count: 4,
transaction_channel_size: 100,
proof_sender_channel_size: 100,
};
@@ -167,9 +167,15 @@ fn proof_generation_bench(c: &mut Criterion) {
});
println!("Starting benchmark...");
let size: usize = 1024;
let proof_count = 100;
c.bench_with_input(BenchmarkId::new("input_example", size), &size, |b, &_s| {
// let size: usize = 1024;
let mut group = c.benchmark_group("prover_bench");
// group.sampling_mode(criterion::SamplingMode::Flat);
let proof_count = 5;
group.throughput(Throughput::Elements(proof_count as u64));
group.bench_with_input(BenchmarkId::new("proof generation", proof_count), &proof_count, |b, &_s| {
b.to_async(&rt).iter(|| {
async {
let mut set = JoinSet::new();
@@ -182,13 +188,16 @@ fn proof_generation_bench(c: &mut Criterion) {
}
});
});
group.finish();
}
criterion_group!(
name = benches;
config = Criterion::default()
.sample_size(10)
.measurement_time(Duration::from_secs(150));
.sample_size(20)
// .measurement_time(Duration::from_secs(150))
;
targets = proof_generation_bench
);
criterion_main!(benches);

View File

@@ -8,7 +8,9 @@ use metrics::{counter, histogram};
use parking_lot::RwLock;
use rln::hashers::hash_to_field;
use rln::protocol::serialize_proof_values;
use tracing::{Instrument, debug, debug_span, info};
use tracing::{Instrument, // debug,
debug_span, info
};
// internal
use crate::epoch_service::{Epoch, EpochSlice};
use crate::error::{AppError, ProofGenerationError, ProofGenerationStringError};
@@ -71,12 +73,18 @@ impl ProofService {
let proof_generation_data_ = proof_generation_data.clone();
let rate_limit = self.rate_limit;
// let counter_label = Arc::new(format!("proof service (id: {})", self.id));
// let counter_label_ref = counter_label.clone();
let counter_id = self.id;
// println!("[proof service {counter_id}] starting to generate proof...");
// Communicate between rayon & current task
let (send, recv) = tokio::sync::oneshot::channel();
// Move to a task (as generating the proof can take quite some time) - avoid blocking the tokio runtime
// Note: avoid tokio spawn_blocking as it does not perform great for CPU bounds tasks
// see https://ryhl.io/blog/async-what-is-blocking/
rayon::spawn(move || {
// Move to a task (as generating the proof can take quite some time)
let blocking_task = tokio::task::spawn_blocking(move || {
let proof_generation_start = std::time::Instant::now();
let message_id = {
@@ -102,41 +110,65 @@ impl ProofService {
};
let epoch = hash_to_field(epoch_bytes.as_slice());
let merkle_proof = user_db.get_merkle_proof(&proof_generation_data.tx_sender)?;
let merkle_proof = match user_db.get_merkle_proof(&proof_generation_data.tx_sender) {
Ok(merkle_proof) => merkle_proof,
Err(e) => {
let _ = send.send(Err(ProofGenerationError::MerkleProofError(e)));
return;
}
};
let (proof, proof_values) = compute_rln_proof_and_values(
// let compute_proof_start = std::time::Instant::now();
let (proof, proof_values) = match compute_rln_proof_and_values(
&proof_generation_data.user_identity,
&proof_generation_data.rln_identifier,
rln_data,
epoch,
&merkle_proof,
)
.map_err(ProofGenerationError::Proof)?;
) {
Ok((proof, proof_values)) => (proof, proof_values),
Err(e) => {
let _ = send.send(Err(ProofGenerationError::Proof(e)));
return;
}
};
debug!("proof: {:?}", proof);
debug!("proof_values: {:?}", proof_values);
// debug!("proof: {:?}", proof);
// debug!("proof_values: {:?}", proof_values);
// Serialize proof
let mut output_buffer = Cursor::new(Vec::with_capacity(PROOF_SIZE));
proof
if let Err(e) = proof
.serialize_compressed(&mut output_buffer)
.map_err(ProofGenerationError::Serialization)?;
output_buffer
.write_all(&serialize_proof_values(&proof_values))
.map_err(ProofGenerationError::SerializationWrite)?;
{
let _ = send.send(Err(ProofGenerationError::Serialization(e)));
return;
}
if let Err(e) = output_buffer
.write_all(&serialize_proof_values(&proof_values)) {
let _ = send.send(Err(ProofGenerationError::SerializationWrite(e)));
return;
}
histogram!(PROOF_SERVICE_GEN_PROOF_TIME.name, "prover" => "proof service")
.record(proof_generation_start.elapsed().as_secs_f64());
.record(proof_generation_start.elapsed().as_secs_f64());
// println!("[proof service {counter_id}] proof generation time: {:?} secs", proof_generation_start.elapsed().as_secs_f64());
let labels = [("prover", format!("proof service id: {counter_id}"))];
counter!(PROOF_SERVICE_PROOF_COMPUTED.name, &labels).increment(1);
Ok::<Vec<u8>, ProofGenerationError>(output_buffer.into_inner())
// Send the result back to Tokio.
let _ = send.send(
Ok::<Vec<u8>, ProofGenerationError>(output_buffer.into_inner())
);
});
let result = blocking_task.instrument(debug_span!("compute proof")).await;
// Result (1st) is a JoinError (and should not happen)
// Result (2nd) is a ProofGenerationError
let result = result.unwrap(); // Should never happen (but should panic if it does)
// Wait for the rayon task.
// Result 1st is from send channel (no errors expected)
// Result 2nd can be a ProofGenerationError
let result = recv
.instrument(debug_span!("compute proof"))
.await
.expect("Panic in rayon::spawn"); // Should never happen (but panic if it does)
let proof_sending_data = result
.map(|r| ProofSendingData {
@@ -173,7 +205,7 @@ mod tests {
use claims::assert_matches;
use futures::TryFutureExt;
use tokio::sync::broadcast;
use tracing::info;
use tracing::{info, debug};
// third-party: zerokit
use rln::{
circuit::{Curve, zkey_from_folder},

View File

@@ -146,7 +146,7 @@ pub fn epoch_counters_operands(
// FIXME: assert when reload from disk
// debug_assert_ge!(epoch_incr.epoch, acc.epoch);
debug_assert!(
epoch_incr.epoch_slice > acc.epoch_slice
epoch_incr.epoch_slice >= acc.epoch_slice
|| epoch_incr.epoch_slice == EpochSlice::from(0)
);

View File

@@ -5,14 +5,87 @@ mod user_db_tests {
use std::sync::Arc;
// third-party
use alloy::primitives::{Address, address};
use claims::assert_matches;
use parking_lot::RwLock;
use crate::epoch_service::{Epoch, EpochSlice};
// internal
use crate::user_db::UserDb;
use crate::user_db_types::{EpochSliceCounter, MerkleTreeIndex};
use crate::user_db_types::{EpochCounter, EpochSliceCounter, MerkleTreeIndex};
const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0");
#[tokio::test]
async fn test_incr_tx_counter_2() {
// Same as test_incr_tx_counter but multi users AND multi incr
let temp_folder = tempfile::tempdir().unwrap();
let temp_folder_tree = tempfile::tempdir().unwrap();
let epoch_store = Arc::new(RwLock::new(Default::default()));
let epoch = 1;
let epoch_slice = 42;
*epoch_store.write() = (Epoch::from(epoch), EpochSlice::from(epoch_slice));
let user_db = UserDb::new(
PathBuf::from(temp_folder.path()),
PathBuf::from(temp_folder_tree.path()),
epoch_store,
Default::default(),
Default::default(),
)
.unwrap();
// Register users
user_db.register(ADDR_1).unwrap();
user_db.register(ADDR_2).unwrap();
assert_eq!(
user_db.get_tx_counter(&ADDR_1),
Ok((EpochCounter::from(0), EpochSliceCounter::from(0)))
);
assert_eq!(
user_db.get_tx_counter(&ADDR_2),
Ok((EpochCounter::from(0), EpochSliceCounter::from(0)))
);
// Now update user tx counter
assert_eq!(
user_db.on_new_tx(&ADDR_1, None),
Ok(EpochSliceCounter::from(1))
);
assert_eq!(
user_db.on_new_tx(&ADDR_1, None),
Ok(EpochSliceCounter::from(2))
);
assert_eq!(
user_db.on_new_tx(&ADDR_1, Some(2)),
Ok(EpochSliceCounter::from(4))
);
assert_eq!(
user_db.on_new_tx(&ADDR_2, None),
Ok(EpochSliceCounter::from(1))
);
assert_eq!(
user_db.on_new_tx(&ADDR_2, None),
Ok(EpochSliceCounter::from(2))
);
assert_eq!(
user_db.get_tx_counter(&ADDR_1),
Ok((EpochCounter::from(4), EpochSliceCounter::from(4)))
);
assert_eq!(
user_db.get_tx_counter(&ADDR_2),
Ok((EpochCounter::from(2), EpochSliceCounter::from(2)))
);
}
#[tokio::test]
async fn test_persistent_storage() {
let temp_folder = tempfile::tempdir().unwrap();

View File

@@ -90,7 +90,7 @@ async fn test_grpc_register_users() {
metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
metrics_port: 30031,
broadcast_channel_size: 100,
proof_service_count: 8,
proof_service_count: 16,
transaction_channel_size: 100,
proof_sender_channel_size: 100,
};
@@ -110,7 +110,10 @@ async fn test_grpc_register_users() {
tokio::time::sleep(Duration::from_secs(1)).await;
}
async fn proof_sender(port: u16, addresses: Vec<Address>, _proof_count: usize) {
async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
let start = std::time::Instant::now();
let chain_id = GrpcU256 {
// FIXME: LE or BE?
value: U256::from(1).to_le_bytes::<32>().to_vec(),
@@ -126,8 +129,29 @@ async fn proof_sender(port: u16, addresses: Vec<Address>, _proof_count: usize) {
// FIXME: LE or BE?
value: U256::from(1000).to_le_bytes::<32>().to_vec(),
};
let tx_hash = U256::from(42).to_le_bytes::<32>().to_vec();
let mut count = 0;
for i in 0..proof_count {
let tx_hash = U256::from(42 + i).to_le_bytes::<32>().to_vec();
let request_0 = SendTransactionRequest {
gas_price: Some(wei.clone()),
sender: Some(addr.clone()),
chain_id: Some(chain_id.clone()),
transaction_hash: tx_hash,
};
let request = tonic::Request::new(request_0);
let response: Response<SendTransactionReply> =
client.send_transaction(request).await.unwrap();
assert_eq!(response.into_inner().result, true);
count += 1;
}
println!("[proof_sender] sent {} tx - elapsed: {} secs", count, start.elapsed().as_secs_f64());
/*
let tx_hash = U256::from(42).to_le_bytes::<32>().to_vec();
let request_0 = SendTransactionRequest {
gas_price: Some(wei),
sender: Some(addr),
@@ -138,9 +162,12 @@ async fn proof_sender(port: u16, addresses: Vec<Address>, _proof_count: usize) {
let request = tonic::Request::new(request_0);
let response: Response<SendTransactionReply> = client.send_transaction(request).await.unwrap();
assert_eq!(response.into_inner().result, true);
*/
}
async fn proof_collector(port: u16) -> Vec<RlnProofReply> {
async fn proof_collector(port: u16, proof_count: usize) -> Vec<RlnProofReply> {
let start = std::time::Instant::now();
let result = Arc::new(RwLock::new(vec![]));
let url = format!("http://127.0.0.1:{}", port);
@@ -154,14 +181,25 @@ async fn proof_collector(port: u16) -> Vec<RlnProofReply> {
let mut stream = stream_.into_inner();
let result_2 = result.clone();
let mut count = 0;
let mut start_per_message = std::time::Instant::now();
let receiver = async move {
while let Some(response) = stream.message().await.unwrap() {
result_2.write().push(response);
count += 1;
if count >= proof_count {
break;
}
println!("count {count} - elapsed: {} secs", start_per_message.elapsed().as_secs_f64());
start_per_message = std::time::Instant::now();
}
};
let _res = tokio::time::timeout(Duration::from_secs(10), receiver).await;
std::mem::take(&mut *result.write())
let _res = tokio::time::timeout(Duration::from_secs(500), receiver).await;
println!("_res: {:?}", _res);
let res = std::mem::take(&mut *result.write());
println!("[proof_collector] elapsed: {} secs", start.elapsed().as_secs_f64());
res
}
#[tokio::test]
@@ -191,10 +229,10 @@ async fn test_grpc_gen_proof() {
no_config: Some(true),
metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
metrics_port: 30031,
broadcast_channel_size: 100,
broadcast_channel_size: 500,
proof_service_count: 8,
transaction_channel_size: 100,
proof_sender_channel_size: 100,
transaction_channel_size: 500,
proof_sender_channel_size: 500,
};
info!("Starting prover...");
@@ -206,15 +244,16 @@ async fn test_grpc_gen_proof() {
register_users(port, addresses.clone()).await;
info!("Sending tx and collecting proofs...");
let proof_count = 1;
let proof_count = 10;
let mut set = JoinSet::new();
set.spawn(
proof_sender(port, addresses.clone(), proof_count).map(|_| vec![]), // JoinSet require having the same return type
);
set.spawn(proof_collector(port));
set.spawn(proof_collector(port, proof_count));
let res = set.join_all().await;
assert_eq!(res[1].len(), proof_count);
println!("res lengths: {} {}", res[0].len(), res[1].len());
assert_eq!(res[0].len() + res[1].len(), proof_count);
info!("Aborting prover...");
prover_handle.abort();