Improve test_notify unit test (more reliable) (#17)

* Improve test_notify unit test (more reliable)
This commit is contained in:
Sydhds
2025-07-11 09:57:38 +02:00
committed by GitHub
parent 88678afdb2
commit e5f88726ae
4 changed files with 115 additions and 127 deletions

View File

@@ -4,23 +4,18 @@ use criterion::{criterion_group, criterion_main};
// std
use std::net::{IpAddr, Ipv4Addr};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::str::FromStr;
// third-party
use alloy::{
primitives::{Address, U256},
};
use alloy::primitives::{Address, U256};
use futures::FutureExt;
use parking_lot::RwLock;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tonic::Response;
use futures::FutureExt;
// internal
use prover::{
AppArgs,
run_prover
};
use prover::{AppArgs, run_prover};
// grpc
pub mod prover_proto {
@@ -28,40 +23,32 @@ pub mod prover_proto {
tonic::include_proto!("prover");
}
use prover_proto::{
Address as GrpcAddress,
U256 as GrpcU256,
Wei as GrpcWei,
RegisterUserReply, RegisterUserRequest, RegistrationStatus,
SendTransactionRequest, SendTransactionReply,
RlnProofFilter, RlnProofReply,
rln_prover_client::RlnProverClient
Address as GrpcAddress, RegisterUserReply, RegisterUserRequest, RegistrationStatus,
RlnProofFilter, RlnProofReply, SendTransactionReply, SendTransactionRequest, U256 as GrpcU256,
Wei as GrpcWei, rln_prover_client::RlnProverClient,
};
async fn register_users(port: u16, addresses: Vec<Address>) {
let url = format!("http://127.0.0.1:{}", port);
let mut client = RlnProverClient::connect(url).await.unwrap();
for address in addresses {
let addr = GrpcAddress {
value: address.to_vec(),
};
let request_0 = RegisterUserRequest {
user: Some(addr),
};
let request_0 = RegisterUserRequest { user: Some(addr) };
let request = tonic::Request::new(request_0);
let response: Response<RegisterUserReply> = client.register_user(request).await.unwrap();
assert_eq!(
RegistrationStatus::try_from(response.into_inner().status).unwrap(),
RegistrationStatus::Success);
RegistrationStatus::Success
);
}
}
async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
let chain_id = GrpcU256 {
// FIXME: LE or BE?
value: U256::from(1).to_le_bytes::<32>().to_vec(),
@@ -79,7 +66,7 @@ async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
};
for i in 0..proof_count {
let tx_hash = U256::from(42+i).to_le_bytes::<32>().to_vec();
let tx_hash = U256::from(42 + i).to_le_bytes::<32>().to_vec();
let request_0 = SendTransactionRequest {
gas_price: Some(wei.clone()),
@@ -89,21 +76,19 @@ async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
};
let request = tonic::Request::new(request_0);
let response: Response<SendTransactionReply> = client.send_transaction(request).await.unwrap();
let response: Response<SendTransactionReply> =
client.send_transaction(request).await.unwrap();
assert_eq!(response.into_inner().result, true);
}
}
async fn proof_collector(port: u16, proof_count: usize) -> Vec<RlnProofReply> {
let result= Arc::new(RwLock::new(vec![]));
let result = Arc::new(RwLock::new(vec![]));
let url = format!("http://127.0.0.1:{}", port);
let mut client = RlnProverClient::connect(url).await.unwrap();
let request_0 = RlnProofFilter {
address: None,
};
let request_0 = RlnProofFilter { address: None };
let request = tonic::Request::new(request_0);
let stream_ = client.get_proofs(request).await.unwrap();
@@ -125,7 +110,6 @@ async fn proof_collector(port: u16, proof_count: usize) -> Vec<RlnProofReply> {
}
fn proof_generation_bench(c: &mut Criterion) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
@@ -163,26 +147,22 @@ fn proof_generation_bench(c: &mut Criterion) {
// Spawn prover
let notify_start_1 = notify_start.clone();
rt.spawn(
async move {
tokio::spawn(run_prover(app_args));
tokio::time::sleep(Duration::from_secs(10)).await;
println!("Prover is ready, notifying it...");
notify_start_1.clone().notify_one();
}
);
rt.spawn(async move {
tokio::spawn(run_prover(app_args));
tokio::time::sleep(Duration::from_secs(10)).await;
println!("Prover is ready, notifying it...");
notify_start_1.clone().notify_one();
});
let notify_start_2 = notify_start.clone();
let addresses_0 = addresses.clone();
// Wait for proof_collector to be connected and waiting for some proofs
let _res = rt.block_on(
async move {
notify_start_2.notified().await;
println!("Prover is ready, registering users...");
register_users(port, addresses_0).await;
}
);
let _res = rt.block_on(async move {
notify_start_2.notified().await;
println!("Prover is ready, registering users...");
register_users(port, addresses_0).await;
});
println!("Starting benchmark...");
let size: usize = 1024;
@@ -192,9 +172,7 @@ fn proof_generation_bench(c: &mut Criterion) {
async {
let mut set = JoinSet::new();
set.spawn(proof_collector(port, proof_count));
set.spawn(
proof_sender(port, addresses.clone(), proof_count).map(|_r| vec![])
);
set.spawn(proof_sender(port, addresses.clone(), proof_count).map(|_r| vec![]));
// Wait for proof_sender + proof_collector to complete
let res = set.join_all().await;
// Check we receive enough proof
@@ -211,4 +189,4 @@ criterion_group!(
.measurement_time(Duration::from_secs(150));
targets = proof_generation_bench
);
criterion_main!(benches);
criterion_main!(benches);

View File

@@ -6,19 +6,25 @@ use chrono::{DateTime, NaiveDate, NaiveDateTime, OutOfRangeError, TimeDelta, Utc
use derive_more::{Deref, From, Into};
use parking_lot::RwLock;
use tokio::sync::Notify;
use tracing::debug;
use tracing::{debug, error};
// internal
use crate::error::AppError;
/// Duration of an epoch (1 day)
const EPOCH_DURATION: Duration = Duration::from_secs(TimeDelta::days(1).num_seconds() as u64);
/// Minimum duration returned by EpochService::compute_wait_until()
const WAIT_UNTIL_MIN_DURATION: Duration = Duration::from_secs(5);
const WAIT_UNTIL_MIN_DURATION: Duration = Duration::from_secs(2);
/// EpochService::compute_wait_until() can return an error like TooLow (see WAIT_UNTIL_MIN_DURATION)
/// so the epoch service will retry X many times.
const WAIT_UNTIL_MAX_COMPUTE_ERROR: usize = 10;
/// An Epoch tracking service
///
/// The service keeps track of the current epoch (duration: 1 day) and the current epoch slice
/// (duration: configurable, < 1 day, usually in minutes)
///
/// Use TryFrom impl to initialize an EpochService. Note that initial epoch & epoch slice is
/// initialized to Default values. Calling listen_for_new_epoch will initialize these fields.
pub struct EpochService {
/// A subdivision of an epoch (in minutes or seconds)
epoch_slice_duration: Duration,
@@ -36,20 +42,45 @@ impl EpochService {
Self::compute_epoch_slice_count(EPOCH_DURATION, self.epoch_slice_duration);
debug!("epoch slice in an epoch: {}", epoch_slice_count);
let (mut current_epoch, mut current_epoch_slice, mut wait_until) =
let mut retry_counter = 0;
let (mut current_epoch, mut current_epoch_slice, mut wait_until): (
i64,
i64,
tokio::time::Instant,
) = loop {
match self.compute_wait_until(&|| Utc::now(), &|| tokio::time::Instant::now()) {
Ok((current_epoch, current_epoch_slice, wait_until)) => {
(current_epoch, current_epoch_slice, wait_until)
break (current_epoch, current_epoch_slice, wait_until);
}
Err(_e) => {
// sleep and try again (only one retry)
Err(WaitUntilError::TooLow(d1, d2)) => {
// Wait until is too low (according to const WAIT_UNTIL_MIN_DURATION)
// so we will retry (WAIT_UNTIL_MAX_COMPUTE_ERROR many times) after a short sleep
debug!("compute_wait_until return TooLow, will retry after a sleep...");
tokio::time::sleep(WAIT_UNTIL_MIN_DURATION).await;
self.compute_wait_until(&|| Utc::now(), &|| tokio::time::Instant::now())?
retry_counter += 1;
if retry_counter > WAIT_UNTIL_MAX_COMPUTE_ERROR {
error!(
"Too many errors while computing the initial wait until, aborting..."
);
return Err(AppError::EpochError(WaitUntilError::TooLow(d1, d2)));
}
}
Err(e) => {
// Another error (like OutOfRange) - exiting...
error!("Error computing the initial wait until: {}", e);
return Err(AppError::EpochError(e));
}
};
};
debug!("wait until: {:?}", wait_until);
// debug!("wait until: {:?}", wait_until);
*self.current_epoch.write() = (current_epoch.into(), current_epoch_slice.into());
debug!(
"Initial epoch: {}, epoch slice: {}",
current_epoch, current_epoch_slice
);
loop {
debug!("wait until: {:?}", wait_until);
@@ -146,7 +177,7 @@ impl EpochService {
fn compute_current_epoch_slice<F: Fn() -> DateTime<Utc>>(
now_date: NaiveDate,
epoch_slice_duration: Duration,
now: F,
now: &F,
) -> i64 {
debug_assert!(epoch_slice_duration.as_secs() > 0);
debug_assert!(i32::try_from(epoch_slice_duration.as_secs()).is_ok());
@@ -239,21 +270,11 @@ impl Add<i64> for EpochSlice {
mod tests {
use super::*;
use chrono::{NaiveDate, NaiveDateTime, TimeDelta};
use claims::assert_ge;
use futures::TryFutureExt;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing_test::traced_test;
/*
#[tokio::test]
async fn test_wait_until_0() {
let wait_until = tokio::time::Instant::now() + Duration::from_secs(10);
println!("Should wait until: {:?}", wait_until);
tokio::time::sleep(Duration::from_secs(3)).await;
tokio::time::sleep_until(wait_until).await;
println!("Wake up at: {:?}", tokio::time::Instant::now());
}
*/
#[test]
fn test_wait_until() {
// Check wait_until is correctly computed
@@ -416,7 +437,7 @@ mod tests {
// Note: 4-minute diff -> expect == 2
assert_eq!(
EpochService::compute_current_epoch_slice(now_date, epoch_slice_duration, now_f),
EpochService::compute_current_epoch_slice(now_date, epoch_slice_duration, &now_f),
2
);
// Note: 5 minutes and 59 seconds diff -> still expect == 2
@@ -424,7 +445,7 @@ mod tests {
EpochService::compute_current_epoch_slice(
now_date,
epoch_slice_duration,
Box::new(now_f_2)
&Box::new(now_f_2)
),
2
);
@@ -433,7 +454,7 @@ mod tests {
EpochService::compute_current_epoch_slice(
now_date,
epoch_slice_duration,
Box::new(now_f_3)
&Box::new(now_f_3)
),
3
);
@@ -458,17 +479,21 @@ mod tests {
let counter_0 = Arc::new(AtomicU64::new(0));
let counter = counter_0.clone();
let start = std::time::Instant::now();
let res = tokio::try_join!(
epoch_service
.listen_for_new_epoch()
.map_err(|e| AppErrorExt::AppError(e)),
// Wait for 3 epoch slices + 100 ms (to wait to receive notif + counter incr)
// Wait for 3 epoch slices
// + WAIT_UNTIL_MIN_DURATION * 2 (expect a maximum of 2 retry)
// + 500 ms (to wait to receive notif + counter incr)
// Note: this might fail if there is more retry (see list_for_new_epoch code)
tokio::time::timeout(
epoch_slice_duration * 3 + Duration::from_millis(500),
epoch_slice_duration * 3 + WAIT_UNTIL_MIN_DURATION * 2 + Duration::from_millis(500),
async move {
loop {
notifier.notified().await;
debug!("[Notified] Epoch update...");
// debug!("[Notified] Epoch update...");
let _v = counter.fetch_add(1, Ordering::SeqCst);
}
// Ok::<(), AppErrorExt>(())
@@ -476,7 +501,11 @@ mod tests {
)
.map_err(|_e| AppErrorExt::Elapsed)
);
debug!("Elapsed time: {}", start.elapsed().as_millis());
// debug!("res: {:?}", res);
assert!(matches!(res, Err(AppErrorExt::Elapsed)));
assert_eq!(counter_0.load(Ordering::SeqCst), 3);
// Because the timeout is quite large - it is expected that sometimes the counter == 4
assert_ge!(counter_0.load(Ordering::SeqCst), 3);
}
}

View File

@@ -29,9 +29,6 @@ use tracing::{
// info
};
// internal
use rln_proof::RlnIdentifier;
use smart_contract::KarmaTiersSC::KarmaTiersSCInstance;
use smart_contract::TIER_LIMITS;
pub use crate::args::{AppArgs, AppArgsConfig};
use crate::epoch_service::EpochService;
use crate::grpc_service::GrpcProverService;
@@ -42,6 +39,9 @@ use crate::tier::TierLimits;
use crate::tiers_listener::TiersListener;
use crate::user_db_service::UserDbService;
use crate::user_db_types::RateLimit;
use rln_proof::RlnIdentifier;
use smart_contract::KarmaTiersSC::KarmaTiersSCInstance;
use smart_contract::TIER_LIMITS;
const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier";
const PROVER_SPAM_LIMIT: RateLimit = RateLimit::new(10_000u64);
@@ -49,8 +49,9 @@ const GENESIS: DateTime<Utc> = DateTime::from_timestamp(1431648000, 0).unwrap();
const PROVER_MINIMAL_AMOUNT_FOR_REGISTRATION: U256 =
U256::from_le_slice(10u64.to_le_bytes().as_slice());
pub async fn run_prover(app_args: AppArgs) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
pub async fn run_prover(
app_args: AppArgs,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
// Epoch
let epoch_service = EpochService::try_from((Duration::from_secs(60 * 2), GENESIS))
.expect("Failed to create epoch service");
@@ -61,7 +62,7 @@ pub async fn run_prover(app_args: AppArgs) -> Result<(), Box<dyn std::error::Err
app_args.ws_rpc_url.clone().unwrap(),
app_args.tsc_address.unwrap(),
)
.await?,
.await?,
)
} else {
// mock
@@ -199,29 +200,22 @@ mod tests {
use std::str::FromStr;
use std::sync::Arc;
// third-party
use tokio::task;
use tonic::Response;
use alloy::{
primitives::{Address, U256},
};
use tracing::info;
use tracing_test::traced_test;
use alloy::primitives::{Address, U256};
use futures::FutureExt;
use parking_lot::RwLock;
use tokio::task;
use tonic::Response;
use tracing::info;
use tracing_test::traced_test;
// internal
use crate::grpc_service::prover_proto::{
Address as GrpcAddress,
U256 as GrpcU256,
Wei as GrpcWei,
GetUserTierInfoReply, GetUserTierInfoRequest,
RegisterUserReply, RegisterUserRequest, RegistrationStatus,
SendTransactionRequest, SendTransactionReply,
RlnProofFilter, RlnProofReply
};
use crate::grpc_service::prover_proto::rln_prover_client::RlnProverClient;
use crate::grpc_service::prover_proto::{
Address as GrpcAddress, GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply,
RegisterUserRequest, RegistrationStatus, RlnProofFilter, RlnProofReply,
SendTransactionReply, SendTransactionRequest, U256 as GrpcU256, Wei as GrpcWei,
};
async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
let chain_id = GrpcU256 {
// FIXME: LE or BE?
value: U256::from(1).to_le_bytes::<32>().to_vec(),
@@ -247,20 +241,18 @@ mod tests {
};
let request = tonic::Request::new(request_0);
let response: Response<SendTransactionReply> = client.send_transaction(request).await.unwrap();
let response: Response<SendTransactionReply> =
client.send_transaction(request).await.unwrap();
assert_eq!(response.into_inner().result, true);
}
async fn proof_collector(port: u16) -> Vec<RlnProofReply> {
let result= Arc::new(RwLock::new(vec![]));
let result = Arc::new(RwLock::new(vec![]));
let url = format!("http://127.0.0.1:{}", port);
let mut client = RlnProverClient::connect(url).await.unwrap();
let request_0 = RlnProofFilter {
address: None,
};
let request_0 = RlnProofFilter { address: None };
let request = tonic::Request::new(request_0);
let stream_ = client.get_proofs(request).await.unwrap();
@@ -279,30 +271,27 @@ mod tests {
}
async fn register_users(port: u16, addresses: Vec<Address>) {
let url = format!("http://127.0.0.1:{}", port);
let mut client = RlnProverClient::connect(url).await.unwrap();
for address in addresses {
let addr = GrpcAddress {
value: address.to_vec(),
};
let request_0 = RegisterUserRequest {
user: Some(addr),
};
let request_0 = RegisterUserRequest { user: Some(addr) };
let request = tonic::Request::new(request_0);
let response: Response<RegisterUserReply> = client.register_user(request).await.unwrap();
let response: Response<RegisterUserReply> =
client.register_user(request).await.unwrap();
assert_eq!(
RegistrationStatus::try_from(response.into_inner().status).unwrap(),
RegistrationStatus::Success);
RegistrationStatus::Success
);
}
}
async fn query_user_info(port: u16, addresses: Vec<Address>) -> Vec<GetUserTierInfoReply> {
let url = format!("http://127.0.0.1:{}", port);
let mut client = RlnProverClient::connect(url).await.unwrap();
@@ -311,11 +300,10 @@ mod tests {
let addr = GrpcAddress {
value: address.to_vec(),
};
let request_0 = GetUserTierInfoRequest {
user: Some(addr),
};
let request_0 = GetUserTierInfoRequest { user: Some(addr) };
let request = tonic::Request::new(request_0);
let resp: Response<GetUserTierInfoReply> = client.get_user_tier_info(request).await.unwrap();
let resp: Response<GetUserTierInfoReply> =
client.get_user_tier_info(request).await.unwrap();
result.push(resp.into_inner());
}
@@ -326,7 +314,6 @@ mod tests {
#[tokio::test]
#[traced_test]
async fn test_grpc_register_users() {
let addresses = vec![
Address::from_str("0xd8da6bf26964af9d7eed9e03e53415d37aa96045").unwrap(),
Address::from_str("0xb20a608c624Ca5003905aA834De7156C68b2E1d0").unwrap(),
@@ -373,7 +360,6 @@ mod tests {
#[tokio::test]
#[traced_test]
async fn test_grpc_gen_proof() {
let addresses = vec![
Address::from_str("0xd8da6bf26964af9d7eed9e03e53415d37aa96045").unwrap(),
Address::from_str("0xb20a608c624Ca5003905aA834De7156C68b2E1d0").unwrap(),
@@ -414,8 +400,7 @@ mod tests {
let proof_count = 1;
let mut set = JoinSet::new();
set.spawn(
proof_sender(port, addresses.clone(), proof_count)
.map(|_| vec![]) // JoinSet require having the same return type
proof_sender(port, addresses.clone(), proof_count).map(|_| vec![]), // JoinSet require having the same return type
);
set.spawn(proof_collector(port));
let res = set.join_all().await;

View File

@@ -10,11 +10,7 @@ use tracing::{
};
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
// internal
use prover::{
run_prover,
AppArgs,
AppArgsConfig,
};
use prover::{AppArgs, AppArgsConfig, run_prover};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {