Initial code to use Zerokit 0.9 + disable parallel feature (#36)

* Initial code to use Zerokit 0.9 + disable parallel feature

* Support IdSecret for user identity secret hash

* Fix clippy + bench

* Use PmTreeConfig builder

* Improve prover_bench perf

* Fix prover_bench 2nd assert

* Fix prover_bench 2nd assert 2

* Can now enable trace for bench prover_bench

* Use anyhow for error handling (+ error context) in prover_cli (#42)

* Use anyhow for error handling (+ error context) in prover_cli

* Cargo fmt pass

* Feature/feature/init user db ser de 2 (#45)

* Add user db serializer && deserializer init & re-use
This commit is contained in:
Sydhds
2025-10-01 17:00:56 +02:00
committed by GitHub
parent 55e7f6c3e2
commit b7967b85e8
17 changed files with 273 additions and 191 deletions

25
Cargo.lock generated
View File

@@ -850,7 +850,6 @@ dependencies = [
"digest 0.10.7",
"fnv",
"merlin",
"rayon",
"sha2",
]
@@ -883,7 +882,6 @@ dependencies = [
"num-bigint",
"num-integer",
"num-traits",
"rayon",
"zeroize",
]
@@ -942,7 +940,6 @@ dependencies = [
"num-bigint",
"num-traits",
"paste",
"rayon",
"zeroize",
]
@@ -1027,7 +1024,6 @@ dependencies = [
"ark-relations",
"ark-serialize 0.5.0",
"ark-std 0.5.0",
"rayon",
]
[[package]]
@@ -1043,7 +1039,6 @@ dependencies = [
"educe",
"fnv",
"hashbrown 0.15.5",
"rayon",
]
[[package]]
@@ -1107,7 +1102,6 @@ dependencies = [
"arrayvec",
"digest 0.10.7",
"num-bigint",
"rayon",
]
[[package]]
@@ -1161,7 +1155,6 @@ checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a"
dependencies = [
"num-traits",
"rand 0.8.5",
"rayon",
]
[[package]]
@@ -3937,6 +3930,7 @@ dependencies = [
"derive_more",
"futures",
"http",
"lazy_static",
"metrics",
"metrics-exporter-prometheus",
"metrics-util",
@@ -3973,6 +3967,7 @@ dependencies = [
name = "prover_cli"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"opentelemetry",
"opentelemetry-otlp",
@@ -4388,8 +4383,7 @@ dependencies = [
[[package]]
name = "rln"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a03834bc168adfee6f49c885fabb0ad6897be11d97258e9375d98f30d2c9878"
source = "git+https://github.com/vacp2p/zerokit/#0b00c639a059a2cfde74bcf68fdf75db3b6898a4"
dependencies = [
"ark-bn254",
"ark-ec",
@@ -4405,14 +4399,16 @@ dependencies = [
"num-bigint",
"num-traits",
"once_cell",
"prost 0.13.5",
"prost 0.14.1",
"rand 0.8.5",
"rand_chacha 0.3.1",
"ruint",
"serde",
"serde_json",
"tempfile",
"thiserror",
"tiny-keccak",
"zeroize",
"zerokit_utils",
]
@@ -5125,15 +5121,15 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "tempfile"
version = "3.20.0"
version = "3.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1"
checksum = "15b61f8f20e3a6f7e0649d825294eaf317edce30f82cf6026e7e4cb9222a7d1e"
dependencies = [
"fastrand",
"getrandom 0.3.3",
"once_cell",
"rustix 1.0.8",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -6388,8 +6384,7 @@ dependencies = [
[[package]]
name = "zerokit_utils"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a21d5ee8dd5cba6c9e39c7391e3fe968d479b6f1eb51c82556b2bb9b2924f572"
source = "git+https://github.com/vacp2p/zerokit/#0b00c639a059a2cfde74bcf68fdf75db3b6898a4"
dependencies = [
"ark-ff 0.5.0",
"hex",

View File

@@ -9,13 +9,19 @@ members = [
resolver = "2"
[workspace.dependencies]
rln = { version = "0.8.0", features = ["pmtree-ft"] }
zerokit_utils = { version = "0.6.0", features = ["pmtree-ft"] }
# rln = { version = "0.8.0", features = ["pmtree-ft"] }
rln = { git = "https://github.com/vacp2p/zerokit/", default-features = false, features = ["pmtree-ft"] }
# zerokit_utils = { version = "0.6.0", features = ["pmtree-ft"] }
zerokit_utils = { git = "https://github.com/vacp2p/zerokit/", default-features = false, features = ["pmtree-ft"] }
ark-bn254 = { version = "0.5.0", features = ["std"] }
ark-relations = { version = "0.5.1", features = ["std"] }
ark-ff = { version = "0.5.0", features = ["parallel"] }
ark-groth16 = { version = "0.5.0", features = ["parallel"] }
ark-serialize = { version = "0.5.0", features = ["parallel"] }
# ark-ff = { version = "0.5.0", features = ["parallel"] }
ark-ff = { version = "0.5.0", features = ["asm"] }
# ark-groth16 = { version = "0.5.0", features = ["parallel"] }
ark-groth16 = { version = "0.5.0", default-features = false, features = [] }
# ark-serialize = { version = "0.5.0", features = ["parallel"] }
ark-serialize = { version = "0.5.0", default-features = false, features = [] }
tokio = { version = "1.47.1", features = ["macros", "rt-multi-thread"] }
clap = { version = "4.5.46", features = ["derive", "wrap_help"] }
url = { version = "2.5.7", features = ["serde"] }

View File

@@ -44,7 +44,7 @@ clap_config = "0.1"
metrics = "0.24"
metrics-exporter-prometheus = "0.17"
metrics-util = "0.20"
rayon = "1.7"
rayon = "1.10"
[build-dependencies]
tonic-prost-build.workspace = true
@@ -52,8 +52,9 @@ tonic-prost-build.workspace = true
[dev-dependencies]
criterion.workspace = true
ark-groth16.workspace = true
tempfile = "3.20"
tempfile = "3.21"
tracing-test = "0.2.5"
lazy_static = "1.5.0"
[[bench]]
name = "prover_bench"

View File

@@ -11,7 +11,6 @@ use std::time::Duration;
// third-party
use alloy::primitives::{Address, U256};
use futures::FutureExt;
use parking_lot::RwLock;
use tempfile::NamedTempFile;
use tokio::sync::Notify;
use tokio::task::JoinSet;
@@ -29,6 +28,29 @@ use prover_proto::{
SendTransactionRequest, U256 as GrpcU256, Wei as GrpcWei, rln_prover_client::RlnProverClient,
};
use lazy_static::lazy_static;
use std::sync::Once;
lazy_static! {
static ref TRACING_INIT: Once = Once::new();
}
pub fn setup_tracing() {
TRACING_INIT.call_once(|| {
let filter = tracing_subscriber::EnvFilter::from_default_env()
.add_directive("h2=error".parse().unwrap())
.add_directive("sled::pagecache=error".parse().unwrap())
.add_directive("opentelemetry_sdk=error".parse().unwrap());
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_line_number(true)
.with_file(true)
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
.init();
});
}
async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
let chain_id = GrpcU256 {
// FIXME: LE or BE?
@@ -59,12 +81,15 @@ async fn proof_sender(port: u16, addresses: Vec<Address>, proof_count: usize) {
let request = tonic::Request::new(request_0);
let response: Response<SendTransactionReply> =
client.send_transaction(request).await.unwrap();
assert!(response.into_inner().result);
}
// println!("[proof_sender] returning...");
}
async fn proof_collector(port: u16, proof_count: usize) -> Vec<RlnProofReply> {
let result = Arc::new(RwLock::new(Vec::with_capacity(proof_count)));
async fn proof_collector(port: u16, proof_count: usize) -> Option<Vec<RlnProofReply>> {
// let result = Arc::new(RwLock::new(Vec::with_capacity(proof_count)));
let mut result = Vec::with_capacity(proof_count);
let url = format!("http://127.0.0.1:{port}");
let mut client = RlnProverClient::connect(url).await.unwrap();
@@ -74,20 +99,37 @@ async fn proof_collector(port: u16, proof_count: usize) -> Vec<RlnProofReply> {
let request = tonic::Request::new(request_0);
let stream_ = client.get_proofs(request).await.unwrap();
let mut stream = stream_.into_inner();
let result_2 = result.clone();
// let result_2 = result.clone();
let mut proof_received = 0;
while let Some(response) = stream.message().await.unwrap() {
result_2.write().push(response);
loop {
let response = stream.message().await;
if let Err(_e) = response {
// println!("[proof_collector] error: {:?}", _e);
break;
}
let response = response.unwrap();
if response.is_none() {
// println!("[proof_collector] response is None");
break;
}
result.push(response.unwrap());
proof_received += 1;
if proof_received >= proof_count {
break;
}
}
std::mem::take(&mut *result.write())
// println!("[proof_collector] returning after received: {:?} proof replies", result.len());
Some(std::mem::take(&mut result))
}
fn proof_generation_bench(c: &mut Criterion) {
// setup_tracing();
let rayon_num_threads = std::env::var("RAYON_NUM_THREADS").unwrap_or("".to_string());
let proof_service_count_default = 4;
let proof_service_count = std::env::var("PROOF_SERVICE_COUNT")
@@ -140,10 +182,10 @@ fn proof_generation_bench(c: &mut Criterion) {
no_config: true,
metrics_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
metrics_port: 30051,
broadcast_channel_size: 100,
broadcast_channel_size: 500,
proof_service_count,
transaction_channel_size: 100,
proof_sender_channel_size: 100,
transaction_channel_size: 500,
proof_sender_channel_size: 500,
};
// Tokio notify - wait for some time after spawning run_prover then notify it's ready to accept
@@ -190,14 +232,20 @@ fn proof_generation_bench(c: &mut Criterion) {
b.to_async(&rt).iter(|| {
async {
let mut set = JoinSet::new();
set.spawn(proof_collector(port, proof_count));
set.spawn(proof_sender(port, addresses.clone(), proof_count).map(|_r| vec![]));
set.spawn(proof_collector(port, proof_count)); // return Option<Vec<...>>
set.spawn(proof_sender(port, addresses.clone(), proof_count).map(|_r| None)); // Map to None
// Wait for proof_sender + proof_collector to complete
let res = set.join_all().await;
// Check proof_sender return an empty vec
assert_eq!(res.iter().filter(|r| r.is_empty()).count(), 1);
assert_eq!(res.len(), 2);
// Check proof_sender return None
assert_eq!(res.iter().filter(|r| r.is_none()).count(), 1);
// Check we receive enough proofs
assert_eq!(res.iter().filter(|r| r.len() == proof_count).count(), 1);
assert_eq!(
res.iter()
.filter(|r| { r.as_ref().map(|v| v.len()).unwrap_or(0) == proof_count })
.count(),
1
);
}
});
},

View File

@@ -21,7 +21,7 @@ const ARGS_DEFAULT_PROOF_SERVICE_COUNT: &str = "8";
///
/// Used by grpc service to send the transaction to one of the proof services. A too low value could stall
/// the grpc service when it receives a transaction.
const ARGS_DEFAULT_TRANSACTION_CHANNEL_SIZE: &str = "100";
const ARGS_DEFAULT_TRANSACTION_CHANNEL_SIZE: &str = "256";
/// Proof sender channel size
///
/// Used by grpc service to send the generated proof to the Verifier. A too low value could stall

View File

@@ -1,10 +1,14 @@
use alloy::signers::local::LocalSignerError;
use alloy::transports::{RpcError, TransportErrorKind};
use ark_serialize::SerializationError;
use rln::error::ProofError;
use smart_contract::{KarmaScError, KarmaTiersError, RlnScError};
// internal
use crate::epoch_service::WaitUntilError;
use crate::user_db_error::{RegisterError, UserMerkleTreeIndexError};
use crate::tier::ValidateTierLimitsError;
use crate::user_db_error::{
RegisterError, TxCounterError, UserDbOpenError, UserMerkleTreeIndexError,
};
#[derive(thiserror::Error, Debug)]
pub enum AppError {
@@ -26,6 +30,16 @@ pub enum AppError {
KarmaTiersError(#[from] KarmaTiersError),
#[error(transparent)]
RlnScError(#[from] RlnScError),
#[error(transparent)]
SignerInitError(#[from] LocalSignerError),
#[error(transparent)]
ValidateTierError(#[from] ValidateTierLimitsError),
#[error(transparent)]
UserDbOpenError(#[from] UserDbOpenError),
#[error(transparent)]
MockUserRegisterError(#[from] RegisterError),
#[error(transparent)]
MockUserTxCounterError(#[from] TxCounterError),
}
#[derive(thiserror::Error, Debug)]

View File

@@ -32,15 +32,12 @@ use alloy::providers::{ProviderBuilder, WsConnect};
use alloy::signers::local::PrivateKeySigner;
use chrono::{DateTime, Utc};
use tokio::task::JoinSet;
use tracing::{
debug,
// error,
// info
};
use tracing::{debug, info};
use zeroize::Zeroizing;
// internal
pub use crate::args::{AppArgs, AppArgsConfig};
use crate::epoch_service::EpochService;
use crate::error::AppError;
use crate::grpc_service::GrpcProverService;
pub use crate::mock::MockUser;
use crate::mock::read_mock_user;
@@ -61,9 +58,7 @@ const GENESIS: DateTime<Utc> = DateTime::from_timestamp(1431648000, 0).unwrap();
const PROVER_MINIMAL_AMOUNT_FOR_REGISTRATION: U256 =
U256::from_le_slice(10u64.to_le_bytes().as_slice());
pub async fn run_prover(
app_args: AppArgs,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
pub async fn run_prover(app_args: AppArgs) -> Result<(), AppError> {
// Epoch
let epoch_service = EpochService::try_from((Duration::from_secs(60 * 2), GENESIS))
.expect("Failed to create epoch service");
@@ -71,10 +66,7 @@ pub async fn run_prover(
// Alloy provider (Smart contract provider)
let provider = if app_args.ws_rpc_url.is_some() {
let ws = WsConnect::new(app_args.ws_rpc_url.clone().unwrap().as_str());
let provider = ProviderBuilder::new()
.connect_ws(ws)
.await
.map_err(KarmaTiersError::RpcTransportError)?;
let provider = ProviderBuilder::new().connect_ws(ws).await?;
Some(provider)
} else {
None
@@ -146,7 +138,7 @@ pub async fn run_prover(
debug!("User {} already registered", mock_user.address);
}
_ => {
return Err(Box::new(e));
return Err(AppError::from(e));
}
}
}
@@ -155,9 +147,8 @@ pub async fn run_prover(
}
// Smart contract
// FIXME: use provider
let registry_listener = if app_args.mock_sc.is_some() {
// debug!("No registry listener when mock is enabled");
// No registry listener when mock is enabled
None
} else {
Some(RegistryListener::new(
@@ -185,7 +176,7 @@ pub async fn run_prover(
let rln_identifier = RlnIdentifier::new(RLN_IDENTIFIER_NAME);
let addr = SocketAddr::new(app_args.ip, app_args.port);
debug!("Listening on: {}", addr);
info!("Listening on: {}", addr);
let prover_grpc_service = {
let mut service = GrpcProverService {
proof_sender,
@@ -242,11 +233,17 @@ pub async fn run_prover(
if app_args.ws_rpc_url.is_some() {
set.spawn(async move { prover_grpc_service.serve().await });
} else {
debug!("Grpc service started with mocked smart contracts");
info!("Grpc service started with mocked smart contracts");
set.spawn(async move { prover_grpc_service.serve_with_mock().await });
}
// TODO: handle error
let _ = set.join_all().await;
let res = set.join_all().await;
// Print all errors from services (if any)
// We expect that the Prover should never stop unexpectedly, but printing error can help to debug
res.iter().for_each(|r| {
if r.is_err() {
info!("Error: {:?}", r);
}
});
Ok(())
}

View File

@@ -6,7 +6,7 @@ use ark_serialize::CanonicalSerialize;
use async_channel::Receiver;
use metrics::{counter, histogram};
use parking_lot::RwLock;
use rln::hashers::hash_to_field;
use rln::hashers::hash_to_field_le;
use rln::protocol::serialize_proof_values;
use tracing::{
Instrument, // debug,
@@ -101,7 +101,7 @@ impl ProofService {
let rln_data = RlnData {
message_id: Fr::from(message_id),
data: hash_to_field(proof_generation_data.tx_hash.as_slice()),
data: hash_to_field_le(proof_generation_data.tx_hash.as_slice()),
};
let epoch_bytes = {
@@ -109,7 +109,7 @@ impl ProofService {
v.extend(current_epoch_slice.to_le_bytes());
v
};
let epoch = hash_to_field(epoch_bytes.as_slice());
let epoch = hash_to_field_le(epoch_bytes.as_slice());
let merkle_proof = match user_db.get_merkle_proof(&proof_generation_data.tx_sender)
{
@@ -159,6 +159,15 @@ impl ProofService {
let _ = send.send(Ok::<Vec<u8>, ProofGenerationError>(
output_buffer.into_inner(),
));
/*
std::thread::sleep(std::time::Duration::from_millis(100));
let mut output_buffer = Cursor::new(Vec::with_capacity(PROOF_SIZE));
// Send the result back to Tokio.
let _ = send.send(Ok::<Vec<u8>, ProofGenerationError>(
output_buffer.into_inner(),
));
*/
});
// Wait for the rayon task.
@@ -271,7 +280,7 @@ mod tests {
debug!("Starting broadcast receiver...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let res =
tokio::time::timeout(std::time::Duration::from_secs(5), broadcast_receiver.recv())
tokio::time::timeout(std::time::Duration::from_secs(7), broadcast_receiver.recv())
.await
.map_err(|_e| AppErrorExt::Elapsed)?;
debug!("res: {:?}", res);

View File

@@ -5,7 +5,6 @@ mod tests {
use std::sync::Arc;
// third-party
use alloy::primitives::{Address, address};
use ark_bn254::Fr;
use ark_groth16::{Proof as ArkProof, Proof, VerifyingKey};
use ark_serialize::CanonicalDeserialize;
use claims::assert_matches;
@@ -14,6 +13,7 @@ mod tests {
use rln::circuit::{Curve, zkey_from_folder};
use rln::error::ComputeIdSecretError;
use rln::protocol::{compute_id_secret, deserialize_proof_values, verify_proof};
use rln::utils::IdSecret;
use tokio::sync::broadcast;
use tracing::{debug, info};
// internal
@@ -50,7 +50,7 @@ mod tests {
#[error(transparent)]
RecoverSecretFailed(ComputeIdSecretError),
#[error("Recovered secret")]
RecoveredSecret(Fr),
RecoveredSecret(IdSecret),
}
async fn proof_sender(

View File

@@ -56,6 +56,7 @@ impl EpochCounterSerializer {
}
}
#[derive(Clone)]
pub struct EpochCounterDeserializer {}
impl EpochCounterDeserializer {
@@ -86,6 +87,7 @@ pub struct EpochIncr {
pub incr_value: u64,
}
#[derive(Clone)]
pub struct EpochIncrSerializer {}
impl EpochIncrSerializer {

View File

@@ -1,5 +1,4 @@
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
// third-party
use alloy::primitives::{Address, U256};
@@ -15,8 +14,8 @@ use rln::{
use rocksdb::{
ColumnFamily, ColumnFamilyDescriptor, DB, Options, ReadOptions, WriteBatch, WriteBatchWithIndex,
};
use serde::{Deserialize, Serialize};
use tracing::error;
use zerokit_utils::Mode::HighThroughput;
use zerokit_utils::{
error::ZerokitMerkleTreeError,
pmtree::{PmtreeErrorKind, TreeErrorKind},
@@ -62,22 +61,20 @@ pub struct UserTierInfo {
pub(crate) tier_limit: Option<TierLimit>,
}
#[derive(Serialize, Deserialize)]
struct PmTreeConfigJson {
path: PathBuf,
temporary: bool,
cache_capacity: u64,
flush_every_ms: u64,
mode: String,
use_compression: bool,
}
#[derive(Clone)]
pub(crate) struct UserDb {
db: Arc<DB>,
merkle_tree: Arc<RwLock<PoseidonTree>>,
rate_limit: RateLimit,
pub(crate) epoch_store: Arc<RwLock<(Epoch, EpochSlice)>>,
rln_identity_serializer: RlnUserIdentitySerializer,
rln_identity_deserializer: RlnUserIdentityDeserializer,
merkle_index_serializer: MerkleTreeIndexSerializer,
merkle_index_deserializer: MerkleTreeIndexDeserializer,
epoch_increase_serializer: EpochIncrSerializer,
epoch_counter_deserializer: EpochCounterDeserializer,
tier_limits_serializer: TierLimitsSerializer,
tier_limits_deserializer: TierLimitsDeserializer,
}
impl std::fmt::Debug for UserDb {
@@ -144,7 +141,10 @@ impl UserDb {
// merkle tree index
let cf_mtree = db.cf_handle(MERKLE_TREE_COUNTER_CF).unwrap();
if let Err(e) = Self::get_merkle_tree_index_(db.clone(), cf_mtree) {
let merkle_index_deserializer = MerkleTreeIndexDeserializer {};
if let Err(e) =
Self::get_merkle_tree_index_(db.clone(), cf_mtree, &merkle_index_deserializer)
{
match e {
MerkleTreeIndexError::DbUninitialized => {
// Check if the value is already there (e.g. after a restart)
@@ -156,25 +156,32 @@ impl UserDb {
}
// merkle tree
let tree_config = PmtreeConfig::builder()
.path(merkle_tree_path)
.temporary(false)
.cache_capacity(100_000)
.flush_every_ms(12_000)
.mode(HighThroughput)
.use_compression(false)
.build()?;
let tree = PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), tree_config)?;
let config_ = PmTreeConfigJson {
path: merkle_tree_path,
temporary: false,
cache_capacity: 100_000,
flush_every_ms: 12_000,
mode: "HighThroughput".to_string(),
use_compression: false,
let tier_limits_deserializer = TierLimitsDeserializer {
tier_deserializer: TierDeserializer {},
};
let config_str = serde_json::to_string(&config_)?;
// Note: in Zerokit 0.8 this is the only way to initialize a PmTreeConfig
let config = PmtreeConfig::from_str(config_str.as_str())?;
let tree = PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), config)?;
Ok(Self {
db,
merkle_tree: Arc::new(RwLock::new(tree)),
rate_limit,
epoch_store,
rln_identity_serializer: RlnUserIdentitySerializer {},
rln_identity_deserializer: RlnUserIdentityDeserializer {},
merkle_index_serializer: MerkleTreeIndexSerializer {},
merkle_index_deserializer,
epoch_increase_serializer: EpochIncrSerializer {},
epoch_counter_deserializer: EpochCounterDeserializer {},
tier_limits_serializer,
tier_limits_deserializer,
})
}
@@ -199,24 +206,23 @@ impl UserDb {
}
pub(crate) fn register(&self, address: Address) -> Result<Fr, RegisterError> {
let rln_identity_serializer = RlnUserIdentitySerializer {};
let merkle_index_serializer = MerkleTreeIndexSerializer {};
let merkle_index_deserializer = MerkleTreeIndexDeserializer {};
let (identity_secret_hash, id_commitment) = keygen();
let rln_identity = RlnUserIdentity::from((
identity_secret_hash,
id_commitment,
identity_secret_hash,
Fr::from(self.rate_limit),
));
let key = address.as_slice();
let mut buffer =
vec![0; rln_identity_serializer.size_hint() + merkle_index_serializer.size_hint()];
let mut buffer = vec![
0;
self.rln_identity_serializer.size_hint()
+ self.merkle_index_serializer.size_hint()
];
// unwrap safe - this is serialized by the Prover + RlnUserIdentitySerializer is unit tested
rln_identity_serializer
self.rln_identity_serializer
.serialize(&rln_identity, &mut buffer)
.unwrap();
@@ -257,7 +263,8 @@ impl UserDb {
// Increase merkle tree index
db_batch.merge_cf(cf_mtree, MERKLE_TREE_INDEX_KEY, 1i64.to_le_bytes());
// Unwrap safe - serialization is handled by the prover
let (_, new_index) = merkle_index_deserializer
let (_, new_index) = self
.merkle_index_deserializer
.deserialize(batch_read.as_slice())
.unwrap();
@@ -281,7 +288,8 @@ impl UserDb {
})?;
// Add index for user
merkle_index_serializer.serialize(&new_index, &mut buffer);
self.merkle_index_serializer
.serialize(&new_index, &mut buffer);
// Put user
db_batch.put_cf(cf_user, key, buffer.as_slice());
// Put user tx counter
@@ -311,11 +319,10 @@ impl UserDb {
pub fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
let cf_user = self.get_user_cf();
let rln_identity_deserializer = RlnUserIdentityDeserializer {};
match self.db.get_pinned_cf(cf_user, address.as_slice()) {
Ok(Some(value)) => {
// Here we silence the error - this is safe as the prover controls this
rln_identity_deserializer.deserialize(&value).ok()
self.rln_identity_deserializer.deserialize(&value).ok()
}
Ok(None) => None,
Err(_e) => None,
@@ -327,13 +334,12 @@ impl UserDb {
address: &Address,
) -> Result<MerkleTreeIndex, UserMerkleTreeIndexError> {
let cf_user = self.get_user_cf();
let rln_identity_serializer = RlnUserIdentitySerializer {};
let merkle_tree_index_deserializer = MerkleTreeIndexDeserializer {};
match self.db.get_pinned_cf(cf_user, address.as_slice()) {
Ok(Some(buffer)) => {
// Here we silence the error - this is safe as the prover controls this
let start = rln_identity_serializer.size_hint();
let (_, index) = merkle_tree_index_deserializer
let start = self.rln_identity_serializer.size_hint();
let (_, index) = self
.merkle_index_deserializer
.deserialize(&buffer[start..])
.unwrap();
Ok(index)
@@ -398,9 +404,8 @@ impl UserDb {
epoch_slice,
incr_value,
};
let incr_ser = EpochIncrSerializer {};
let mut buffer = Vec::with_capacity(incr_ser.size_hint());
incr_ser.serialize(&incr, &mut buffer);
let mut buffer = Vec::with_capacity(self.epoch_increase_serializer.size_hint());
self.epoch_increase_serializer.serialize(&incr, &mut buffer);
// Create a transaction
// By using a WriteBatchWithIndex, we can "read your own writes" so here we incr then read the new value
@@ -435,11 +440,9 @@ impl UserDb {
address: &Address,
key: Option<Vec<u8>>,
) -> Result<(EpochCounter, EpochSliceCounter), TxCounterError> {
let deserializer = EpochCounterDeserializer {};
match key {
Some(value) => {
let (_, counter) = deserializer.deserialize(&value).unwrap();
let (_, counter) = self.epoch_counter_deserializer.deserialize(&value).unwrap();
let (epoch, epoch_slice) = *self.epoch_store.read();
let cmp = (counter.epoch == epoch, counter.epoch_slice == epoch_slice);
@@ -490,19 +493,20 @@ impl UserDb {
#[cfg(test)]
pub(crate) fn get_merkle_tree_index(&self) -> Result<MerkleTreeIndex, MerkleTreeIndexError> {
let cf_mtree = self.get_mtree_cf();
Self::get_merkle_tree_index_(self.db.clone(), cf_mtree)
Self::get_merkle_tree_index_(self.db.clone(), cf_mtree, &self.merkle_index_deserializer)
}
fn get_merkle_tree_index_(
db: Arc<DB>,
cf: &ColumnFamily,
merkle_tree_index_deserializer: &MerkleTreeIndexDeserializer,
) -> Result<MerkleTreeIndex, MerkleTreeIndexError> {
let deserializer = MerkleTreeIndexDeserializer {};
match db.get_cf(cf, MERKLE_TREE_INDEX_KEY) {
Ok(Some(v)) => {
// Unwrap safe - serialization is done by the prover
let (_, index) = deserializer.deserialize(v.as_slice()).unwrap();
let (_, index) = merkle_tree_index_deserializer
.deserialize(v.as_slice())
.unwrap();
Ok(index)
}
Ok(None) => Err(MerkleTreeIndexError::DbUninitialized),
@@ -541,12 +545,8 @@ impl UserDb {
let cf = self.get_tier_limits_cf();
// Unwrap safe - Db is initialized with valid tier limits
let buffer = self.db.get_cf(cf, TIER_LIMITS_KEY.as_slice())?.unwrap();
let tier_limits_deserializer = TierLimitsDeserializer {
tier_deserializer: TierDeserializer {},
};
// Unwrap safe - serialized by the prover (should always deserialize)
let (_, tier_limits) = tier_limits_deserializer.deserialize(&buffer).unwrap();
let (_, tier_limits) = self.tier_limits_deserializer.deserialize(&buffer).unwrap();
Ok(tier_limits)
}
@@ -557,10 +557,10 @@ impl UserDb {
tier_limits.validate()?;
// Serialize
let tier_limits_serializer = TierLimitsSerializer::default();
let mut buffer = Vec::with_capacity(tier_limits_serializer.size_hint(tier_limits.len()));
let mut buffer =
Vec::with_capacity(self.tier_limits_serializer.size_hint(tier_limits.len()));
// Unwrap safe - already validated - should always serialize
tier_limits_serializer
self.tier_limits_serializer
.serialize(&tier_limits, &mut buffer)
.unwrap();
@@ -798,16 +798,15 @@ mod tests {
.unwrap();
let temp_folder_tree_2 = tempfile::tempdir().unwrap();
let config_ = PmTreeConfigJson {
path: temp_folder_tree_2.path().to_path_buf(),
temporary: false,
cache_capacity: 100_000,
flush_every_ms: 12_000,
mode: "HighThroughput".to_string(),
use_compression: false,
};
let config_str = serde_json::to_string(&config_).unwrap();
let config = PmtreeConfig::from_str(config_str.as_str()).unwrap();
let config = PmtreeConfig::builder()
.path(temp_folder_tree_2.path().to_path_buf())
.temporary(false)
.cache_capacity(100_000)
.flush_every_ms(12_000)
.mode(HighThroughput)
.use_compression(false)
.build()
.unwrap();
let tree = PoseidonTree::new(1, Default::default(), config).unwrap();
let tree = Arc::new(RwLock::new(tree));
user_db.merkle_tree = tree.clone();

View File

@@ -6,7 +6,7 @@ use zerokit_utils::error::{FromConfigError, ZerokitMerkleTreeError};
use crate::tier::ValidateTierLimitsError;
#[derive(Debug, thiserror::Error)]
pub(crate) enum UserDbOpenError {
pub enum UserDbOpenError {
#[error(transparent)]
RocksDb(#[from] rocksdb::Error),
#[error("Serialization error: {0}")]

View File

@@ -12,12 +12,14 @@ use nom::{
multi::length_count,
number::complete::{le_u32, le_u64},
};
use rln::utils::IdSecret;
use rln_proof::RlnUserIdentity;
// internal
use crate::tier::TierLimits;
use crate::user_db_types::MerkleTreeIndex;
use smart_contract::Tier;
#[derive(Clone)]
pub(crate) struct RlnUserIdentitySerializer {}
impl RlnUserIdentitySerializer {
@@ -41,6 +43,7 @@ impl RlnUserIdentitySerializer {
}
}
#[derive(Clone)]
pub(crate) struct RlnUserIdentityDeserializer {}
impl RlnUserIdentityDeserializer {
@@ -49,8 +52,8 @@ impl RlnUserIdentityDeserializer {
let (co_buffer, rem_buffer) = buffer.split_at(compressed_size);
let commitment: Fr = CanonicalDeserialize::deserialize_compressed(co_buffer)?;
let (secret_buffer, user_limit_buffer) = rem_buffer.split_at(compressed_size);
// TODO: IdSecret (wait for Zerokit PR: https://github.com/vacp2p/zerokit/pull/320)
let secret_hash: Fr = CanonicalDeserialize::deserialize_compressed(secret_buffer)?;
let mut secret_hash_: Fr = CanonicalDeserialize::deserialize_compressed(secret_buffer)?;
let secret_hash = IdSecret::from(&mut secret_hash_);
let user_limit: Fr = CanonicalDeserialize::deserialize_compressed(user_limit_buffer)?;
Ok({
@@ -63,6 +66,7 @@ impl RlnUserIdentityDeserializer {
}
}
#[derive(Clone)]
pub(crate) struct MerkleTreeIndexSerializer {}
impl MerkleTreeIndexSerializer {
@@ -77,6 +81,7 @@ impl MerkleTreeIndexSerializer {
}
}
#[derive(Clone)]
pub(crate) struct MerkleTreeIndexDeserializer {}
impl MerkleTreeIndexDeserializer {
@@ -88,7 +93,7 @@ impl MerkleTreeIndexDeserializer {
}
}
#[derive(Default)]
#[derive(Default, Clone)]
pub(crate) struct TierSerializer {}
impl TierSerializer {
@@ -113,7 +118,7 @@ impl TierSerializer {
}
}
#[derive(Default)]
#[derive(Default, Clone)]
pub(crate) struct TierDeserializer {}
#[derive(Debug, PartialEq)]
@@ -166,7 +171,7 @@ impl TierDeserializer {
}
}
#[derive(Default)]
#[derive(Default, Clone)]
pub(crate) struct TierLimitsSerializer {
tier_serializer: TierSerializer,
}
@@ -193,7 +198,7 @@ impl TierLimitsSerializer {
}
}
#[derive(Default)]
#[derive(Default, Clone)]
pub(crate) struct TierLimitsDeserializer {
pub(crate) tier_deserializer: TierDeserializer,
}
@@ -226,7 +231,7 @@ mod tests {
fn test_rln_ser_der() {
let rln_user_identity = RlnUserIdentity {
commitment: Fr::from(42),
secret_hash: Fr::from(u16::MAX),
secret_hash: IdSecret::from(&mut Fr::from(u16::MAX)),
user_limit: Fr::from(1_000_000),
};
let serializer = RlnUserIdentitySerializer {};

View File

@@ -24,3 +24,4 @@ opentelemetry-otlp = { version = "0.30.0", features = [
"tls-roots",
] }
tracing-opentelemetry = "0.31.0"
anyhow = "1.0.99"

View File

@@ -11,6 +11,7 @@ use tracing::{
};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
use anyhow::{Context, Result, anyhow};
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithTonicConfig;
use opentelemetry_sdk::Resource;
@@ -20,7 +21,7 @@ use prover::{AppArgs, AppArgsConfig, metrics::init_metrics, run_prover};
const APP_NAME: &str = "prover-cli";
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
async fn main() -> Result<()> {
// install crypto provider for rustls - required for WebSocket TLS connections
rustls::crypto::CryptoProvider::install_default(aws_lc_rs::default_provider())
.expect("Failed to install default CryptoProvider");
@@ -58,7 +59,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>
// Unwrap safe - default value provided
let config_path = app_args.get_one::<PathBuf>("config_path").unwrap();
debug!("Reading config path: {:?}...", config_path);
let config_str = std::fs::read_to_string(config_path)?;
let config_str = std::fs::read_to_string(config_path)
.context(format!("Failed to read config file: {:?}", config_path))?;
let config: AppArgsConfig = toml::from_str(config_str.as_str())?;
debug!("Config: {:?}", config);
config
@@ -76,15 +78,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>
|| app_args.ksc_address.is_none()
|| app_args.tsc_address.is_none()
{
return Err("Please provide smart contract addresses".into());
return Err(anyhow!("Please provide smart contract addresses"));
}
} else if app_args.mock_sc.is_none() {
return Err("Please provide rpc url (--ws-rpc-url) or mock (--mock-sc)".into());
return Err(anyhow!(
"Please provide rpc url (--ws-rpc-url) or mock (--mock-sc)"
));
}
init_metrics(app_args.metrics_ip, &app_args.metrics_port);
run_prover(app_args).await
run_prover(app_args).await.map_err(anyhow::Error::new)
}
fn create_otlp_tracer_provider() -> Option<opentelemetry_sdk::trace::SdkTracerProvider> {

View File

@@ -1,3 +1,4 @@
use std::hint::black_box;
// std
use std::io::{Cursor, Write};
// criterion
@@ -5,7 +6,7 @@ use criterion::{Criterion, criterion_group, criterion_main};
// third-party
use ark_bn254::Fr;
use ark_serialize::CanonicalSerialize;
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::hashers::{hash_to_field_le, poseidon_hash};
use rln::poseidon_tree::PoseidonTree;
use rln::protocol::{keygen, serialize_proof_values};
// internal
@@ -24,7 +25,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
let rln_identifier = RlnIdentifier::new(b"test-test");
let rln_data = RlnData {
message_id: Fr::from(user_limit - 2),
data: hash_to_field(b"data-from-message"),
data: hash_to_field_le(b"data-from-message"),
};
// Merkle tree
@@ -35,7 +36,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
let merkle_proof = tree.proof(0).unwrap();
// Epoch
let epoch = hash_to_field(b"Today at noon, this year");
let epoch = hash_to_field_le(b"Today at noon, this year");
{
// Not a benchmark but print the proof size (serialized)
@@ -61,17 +62,6 @@ pub fn criterion_benchmark(c: &mut Criterion) {
}
c.bench_function("compute proof and values", |b| {
/*
b.iter(|| {
compute_rln_proof_and_values(
&rln_identity,
&rln_identifier,
rln_data.clone(),
epoch,
&merkle_proof,
)
})
*/
b.iter_batched(
|| {
// generate setup data
@@ -80,11 +70,11 @@ pub fn criterion_benchmark(c: &mut Criterion) {
|data| {
// function to benchmark
compute_rln_proof_and_values(
&rln_identity,
&rln_identifier,
data,
epoch,
&merkle_proof,
black_box(&rln_identity),
black_box(&rln_identifier),
black_box(data),
black_box(epoch),
black_box(&merkle_proof),
)
},
criterion::BatchSize::SmallInput,
@@ -96,19 +86,21 @@ pub fn criterion_benchmark(c: &mut Criterion) {
|| {
// generate setup data
compute_rln_proof_and_values(
&rln_identity,
&rln_identifier,
rln_data.clone(),
epoch,
&merkle_proof,
black_box(&rln_identity),
black_box(&rln_identifier),
black_box(rln_data.clone()),
black_box(epoch),
black_box(&merkle_proof),
)
.unwrap()
},
|(proof, proof_values)| {
let mut output_buffer = Cursor::new(Vec::with_capacity(320));
proof.serialize_compressed(&mut output_buffer).unwrap();
proof
.serialize_compressed(black_box(&mut output_buffer))
.unwrap();
output_buffer
.write_all(&serialize_proof_values(&proof_values))
.write_all(black_box(&serialize_proof_values(black_box(&proof_values))))
.unwrap();
},
criterion::BatchSize::SmallInput,

View File

@@ -1,29 +1,31 @@
// std
use std::io::Cursor;
// use std::io::Cursor;
// third-party
use ark_bn254::{Bn254, Fr};
use ark_groth16::{Proof, ProvingKey};
use ark_relations::r1cs::ConstraintMatrices;
use rln::utils::IdSecret;
use rln::{
circuit::{ZKEY_BYTES, zkey::read_zkey},
circuit::{ARKZKEY_BYTES, read_arkzkey_from_bytes_uncompressed as read_zkey},
error::ProofError,
hashers::{hash_to_field, poseidon_hash},
hashers::{hash_to_field_le, poseidon_hash},
poseidon_tree::MerkleProof,
protocol::{
RLNProofValues, generate_proof, proof_values_from_witness, rln_witness_from_values,
},
};
use zerokit_utils::ZerokitMerkleProof;
/// A RLN user identity & limit
#[derive(Debug, Clone, PartialEq)]
pub struct RlnUserIdentity {
pub commitment: Fr,
pub secret_hash: Fr,
pub secret_hash: IdSecret,
pub user_limit: Fr,
}
impl From<(Fr, Fr, Fr)> for RlnUserIdentity {
fn from((commitment, secret_hash, user_limit): (Fr, Fr, Fr)) -> Self {
impl From<(Fr, IdSecret, Fr)> for RlnUserIdentity {
fn from((commitment, secret_hash, user_limit): (Fr, IdSecret, Fr)) -> Self {
Self {
commitment,
secret_hash,
@@ -43,13 +45,13 @@ pub struct RlnIdentifier {
impl RlnIdentifier {
pub fn new(identifier: &[u8]) -> Self {
let pk_and_matrices = {
let mut reader = Cursor::new(ZKEY_BYTES);
read_zkey(&mut reader).unwrap()
// let mut reader = Cursor::new(ARKZKEY_BYTES);
read_zkey(ARKZKEY_BYTES).unwrap()
};
let graph_bytes = include_bytes!("../resources/graph.bin");
Self {
identifier: hash_to_field(identifier),
identifier: hash_to_field_le(identifier),
pkey_and_constraints: pk_and_matrices,
graph: graph_bytes.to_vec(),
}
@@ -74,9 +76,15 @@ pub fn compute_rln_proof_and_values(
) -> Result<(Proof<Bn254>, RLNProofValues), ProofError> {
let external_nullifier = poseidon_hash(&[rln_identifier.identifier, epoch]);
let path_elements = merkle_proof.get_path_elements();
let identity_path_index = merkle_proof.get_path_index();
// let mut id_s = user_identity.secret_hash;
let witness = rln_witness_from_values(
user_identity.secret_hash,
merkle_proof,
user_identity.secret_hash.clone(),
path_elements,
identity_path_index,
rln_data.data,
external_nullifier,
user_identity.user_limit,
@@ -101,8 +109,9 @@ mod tests {
#[test]
fn test_recover_secret_hash() {
let (user_co, user_sh) = keygen();
let epoch = hash_to_field(b"foo");
let (user_co, mut user_sh_) = keygen();
let user_sh = IdSecret::from(&mut user_sh_);
let epoch = hash_to_field_le(b"foo");
let spam_limit = Fr::from(10);
// let mut tree = OptimalMerkleTree::new(20, Default::default(), Default::default()).unwrap();
@@ -116,14 +125,14 @@ mod tests {
let (_proof_0, proof_values_0) = compute_rln_proof_and_values(
&RlnUserIdentity {
commitment: user_co,
secret_hash: user_sh,
commitment: *user_co,
secret_hash: user_sh.clone(),
user_limit: spam_limit,
},
&rln_identifier,
RlnData {
message_id,
data: hash_to_field(b"sig"),
data: hash_to_field_le(b"sig"),
},
epoch,
&m_proof,
@@ -132,14 +141,14 @@ mod tests {
let (_proof_1, proof_values_1) = compute_rln_proof_and_values(
&RlnUserIdentity {
commitment: user_co,
secret_hash: user_sh,
commitment: *user_co,
secret_hash: user_sh.clone(),
user_limit: spam_limit,
},
&rln_identifier,
RlnData {
message_id,
data: hash_to_field(b"sig 2"),
data: hash_to_field_le(b"sig 2"),
},
epoch,
&m_proof,