From b7967b85e8a3bc955c5826536e8899fb2842fd80 Mon Sep 17 00:00:00 2001 From: Sydhds Date: Wed, 1 Oct 2025 17:00:56 +0200 Subject: [PATCH] Initial code to use Zerokit 0.9 + disable parallel feature (#36) * Initial code to use Zerokit 0.9 + disable parallel feature * Support IdSecret for user identity secret hash * Fix clippy + bench * Use PmTreeConfig builder * Improve prover_bench perf * Fix prover_bench 2nd assert * Fix prover_bench 2nd assert 2 * Can now enable trace for bench prover_bench * Use anyhow for error handling (+ error context) in prover_cli (#42) * Use anyhow for error handling (+ error context) in prover_cli * Cargo fmt pass * Feature/feature/init user db ser de 2 (#45) * Add user db serializer && deserializer init & re-use --- Cargo.lock | 25 ++--- Cargo.toml | 16 +++- prover/Cargo.toml | 5 +- prover/benches/prover_bench.rs | 78 +++++++++++++--- prover/src/args.rs | 2 +- prover/src/error.rs | 16 +++- prover/src/lib.rs | 35 ++++--- prover/src/proof_service.rs | 17 +++- prover/src/proof_service_tests.rs | 4 +- prover/src/rocksdb_operands.rs | 2 + prover/src/user_db.rs | 137 ++++++++++++++-------------- prover/src/user_db_error.rs | 2 +- prover/src/user_db_serialization.rs | 19 ++-- prover_cli/Cargo.toml | 1 + prover_cli/src/main.rs | 14 ++- rln_proof/benches/generate_proof.rs | 44 ++++----- rln_proof/src/proof.rs | 47 ++++++---- 17 files changed, 273 insertions(+), 191 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d2847d7..17dba6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -850,7 +850,6 @@ dependencies = [ "digest 0.10.7", "fnv", "merlin", - "rayon", "sha2", ] @@ -883,7 +882,6 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", - "rayon", "zeroize", ] @@ -942,7 +940,6 @@ dependencies = [ "num-bigint", "num-traits", "paste", - "rayon", "zeroize", ] @@ -1027,7 +1024,6 @@ dependencies = [ "ark-relations", "ark-serialize 0.5.0", "ark-std 0.5.0", - "rayon", ] [[package]] @@ -1043,7 +1039,6 @@ dependencies = [ "educe", "fnv", "hashbrown 0.15.5", - "rayon", ] [[package]] @@ -1107,7 +1102,6 @@ dependencies = [ "arrayvec", "digest 0.10.7", "num-bigint", - "rayon", ] [[package]] @@ -1161,7 +1155,6 @@ checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" dependencies = [ "num-traits", "rand 0.8.5", - "rayon", ] [[package]] @@ -3937,6 +3930,7 @@ dependencies = [ "derive_more", "futures", "http", + "lazy_static", "metrics", "metrics-exporter-prometheus", "metrics-util", @@ -3973,6 +3967,7 @@ dependencies = [ name = "prover_cli" version = "0.1.0" dependencies = [ + "anyhow", "clap", "opentelemetry", "opentelemetry-otlp", @@ -4388,8 +4383,7 @@ dependencies = [ [[package]] name = "rln" version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a03834bc168adfee6f49c885fabb0ad6897be11d97258e9375d98f30d2c9878" +source = "git+https://github.com/vacp2p/zerokit/#0b00c639a059a2cfde74bcf68fdf75db3b6898a4" dependencies = [ "ark-bn254", "ark-ec", @@ -4405,14 +4399,16 @@ dependencies = [ "num-bigint", "num-traits", "once_cell", - "prost 0.13.5", + "prost 0.14.1", "rand 0.8.5", "rand_chacha 0.3.1", "ruint", "serde", "serde_json", + "tempfile", "thiserror", "tiny-keccak", + "zeroize", "zerokit_utils", ] @@ -5125,15 +5121,15 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.20.0" +version = "3.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +checksum = "15b61f8f20e3a6f7e0649d825294eaf317edce30f82cf6026e7e4cb9222a7d1e" dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", "rustix 1.0.8", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -6388,8 +6384,7 @@ dependencies = [ [[package]] name = "zerokit_utils" version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21d5ee8dd5cba6c9e39c7391e3fe968d479b6f1eb51c82556b2bb9b2924f572" +source = "git+https://github.com/vacp2p/zerokit/#0b00c639a059a2cfde74bcf68fdf75db3b6898a4" dependencies = [ "ark-ff 0.5.0", "hex", diff --git a/Cargo.toml b/Cargo.toml index 74ae87c..a0e4eea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,13 +9,19 @@ members = [ resolver = "2" [workspace.dependencies] -rln = { version = "0.8.0", features = ["pmtree-ft"] } -zerokit_utils = { version = "0.6.0", features = ["pmtree-ft"] } +# rln = { version = "0.8.0", features = ["pmtree-ft"] } +rln = { git = "https://github.com/vacp2p/zerokit/", default-features = false, features = ["pmtree-ft"] } +# zerokit_utils = { version = "0.6.0", features = ["pmtree-ft"] } +zerokit_utils = { git = "https://github.com/vacp2p/zerokit/", default-features = false, features = ["pmtree-ft"] } + ark-bn254 = { version = "0.5.0", features = ["std"] } ark-relations = { version = "0.5.1", features = ["std"] } -ark-ff = { version = "0.5.0", features = ["parallel"] } -ark-groth16 = { version = "0.5.0", features = ["parallel"] } -ark-serialize = { version = "0.5.0", features = ["parallel"] } +# ark-ff = { version = "0.5.0", features = ["parallel"] } +ark-ff = { version = "0.5.0", features = ["asm"] } +# ark-groth16 = { version = "0.5.0", features = ["parallel"] } +ark-groth16 = { version = "0.5.0", default-features = false, features = [] } +# ark-serialize = { version = "0.5.0", features = ["parallel"] } +ark-serialize = { version = "0.5.0", default-features = false, features = [] } tokio = { version = "1.47.1", features = ["macros", "rt-multi-thread"] } clap = { version = "4.5.46", features = ["derive", "wrap_help"] } url = { version = "2.5.7", features = ["serde"] } diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 663cf3c..2d49e77 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -44,7 +44,7 @@ clap_config = "0.1" metrics = "0.24" metrics-exporter-prometheus = "0.17" metrics-util = "0.20" -rayon = "1.7" +rayon = "1.10" [build-dependencies] tonic-prost-build.workspace = true @@ -52,8 +52,9 @@ tonic-prost-build.workspace = true [dev-dependencies] criterion.workspace = true ark-groth16.workspace = true -tempfile = "3.20" +tempfile = "3.21" tracing-test = "0.2.5" +lazy_static = "1.5.0" [[bench]] name = "prover_bench" diff --git a/prover/benches/prover_bench.rs b/prover/benches/prover_bench.rs index 96f060b..bab4c54 100644 --- a/prover/benches/prover_bench.rs +++ b/prover/benches/prover_bench.rs @@ -11,7 +11,6 @@ use std::time::Duration; // third-party use alloy::primitives::{Address, U256}; use futures::FutureExt; -use parking_lot::RwLock; use tempfile::NamedTempFile; use tokio::sync::Notify; use tokio::task::JoinSet; @@ -29,6 +28,29 @@ use prover_proto::{ SendTransactionRequest, U256 as GrpcU256, Wei as GrpcWei, rln_prover_client::RlnProverClient, }; +use lazy_static::lazy_static; +use std::sync::Once; + +lazy_static! { + static ref TRACING_INIT: Once = Once::new(); +} + +pub fn setup_tracing() { + TRACING_INIT.call_once(|| { + let filter = tracing_subscriber::EnvFilter::from_default_env() + .add_directive("h2=error".parse().unwrap()) + .add_directive("sled::pagecache=error".parse().unwrap()) + .add_directive("opentelemetry_sdk=error".parse().unwrap()); + + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_line_number(true) + .with_file(true) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE) + .init(); + }); +} + async fn proof_sender(port: u16, addresses: Vec
, proof_count: usize) { let chain_id = GrpcU256 { // FIXME: LE or BE? @@ -59,12 +81,15 @@ async fn proof_sender(port: u16, addresses: Vec
, proof_count: usize) { let request = tonic::Request::new(request_0); let response: Response = client.send_transaction(request).await.unwrap(); + assert!(response.into_inner().result); } + // println!("[proof_sender] returning..."); } -async fn proof_collector(port: u16, proof_count: usize) -> Vec { - let result = Arc::new(RwLock::new(Vec::with_capacity(proof_count))); +async fn proof_collector(port: u16, proof_count: usize) -> Option> { + // let result = Arc::new(RwLock::new(Vec::with_capacity(proof_count))); + let mut result = Vec::with_capacity(proof_count); let url = format!("http://127.0.0.1:{port}"); let mut client = RlnProverClient::connect(url).await.unwrap(); @@ -74,20 +99,37 @@ async fn proof_collector(port: u16, proof_count: usize) -> Vec { let request = tonic::Request::new(request_0); let stream_ = client.get_proofs(request).await.unwrap(); let mut stream = stream_.into_inner(); - let result_2 = result.clone(); + // let result_2 = result.clone(); let mut proof_received = 0; - while let Some(response) = stream.message().await.unwrap() { - result_2.write().push(response); + + loop { + let response = stream.message().await; + if let Err(_e) = response { + // println!("[proof_collector] error: {:?}", _e); + break; + } + + let response = response.unwrap(); + + if response.is_none() { + // println!("[proof_collector] response is None"); + break; + } + + result.push(response.unwrap()); proof_received += 1; if proof_received >= proof_count { break; } } - std::mem::take(&mut *result.write()) + // println!("[proof_collector] returning after received: {:?} proof replies", result.len()); + Some(std::mem::take(&mut result)) } fn proof_generation_bench(c: &mut Criterion) { + // setup_tracing(); + let rayon_num_threads = std::env::var("RAYON_NUM_THREADS").unwrap_or("".to_string()); let proof_service_count_default = 4; let proof_service_count = std::env::var("PROOF_SERVICE_COUNT") @@ -140,10 +182,10 @@ fn proof_generation_bench(c: &mut Criterion) { no_config: true, metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), metrics_port: 30051, - broadcast_channel_size: 100, + broadcast_channel_size: 500, proof_service_count, - transaction_channel_size: 100, - proof_sender_channel_size: 100, + transaction_channel_size: 500, + proof_sender_channel_size: 500, }; // Tokio notify - wait for some time after spawning run_prover then notify it's ready to accept @@ -190,14 +232,20 @@ fn proof_generation_bench(c: &mut Criterion) { b.to_async(&rt).iter(|| { async { let mut set = JoinSet::new(); - set.spawn(proof_collector(port, proof_count)); - set.spawn(proof_sender(port, addresses.clone(), proof_count).map(|_r| vec![])); + set.spawn(proof_collector(port, proof_count)); // return Option> + set.spawn(proof_sender(port, addresses.clone(), proof_count).map(|_r| None)); // Map to None // Wait for proof_sender + proof_collector to complete let res = set.join_all().await; - // Check proof_sender return an empty vec - assert_eq!(res.iter().filter(|r| r.is_empty()).count(), 1); + assert_eq!(res.len(), 2); + // Check proof_sender return None + assert_eq!(res.iter().filter(|r| r.is_none()).count(), 1); // Check we receive enough proofs - assert_eq!(res.iter().filter(|r| r.len() == proof_count).count(), 1); + assert_eq!( + res.iter() + .filter(|r| { r.as_ref().map(|v| v.len()).unwrap_or(0) == proof_count }) + .count(), + 1 + ); } }); }, diff --git a/prover/src/args.rs b/prover/src/args.rs index abe30ac..82e618a 100644 --- a/prover/src/args.rs +++ b/prover/src/args.rs @@ -21,7 +21,7 @@ const ARGS_DEFAULT_PROOF_SERVICE_COUNT: &str = "8"; /// /// Used by grpc service to send the transaction to one of the proof services. A too low value could stall /// the grpc service when it receives a transaction. -const ARGS_DEFAULT_TRANSACTION_CHANNEL_SIZE: &str = "100"; +const ARGS_DEFAULT_TRANSACTION_CHANNEL_SIZE: &str = "256"; /// Proof sender channel size /// /// Used by grpc service to send the generated proof to the Verifier. A too low value could stall diff --git a/prover/src/error.rs b/prover/src/error.rs index ae24eba..c20ee19 100644 --- a/prover/src/error.rs +++ b/prover/src/error.rs @@ -1,10 +1,14 @@ +use alloy::signers::local::LocalSignerError; use alloy::transports::{RpcError, TransportErrorKind}; use ark_serialize::SerializationError; use rln::error::ProofError; use smart_contract::{KarmaScError, KarmaTiersError, RlnScError}; // internal use crate::epoch_service::WaitUntilError; -use crate::user_db_error::{RegisterError, UserMerkleTreeIndexError}; +use crate::tier::ValidateTierLimitsError; +use crate::user_db_error::{ + RegisterError, TxCounterError, UserDbOpenError, UserMerkleTreeIndexError, +}; #[derive(thiserror::Error, Debug)] pub enum AppError { @@ -26,6 +30,16 @@ pub enum AppError { KarmaTiersError(#[from] KarmaTiersError), #[error(transparent)] RlnScError(#[from] RlnScError), + #[error(transparent)] + SignerInitError(#[from] LocalSignerError), + #[error(transparent)] + ValidateTierError(#[from] ValidateTierLimitsError), + #[error(transparent)] + UserDbOpenError(#[from] UserDbOpenError), + #[error(transparent)] + MockUserRegisterError(#[from] RegisterError), + #[error(transparent)] + MockUserTxCounterError(#[from] TxCounterError), } #[derive(thiserror::Error, Debug)] diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 0339b3f..2965c77 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -32,15 +32,12 @@ use alloy::providers::{ProviderBuilder, WsConnect}; use alloy::signers::local::PrivateKeySigner; use chrono::{DateTime, Utc}; use tokio::task::JoinSet; -use tracing::{ - debug, - // error, - // info -}; +use tracing::{debug, info}; use zeroize::Zeroizing; // internal pub use crate::args::{AppArgs, AppArgsConfig}; use crate::epoch_service::EpochService; +use crate::error::AppError; use crate::grpc_service::GrpcProverService; pub use crate::mock::MockUser; use crate::mock::read_mock_user; @@ -61,9 +58,7 @@ const GENESIS: DateTime = DateTime::from_timestamp(1431648000, 0).unwrap(); const PROVER_MINIMAL_AMOUNT_FOR_REGISTRATION: U256 = U256::from_le_slice(10u64.to_le_bytes().as_slice()); -pub async fn run_prover( - app_args: AppArgs, -) -> Result<(), Box> { +pub async fn run_prover(app_args: AppArgs) -> Result<(), AppError> { // Epoch let epoch_service = EpochService::try_from((Duration::from_secs(60 * 2), GENESIS)) .expect("Failed to create epoch service"); @@ -71,10 +66,7 @@ pub async fn run_prover( // Alloy provider (Smart contract provider) let provider = if app_args.ws_rpc_url.is_some() { let ws = WsConnect::new(app_args.ws_rpc_url.clone().unwrap().as_str()); - let provider = ProviderBuilder::new() - .connect_ws(ws) - .await - .map_err(KarmaTiersError::RpcTransportError)?; + let provider = ProviderBuilder::new().connect_ws(ws).await?; Some(provider) } else { None @@ -146,7 +138,7 @@ pub async fn run_prover( debug!("User {} already registered", mock_user.address); } _ => { - return Err(Box::new(e)); + return Err(AppError::from(e)); } } } @@ -155,9 +147,8 @@ pub async fn run_prover( } // Smart contract - // FIXME: use provider let registry_listener = if app_args.mock_sc.is_some() { - // debug!("No registry listener when mock is enabled"); + // No registry listener when mock is enabled None } else { Some(RegistryListener::new( @@ -185,7 +176,7 @@ pub async fn run_prover( let rln_identifier = RlnIdentifier::new(RLN_IDENTIFIER_NAME); let addr = SocketAddr::new(app_args.ip, app_args.port); - debug!("Listening on: {}", addr); + info!("Listening on: {}", addr); let prover_grpc_service = { let mut service = GrpcProverService { proof_sender, @@ -242,11 +233,17 @@ pub async fn run_prover( if app_args.ws_rpc_url.is_some() { set.spawn(async move { prover_grpc_service.serve().await }); } else { - debug!("Grpc service started with mocked smart contracts"); + info!("Grpc service started with mocked smart contracts"); set.spawn(async move { prover_grpc_service.serve_with_mock().await }); } - // TODO: handle error - let _ = set.join_all().await; + let res = set.join_all().await; + // Print all errors from services (if any) + // We expect that the Prover should never stop unexpectedly, but printing error can help to debug + res.iter().for_each(|r| { + if r.is_err() { + info!("Error: {:?}", r); + } + }); Ok(()) } diff --git a/prover/src/proof_service.rs b/prover/src/proof_service.rs index 989bbb6..03ab5c2 100644 --- a/prover/src/proof_service.rs +++ b/prover/src/proof_service.rs @@ -6,7 +6,7 @@ use ark_serialize::CanonicalSerialize; use async_channel::Receiver; use metrics::{counter, histogram}; use parking_lot::RwLock; -use rln::hashers::hash_to_field; +use rln::hashers::hash_to_field_le; use rln::protocol::serialize_proof_values; use tracing::{ Instrument, // debug, @@ -101,7 +101,7 @@ impl ProofService { let rln_data = RlnData { message_id: Fr::from(message_id), - data: hash_to_field(proof_generation_data.tx_hash.as_slice()), + data: hash_to_field_le(proof_generation_data.tx_hash.as_slice()), }; let epoch_bytes = { @@ -109,7 +109,7 @@ impl ProofService { v.extend(current_epoch_slice.to_le_bytes()); v }; - let epoch = hash_to_field(epoch_bytes.as_slice()); + let epoch = hash_to_field_le(epoch_bytes.as_slice()); let merkle_proof = match user_db.get_merkle_proof(&proof_generation_data.tx_sender) { @@ -159,6 +159,15 @@ impl ProofService { let _ = send.send(Ok::, ProofGenerationError>( output_buffer.into_inner(), )); + + /* + std::thread::sleep(std::time::Duration::from_millis(100)); + let mut output_buffer = Cursor::new(Vec::with_capacity(PROOF_SIZE)); + // Send the result back to Tokio. + let _ = send.send(Ok::, ProofGenerationError>( + output_buffer.into_inner(), + )); + */ }); // Wait for the rayon task. @@ -271,7 +280,7 @@ mod tests { debug!("Starting broadcast receiver..."); tokio::time::sleep(std::time::Duration::from_secs(1)).await; let res = - tokio::time::timeout(std::time::Duration::from_secs(5), broadcast_receiver.recv()) + tokio::time::timeout(std::time::Duration::from_secs(7), broadcast_receiver.recv()) .await .map_err(|_e| AppErrorExt::Elapsed)?; debug!("res: {:?}", res); diff --git a/prover/src/proof_service_tests.rs b/prover/src/proof_service_tests.rs index 847e9e0..1b971d7 100644 --- a/prover/src/proof_service_tests.rs +++ b/prover/src/proof_service_tests.rs @@ -5,7 +5,6 @@ mod tests { use std::sync::Arc; // third-party use alloy::primitives::{Address, address}; - use ark_bn254::Fr; use ark_groth16::{Proof as ArkProof, Proof, VerifyingKey}; use ark_serialize::CanonicalDeserialize; use claims::assert_matches; @@ -14,6 +13,7 @@ mod tests { use rln::circuit::{Curve, zkey_from_folder}; use rln::error::ComputeIdSecretError; use rln::protocol::{compute_id_secret, deserialize_proof_values, verify_proof}; + use rln::utils::IdSecret; use tokio::sync::broadcast; use tracing::{debug, info}; // internal @@ -50,7 +50,7 @@ mod tests { #[error(transparent)] RecoverSecretFailed(ComputeIdSecretError), #[error("Recovered secret")] - RecoveredSecret(Fr), + RecoveredSecret(IdSecret), } async fn proof_sender( diff --git a/prover/src/rocksdb_operands.rs b/prover/src/rocksdb_operands.rs index 7f76447..0e62f3d 100644 --- a/prover/src/rocksdb_operands.rs +++ b/prover/src/rocksdb_operands.rs @@ -56,6 +56,7 @@ impl EpochCounterSerializer { } } +#[derive(Clone)] pub struct EpochCounterDeserializer {} impl EpochCounterDeserializer { @@ -86,6 +87,7 @@ pub struct EpochIncr { pub incr_value: u64, } +#[derive(Clone)] pub struct EpochIncrSerializer {} impl EpochIncrSerializer { diff --git a/prover/src/user_db.rs b/prover/src/user_db.rs index 857be46..cc8553b 100644 --- a/prover/src/user_db.rs +++ b/prover/src/user_db.rs @@ -1,5 +1,4 @@ use std::path::PathBuf; -use std::str::FromStr; use std::sync::Arc; // third-party use alloy::primitives::{Address, U256}; @@ -15,8 +14,8 @@ use rln::{ use rocksdb::{ ColumnFamily, ColumnFamilyDescriptor, DB, Options, ReadOptions, WriteBatch, WriteBatchWithIndex, }; -use serde::{Deserialize, Serialize}; use tracing::error; +use zerokit_utils::Mode::HighThroughput; use zerokit_utils::{ error::ZerokitMerkleTreeError, pmtree::{PmtreeErrorKind, TreeErrorKind}, @@ -62,22 +61,20 @@ pub struct UserTierInfo { pub(crate) tier_limit: Option, } -#[derive(Serialize, Deserialize)] -struct PmTreeConfigJson { - path: PathBuf, - temporary: bool, - cache_capacity: u64, - flush_every_ms: u64, - mode: String, - use_compression: bool, -} - #[derive(Clone)] pub(crate) struct UserDb { db: Arc, merkle_tree: Arc>, rate_limit: RateLimit, pub(crate) epoch_store: Arc>, + rln_identity_serializer: RlnUserIdentitySerializer, + rln_identity_deserializer: RlnUserIdentityDeserializer, + merkle_index_serializer: MerkleTreeIndexSerializer, + merkle_index_deserializer: MerkleTreeIndexDeserializer, + epoch_increase_serializer: EpochIncrSerializer, + epoch_counter_deserializer: EpochCounterDeserializer, + tier_limits_serializer: TierLimitsSerializer, + tier_limits_deserializer: TierLimitsDeserializer, } impl std::fmt::Debug for UserDb { @@ -144,7 +141,10 @@ impl UserDb { // merkle tree index let cf_mtree = db.cf_handle(MERKLE_TREE_COUNTER_CF).unwrap(); - if let Err(e) = Self::get_merkle_tree_index_(db.clone(), cf_mtree) { + let merkle_index_deserializer = MerkleTreeIndexDeserializer {}; + if let Err(e) = + Self::get_merkle_tree_index_(db.clone(), cf_mtree, &merkle_index_deserializer) + { match e { MerkleTreeIndexError::DbUninitialized => { // Check if the value is already there (e.g. after a restart) @@ -156,25 +156,32 @@ impl UserDb { } // merkle tree + let tree_config = PmtreeConfig::builder() + .path(merkle_tree_path) + .temporary(false) + .cache_capacity(100_000) + .flush_every_ms(12_000) + .mode(HighThroughput) + .use_compression(false) + .build()?; + let tree = PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), tree_config)?; - let config_ = PmTreeConfigJson { - path: merkle_tree_path, - temporary: false, - cache_capacity: 100_000, - flush_every_ms: 12_000, - mode: "HighThroughput".to_string(), - use_compression: false, + let tier_limits_deserializer = TierLimitsDeserializer { + tier_deserializer: TierDeserializer {}, }; - let config_str = serde_json::to_string(&config_)?; - // Note: in Zerokit 0.8 this is the only way to initialize a PmTreeConfig - let config = PmtreeConfig::from_str(config_str.as_str())?; - let tree = PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), config)?; - Ok(Self { db, merkle_tree: Arc::new(RwLock::new(tree)), rate_limit, epoch_store, + rln_identity_serializer: RlnUserIdentitySerializer {}, + rln_identity_deserializer: RlnUserIdentityDeserializer {}, + merkle_index_serializer: MerkleTreeIndexSerializer {}, + merkle_index_deserializer, + epoch_increase_serializer: EpochIncrSerializer {}, + epoch_counter_deserializer: EpochCounterDeserializer {}, + tier_limits_serializer, + tier_limits_deserializer, }) } @@ -199,24 +206,23 @@ impl UserDb { } pub(crate) fn register(&self, address: Address) -> Result { - let rln_identity_serializer = RlnUserIdentitySerializer {}; - let merkle_index_serializer = MerkleTreeIndexSerializer {}; - let merkle_index_deserializer = MerkleTreeIndexDeserializer {}; - let (identity_secret_hash, id_commitment) = keygen(); let rln_identity = RlnUserIdentity::from(( - identity_secret_hash, id_commitment, + identity_secret_hash, Fr::from(self.rate_limit), )); let key = address.as_slice(); - let mut buffer = - vec![0; rln_identity_serializer.size_hint() + merkle_index_serializer.size_hint()]; + let mut buffer = vec![ + 0; + self.rln_identity_serializer.size_hint() + + self.merkle_index_serializer.size_hint() + ]; // unwrap safe - this is serialized by the Prover + RlnUserIdentitySerializer is unit tested - rln_identity_serializer + self.rln_identity_serializer .serialize(&rln_identity, &mut buffer) .unwrap(); @@ -257,7 +263,8 @@ impl UserDb { // Increase merkle tree index db_batch.merge_cf(cf_mtree, MERKLE_TREE_INDEX_KEY, 1i64.to_le_bytes()); // Unwrap safe - serialization is handled by the prover - let (_, new_index) = merkle_index_deserializer + let (_, new_index) = self + .merkle_index_deserializer .deserialize(batch_read.as_slice()) .unwrap(); @@ -281,7 +288,8 @@ impl UserDb { })?; // Add index for user - merkle_index_serializer.serialize(&new_index, &mut buffer); + self.merkle_index_serializer + .serialize(&new_index, &mut buffer); // Put user db_batch.put_cf(cf_user, key, buffer.as_slice()); // Put user tx counter @@ -311,11 +319,10 @@ impl UserDb { pub fn get_user(&self, address: &Address) -> Option { let cf_user = self.get_user_cf(); - let rln_identity_deserializer = RlnUserIdentityDeserializer {}; match self.db.get_pinned_cf(cf_user, address.as_slice()) { Ok(Some(value)) => { // Here we silence the error - this is safe as the prover controls this - rln_identity_deserializer.deserialize(&value).ok() + self.rln_identity_deserializer.deserialize(&value).ok() } Ok(None) => None, Err(_e) => None, @@ -327,13 +334,12 @@ impl UserDb { address: &Address, ) -> Result { let cf_user = self.get_user_cf(); - let rln_identity_serializer = RlnUserIdentitySerializer {}; - let merkle_tree_index_deserializer = MerkleTreeIndexDeserializer {}; match self.db.get_pinned_cf(cf_user, address.as_slice()) { Ok(Some(buffer)) => { // Here we silence the error - this is safe as the prover controls this - let start = rln_identity_serializer.size_hint(); - let (_, index) = merkle_tree_index_deserializer + let start = self.rln_identity_serializer.size_hint(); + let (_, index) = self + .merkle_index_deserializer .deserialize(&buffer[start..]) .unwrap(); Ok(index) @@ -398,9 +404,8 @@ impl UserDb { epoch_slice, incr_value, }; - let incr_ser = EpochIncrSerializer {}; - let mut buffer = Vec::with_capacity(incr_ser.size_hint()); - incr_ser.serialize(&incr, &mut buffer); + let mut buffer = Vec::with_capacity(self.epoch_increase_serializer.size_hint()); + self.epoch_increase_serializer.serialize(&incr, &mut buffer); // Create a transaction // By using a WriteBatchWithIndex, we can "read your own writes" so here we incr then read the new value @@ -435,11 +440,9 @@ impl UserDb { address: &Address, key: Option>, ) -> Result<(EpochCounter, EpochSliceCounter), TxCounterError> { - let deserializer = EpochCounterDeserializer {}; - match key { Some(value) => { - let (_, counter) = deserializer.deserialize(&value).unwrap(); + let (_, counter) = self.epoch_counter_deserializer.deserialize(&value).unwrap(); let (epoch, epoch_slice) = *self.epoch_store.read(); let cmp = (counter.epoch == epoch, counter.epoch_slice == epoch_slice); @@ -490,19 +493,20 @@ impl UserDb { #[cfg(test)] pub(crate) fn get_merkle_tree_index(&self) -> Result { let cf_mtree = self.get_mtree_cf(); - Self::get_merkle_tree_index_(self.db.clone(), cf_mtree) + Self::get_merkle_tree_index_(self.db.clone(), cf_mtree, &self.merkle_index_deserializer) } fn get_merkle_tree_index_( db: Arc, cf: &ColumnFamily, + merkle_tree_index_deserializer: &MerkleTreeIndexDeserializer, ) -> Result { - let deserializer = MerkleTreeIndexDeserializer {}; - match db.get_cf(cf, MERKLE_TREE_INDEX_KEY) { Ok(Some(v)) => { // Unwrap safe - serialization is done by the prover - let (_, index) = deserializer.deserialize(v.as_slice()).unwrap(); + let (_, index) = merkle_tree_index_deserializer + .deserialize(v.as_slice()) + .unwrap(); Ok(index) } Ok(None) => Err(MerkleTreeIndexError::DbUninitialized), @@ -541,12 +545,8 @@ impl UserDb { let cf = self.get_tier_limits_cf(); // Unwrap safe - Db is initialized with valid tier limits let buffer = self.db.get_cf(cf, TIER_LIMITS_KEY.as_slice())?.unwrap(); - let tier_limits_deserializer = TierLimitsDeserializer { - tier_deserializer: TierDeserializer {}, - }; - // Unwrap safe - serialized by the prover (should always deserialize) - let (_, tier_limits) = tier_limits_deserializer.deserialize(&buffer).unwrap(); + let (_, tier_limits) = self.tier_limits_deserializer.deserialize(&buffer).unwrap(); Ok(tier_limits) } @@ -557,10 +557,10 @@ impl UserDb { tier_limits.validate()?; // Serialize - let tier_limits_serializer = TierLimitsSerializer::default(); - let mut buffer = Vec::with_capacity(tier_limits_serializer.size_hint(tier_limits.len())); + let mut buffer = + Vec::with_capacity(self.tier_limits_serializer.size_hint(tier_limits.len())); // Unwrap safe - already validated - should always serialize - tier_limits_serializer + self.tier_limits_serializer .serialize(&tier_limits, &mut buffer) .unwrap(); @@ -798,16 +798,15 @@ mod tests { .unwrap(); let temp_folder_tree_2 = tempfile::tempdir().unwrap(); - let config_ = PmTreeConfigJson { - path: temp_folder_tree_2.path().to_path_buf(), - temporary: false, - cache_capacity: 100_000, - flush_every_ms: 12_000, - mode: "HighThroughput".to_string(), - use_compression: false, - }; - let config_str = serde_json::to_string(&config_).unwrap(); - let config = PmtreeConfig::from_str(config_str.as_str()).unwrap(); + let config = PmtreeConfig::builder() + .path(temp_folder_tree_2.path().to_path_buf()) + .temporary(false) + .cache_capacity(100_000) + .flush_every_ms(12_000) + .mode(HighThroughput) + .use_compression(false) + .build() + .unwrap(); let tree = PoseidonTree::new(1, Default::default(), config).unwrap(); let tree = Arc::new(RwLock::new(tree)); user_db.merkle_tree = tree.clone(); diff --git a/prover/src/user_db_error.rs b/prover/src/user_db_error.rs index d0c8299..1583b95 100644 --- a/prover/src/user_db_error.rs +++ b/prover/src/user_db_error.rs @@ -6,7 +6,7 @@ use zerokit_utils::error::{FromConfigError, ZerokitMerkleTreeError}; use crate::tier::ValidateTierLimitsError; #[derive(Debug, thiserror::Error)] -pub(crate) enum UserDbOpenError { +pub enum UserDbOpenError { #[error(transparent)] RocksDb(#[from] rocksdb::Error), #[error("Serialization error: {0}")] diff --git a/prover/src/user_db_serialization.rs b/prover/src/user_db_serialization.rs index a3bdb6b..3bcbfaf 100644 --- a/prover/src/user_db_serialization.rs +++ b/prover/src/user_db_serialization.rs @@ -12,12 +12,14 @@ use nom::{ multi::length_count, number::complete::{le_u32, le_u64}, }; +use rln::utils::IdSecret; use rln_proof::RlnUserIdentity; // internal use crate::tier::TierLimits; use crate::user_db_types::MerkleTreeIndex; use smart_contract::Tier; +#[derive(Clone)] pub(crate) struct RlnUserIdentitySerializer {} impl RlnUserIdentitySerializer { @@ -41,6 +43,7 @@ impl RlnUserIdentitySerializer { } } +#[derive(Clone)] pub(crate) struct RlnUserIdentityDeserializer {} impl RlnUserIdentityDeserializer { @@ -49,8 +52,8 @@ impl RlnUserIdentityDeserializer { let (co_buffer, rem_buffer) = buffer.split_at(compressed_size); let commitment: Fr = CanonicalDeserialize::deserialize_compressed(co_buffer)?; let (secret_buffer, user_limit_buffer) = rem_buffer.split_at(compressed_size); - // TODO: IdSecret (wait for Zerokit PR: https://github.com/vacp2p/zerokit/pull/320) - let secret_hash: Fr = CanonicalDeserialize::deserialize_compressed(secret_buffer)?; + let mut secret_hash_: Fr = CanonicalDeserialize::deserialize_compressed(secret_buffer)?; + let secret_hash = IdSecret::from(&mut secret_hash_); let user_limit: Fr = CanonicalDeserialize::deserialize_compressed(user_limit_buffer)?; Ok({ @@ -63,6 +66,7 @@ impl RlnUserIdentityDeserializer { } } +#[derive(Clone)] pub(crate) struct MerkleTreeIndexSerializer {} impl MerkleTreeIndexSerializer { @@ -77,6 +81,7 @@ impl MerkleTreeIndexSerializer { } } +#[derive(Clone)] pub(crate) struct MerkleTreeIndexDeserializer {} impl MerkleTreeIndexDeserializer { @@ -88,7 +93,7 @@ impl MerkleTreeIndexDeserializer { } } -#[derive(Default)] +#[derive(Default, Clone)] pub(crate) struct TierSerializer {} impl TierSerializer { @@ -113,7 +118,7 @@ impl TierSerializer { } } -#[derive(Default)] +#[derive(Default, Clone)] pub(crate) struct TierDeserializer {} #[derive(Debug, PartialEq)] @@ -166,7 +171,7 @@ impl TierDeserializer { } } -#[derive(Default)] +#[derive(Default, Clone)] pub(crate) struct TierLimitsSerializer { tier_serializer: TierSerializer, } @@ -193,7 +198,7 @@ impl TierLimitsSerializer { } } -#[derive(Default)] +#[derive(Default, Clone)] pub(crate) struct TierLimitsDeserializer { pub(crate) tier_deserializer: TierDeserializer, } @@ -226,7 +231,7 @@ mod tests { fn test_rln_ser_der() { let rln_user_identity = RlnUserIdentity { commitment: Fr::from(42), - secret_hash: Fr::from(u16::MAX), + secret_hash: IdSecret::from(&mut Fr::from(u16::MAX)), user_limit: Fr::from(1_000_000), }; let serializer = RlnUserIdentitySerializer {}; diff --git a/prover_cli/Cargo.toml b/prover_cli/Cargo.toml index 400951a..c1f14d0 100644 --- a/prover_cli/Cargo.toml +++ b/prover_cli/Cargo.toml @@ -24,3 +24,4 @@ opentelemetry-otlp = { version = "0.30.0", features = [ "tls-roots", ] } tracing-opentelemetry = "0.31.0" +anyhow = "1.0.99" diff --git a/prover_cli/src/main.rs b/prover_cli/src/main.rs index 6a0123c..b900f57 100644 --- a/prover_cli/src/main.rs +++ b/prover_cli/src/main.rs @@ -11,6 +11,7 @@ use tracing::{ }; use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; +use anyhow::{Context, Result, anyhow}; use opentelemetry::trace::TracerProvider; use opentelemetry_otlp::WithTonicConfig; use opentelemetry_sdk::Resource; @@ -20,7 +21,7 @@ use prover::{AppArgs, AppArgsConfig, metrics::init_metrics, run_prover}; const APP_NAME: &str = "prover-cli"; #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<()> { // install crypto provider for rustls - required for WebSocket TLS connections rustls::crypto::CryptoProvider::install_default(aws_lc_rs::default_provider()) .expect("Failed to install default CryptoProvider"); @@ -58,7 +59,8 @@ async fn main() -> Result<(), Box // Unwrap safe - default value provided let config_path = app_args.get_one::("config_path").unwrap(); debug!("Reading config path: {:?}...", config_path); - let config_str = std::fs::read_to_string(config_path)?; + let config_str = std::fs::read_to_string(config_path) + .context(format!("Failed to read config file: {:?}", config_path))?; let config: AppArgsConfig = toml::from_str(config_str.as_str())?; debug!("Config: {:?}", config); config @@ -76,15 +78,17 @@ async fn main() -> Result<(), Box || app_args.ksc_address.is_none() || app_args.tsc_address.is_none() { - return Err("Please provide smart contract addresses".into()); + return Err(anyhow!("Please provide smart contract addresses")); } } else if app_args.mock_sc.is_none() { - return Err("Please provide rpc url (--ws-rpc-url) or mock (--mock-sc)".into()); + return Err(anyhow!( + "Please provide rpc url (--ws-rpc-url) or mock (--mock-sc)" + )); } init_metrics(app_args.metrics_ip, &app_args.metrics_port); - run_prover(app_args).await + run_prover(app_args).await.map_err(anyhow::Error::new) } fn create_otlp_tracer_provider() -> Option { diff --git a/rln_proof/benches/generate_proof.rs b/rln_proof/benches/generate_proof.rs index aa3902b..65a211e 100644 --- a/rln_proof/benches/generate_proof.rs +++ b/rln_proof/benches/generate_proof.rs @@ -1,3 +1,4 @@ +use std::hint::black_box; // std use std::io::{Cursor, Write}; // criterion @@ -5,7 +6,7 @@ use criterion::{Criterion, criterion_group, criterion_main}; // third-party use ark_bn254::Fr; use ark_serialize::CanonicalSerialize; -use rln::hashers::{hash_to_field, poseidon_hash}; +use rln::hashers::{hash_to_field_le, poseidon_hash}; use rln::poseidon_tree::PoseidonTree; use rln::protocol::{keygen, serialize_proof_values}; // internal @@ -24,7 +25,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { let rln_identifier = RlnIdentifier::new(b"test-test"); let rln_data = RlnData { message_id: Fr::from(user_limit - 2), - data: hash_to_field(b"data-from-message"), + data: hash_to_field_le(b"data-from-message"), }; // Merkle tree @@ -35,7 +36,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { let merkle_proof = tree.proof(0).unwrap(); // Epoch - let epoch = hash_to_field(b"Today at noon, this year"); + let epoch = hash_to_field_le(b"Today at noon, this year"); { // Not a benchmark but print the proof size (serialized) @@ -61,17 +62,6 @@ pub fn criterion_benchmark(c: &mut Criterion) { } c.bench_function("compute proof and values", |b| { - /* - b.iter(|| { - compute_rln_proof_and_values( - &rln_identity, - &rln_identifier, - rln_data.clone(), - epoch, - &merkle_proof, - ) - }) - */ b.iter_batched( || { // generate setup data @@ -80,11 +70,11 @@ pub fn criterion_benchmark(c: &mut Criterion) { |data| { // function to benchmark compute_rln_proof_and_values( - &rln_identity, - &rln_identifier, - data, - epoch, - &merkle_proof, + black_box(&rln_identity), + black_box(&rln_identifier), + black_box(data), + black_box(epoch), + black_box(&merkle_proof), ) }, criterion::BatchSize::SmallInput, @@ -96,19 +86,21 @@ pub fn criterion_benchmark(c: &mut Criterion) { || { // generate setup data compute_rln_proof_and_values( - &rln_identity, - &rln_identifier, - rln_data.clone(), - epoch, - &merkle_proof, + black_box(&rln_identity), + black_box(&rln_identifier), + black_box(rln_data.clone()), + black_box(epoch), + black_box(&merkle_proof), ) .unwrap() }, |(proof, proof_values)| { let mut output_buffer = Cursor::new(Vec::with_capacity(320)); - proof.serialize_compressed(&mut output_buffer).unwrap(); + proof + .serialize_compressed(black_box(&mut output_buffer)) + .unwrap(); output_buffer - .write_all(&serialize_proof_values(&proof_values)) + .write_all(black_box(&serialize_proof_values(black_box(&proof_values)))) .unwrap(); }, criterion::BatchSize::SmallInput, diff --git a/rln_proof/src/proof.rs b/rln_proof/src/proof.rs index 827f503..e73c084 100644 --- a/rln_proof/src/proof.rs +++ b/rln_proof/src/proof.rs @@ -1,29 +1,31 @@ // std -use std::io::Cursor; +// use std::io::Cursor; // third-party use ark_bn254::{Bn254, Fr}; use ark_groth16::{Proof, ProvingKey}; use ark_relations::r1cs::ConstraintMatrices; +use rln::utils::IdSecret; use rln::{ - circuit::{ZKEY_BYTES, zkey::read_zkey}, + circuit::{ARKZKEY_BYTES, read_arkzkey_from_bytes_uncompressed as read_zkey}, error::ProofError, - hashers::{hash_to_field, poseidon_hash}, + hashers::{hash_to_field_le, poseidon_hash}, poseidon_tree::MerkleProof, protocol::{ RLNProofValues, generate_proof, proof_values_from_witness, rln_witness_from_values, }, }; +use zerokit_utils::ZerokitMerkleProof; /// A RLN user identity & limit #[derive(Debug, Clone, PartialEq)] pub struct RlnUserIdentity { pub commitment: Fr, - pub secret_hash: Fr, + pub secret_hash: IdSecret, pub user_limit: Fr, } -impl From<(Fr, Fr, Fr)> for RlnUserIdentity { - fn from((commitment, secret_hash, user_limit): (Fr, Fr, Fr)) -> Self { +impl From<(Fr, IdSecret, Fr)> for RlnUserIdentity { + fn from((commitment, secret_hash, user_limit): (Fr, IdSecret, Fr)) -> Self { Self { commitment, secret_hash, @@ -43,13 +45,13 @@ pub struct RlnIdentifier { impl RlnIdentifier { pub fn new(identifier: &[u8]) -> Self { let pk_and_matrices = { - let mut reader = Cursor::new(ZKEY_BYTES); - read_zkey(&mut reader).unwrap() + // let mut reader = Cursor::new(ARKZKEY_BYTES); + read_zkey(ARKZKEY_BYTES).unwrap() }; let graph_bytes = include_bytes!("../resources/graph.bin"); Self { - identifier: hash_to_field(identifier), + identifier: hash_to_field_le(identifier), pkey_and_constraints: pk_and_matrices, graph: graph_bytes.to_vec(), } @@ -74,9 +76,15 @@ pub fn compute_rln_proof_and_values( ) -> Result<(Proof, RLNProofValues), ProofError> { let external_nullifier = poseidon_hash(&[rln_identifier.identifier, epoch]); + let path_elements = merkle_proof.get_path_elements(); + let identity_path_index = merkle_proof.get_path_index(); + + // let mut id_s = user_identity.secret_hash; + let witness = rln_witness_from_values( - user_identity.secret_hash, - merkle_proof, + user_identity.secret_hash.clone(), + path_elements, + identity_path_index, rln_data.data, external_nullifier, user_identity.user_limit, @@ -101,8 +109,9 @@ mod tests { #[test] fn test_recover_secret_hash() { - let (user_co, user_sh) = keygen(); - let epoch = hash_to_field(b"foo"); + let (user_co, mut user_sh_) = keygen(); + let user_sh = IdSecret::from(&mut user_sh_); + let epoch = hash_to_field_le(b"foo"); let spam_limit = Fr::from(10); // let mut tree = OptimalMerkleTree::new(20, Default::default(), Default::default()).unwrap(); @@ -116,14 +125,14 @@ mod tests { let (_proof_0, proof_values_0) = compute_rln_proof_and_values( &RlnUserIdentity { - commitment: user_co, - secret_hash: user_sh, + commitment: *user_co, + secret_hash: user_sh.clone(), user_limit: spam_limit, }, &rln_identifier, RlnData { message_id, - data: hash_to_field(b"sig"), + data: hash_to_field_le(b"sig"), }, epoch, &m_proof, @@ -132,14 +141,14 @@ mod tests { let (_proof_1, proof_values_1) = compute_rln_proof_and_values( &RlnUserIdentity { - commitment: user_co, - secret_hash: user_sh, + commitment: *user_co, + secret_hash: user_sh.clone(), user_limit: spam_limit, }, &rln_identifier, RlnData { message_id, - data: hash_to_field(b"sig 2"), + data: hash_to_field_le(b"sig 2"), }, epoch, &m_proof,