diff --git a/Cargo.lock b/Cargo.lock
index 3f20aa6..839be72 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -322,6 +322,7 @@ dependencies = [
"const-hex",
"derive_more",
"foldhash",
+ "getrandom 0.3.2",
"hashbrown 0.15.2",
"indexmap 2.9.0",
"itoa",
@@ -1557,6 +1558,12 @@ dependencies = [
"half",
]
+[[package]]
+name = "claims"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18"
+
[[package]]
name = "clap"
version = "4.5.37"
@@ -1753,6 +1760,29 @@ dependencies = [
"walkdir",
]
+[[package]]
+name = "criterion"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679"
+dependencies = [
+ "anes",
+ "cast",
+ "ciborium",
+ "clap",
+ "criterion-plot",
+ "itertools 0.13.0",
+ "num-traits",
+ "oorandom",
+ "plotters",
+ "rayon",
+ "regex",
+ "serde",
+ "serde_json",
+ "tinytemplate",
+ "walkdir",
+]
+
[[package]]
name = "criterion-plot"
version = "0.5.0"
@@ -3886,9 +3916,8 @@ dependencies = [
"ark-groth16",
"ark-relations",
"ark-serialize 0.5.0",
- "criterion",
+ "criterion 0.5.1",
"rln",
- "serde_json",
"zerokit_utils",
]
@@ -4057,6 +4086,15 @@ dependencies = [
"winapi-util",
]
+[[package]]
+name = "scc"
+version = "2.3.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22b2d775fb28f245817589471dd49c5edf64237f4a19d10ce9a92ff4651a27f4"
+dependencies = [
+ "sdd",
+]
+
[[package]]
name = "schannel"
version = "0.1.27"
@@ -4072,6 +4110,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
+[[package]]
+name = "sdd"
+version = "3.0.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21"
+
[[package]]
name = "sec1"
version = "0.7.3"
@@ -4367,15 +4411,19 @@ dependencies = [
"async-channel",
"bytesize",
"chrono",
+ "claims",
"clap",
+ "criterion 0.6.0",
"dashmap",
+ "derive_more",
"futures",
"http",
"parking_lot 0.12.3",
"prost",
+ "rand 0.8.5",
"rln",
"rln_proof",
- "serde_json",
+ "scc",
"thiserror 2.0.12",
"tokio",
"tonic",
diff --git a/proto/net/vac/prover/prover.proto b/proto/net/vac/prover/prover.proto
index eb67850..a12f7db 100644
--- a/proto/net/vac/prover/prover.proto
+++ b/proto/net/vac/prover/prover.proto
@@ -14,6 +14,8 @@ service RlnProver {
// Server side streaming RPC: 1 request -> X responses (stream)
rpc GetProofs(RlnProofFilter) returns (stream RlnProof);
+ rpc GetUserTierInfo(GetUserTierInfoRequest) returns (GetUserTierInfoReply);
+ rpc SetTierLimits(SetTierLimitsRequest) returns (SetTierLimitsReply);
}
// TransactionType: https://github.com/Consensys/linea-besu/blob/09cbed1142cfe4d29b50ecf2f156639a4bc8c854/datatypes/src/main/java/org/hyperledger/besu/datatypes/TransactionType.java#L22
@@ -142,3 +144,41 @@ message RegisterUserReply {
RegistrationStatus status = 1;
}
+message GetUserTierInfoRequest {
+ Address user = 1;
+}
+
+message GetUserTierInfoReply {
+ oneof resp {
+ // variant for success
+ UserTierInfoResult res = 1;
+ // variant for error
+ UserTierInfoError error = 2;
+ }
+}
+
+message UserTierInfoError {
+ string message = 1;
+}
+
+message UserTierInfoResult {
+ sint64 current_epoch = 1;
+ sint64 current_epoch_slice = 2;
+ uint64 tx_count = 3;
+ optional Tier tier = 4;
+}
+
+message Tier {
+ string name = 1;
+ uint64 quota = 2;
+}
+
+message SetTierLimitsRequest {
+ repeated U256 karmaAmounts = 1;
+ repeated Tier tiers = 2;
+}
+
+message SetTierLimitsReply {
+ bool status = 1;
+ string error = 2;
+}
\ No newline at end of file
diff --git a/prover/Cargo.toml b/prover/Cargo.toml
index a00080a..a64a51d 100644
--- a/prover/Cargo.toml
+++ b/prover/Cargo.toml
@@ -13,14 +13,14 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
tracing = "0.1.41"
tracing-test = "0.2.5"
-alloy = { version = "0.15", features = ["full"] }
+alloy = { version = "0.15", features = ["full", "getrandom"] }
thiserror = "2.0"
futures = "0.3"
rln = { git = "https://github.com/vacp2p/zerokit" }
ark-bn254 = { version = "0.5", features = ["std"] }
ark-serialize = "0.5.0"
-serde_json = "1.0"
dashmap = "6.1.0"
+scc = "2.3"
bytesize = "2.0.1"
rln_proof = { path = "../rln_proof" }
chrono = "0.4.41"
@@ -28,6 +28,16 @@ parking_lot = "0.12.3"
tower-http = { version = "0.6.4", features = ["cors"] }
http = "*"
async-channel = "2.3.1"
+rand = "0.8.5"
+derive_more = "2.0.1"
[build-dependencies]
tonic-build = "*"
+
+[dev-dependencies]
+criterion = "0.6.0"
+claims = "0.8"
+
+[[bench]]
+name = "user_db_heavy_write"
+harness = false
diff --git a/prover/benches/user_db_heavy_write.rs b/prover/benches/user_db_heavy_write.rs
new file mode 100644
index 0000000..ed56c21
--- /dev/null
+++ b/prover/benches/user_db_heavy_write.rs
@@ -0,0 +1,59 @@
+use alloy::primitives::Address;
+use std::hint::black_box;
+// criterion
+use criterion::{
+ BenchmarkId,
+ Criterion,
+ Throughput,
+ criterion_group,
+ criterion_main,
+ // black_box
+};
+use dashmap::DashMap;
+use rand::Rng;
+use scc::HashMap;
+
+pub fn criterion_benchmark(c: &mut Criterion) {
+ let size = 1_250_000;
+ let mut rng = rand::thread_rng();
+ let d_1m: DashMap
= DashMap::with_capacity(size as usize);
+ let scc_1m: HashMap = HashMap::with_capacity(size as usize);
+
+ (0..size).into_iter().for_each(|_i| {
+ let mut addr = Address::new([0; 20]);
+ addr.0.randomize();
+ let n1 = rng.r#gen();
+ let n2 = rng.r#gen();
+ d_1m.insert(addr, (n1, n2));
+ scc_1m.insert(addr, (n1, n2)).unwrap();
+ });
+
+ let mut group = c.benchmark_group("Scc versus DashMap alter_all");
+
+ group.throughput(Throughput::Elements(size));
+ group.bench_function(BenchmarkId::new("Dashmap", size), |b| {
+ b.iter(|| {
+ black_box(d_1m.alter_all(|_, v| black_box((v.0, 0))));
+ })
+ });
+
+ group.bench_function(BenchmarkId::new("Scc", size), |b| {
+ b.iter(|| {
+ black_box(scc_1m.retain(|_, v| {
+ black_box(*v = (v.0, 0));
+ black_box(true)
+ }));
+ })
+ });
+
+ group.finish();
+}
+
+criterion_group! {
+ name = benches;
+ config = Criterion::default()
+ .measurement_time(std::time::Duration::from_secs(45))
+ .sample_size(10);
+ targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/prover/src/epoch_service.rs b/prover/src/epoch_service.rs
index 8de7eeb..be25c75 100644
--- a/prover/src/epoch_service.rs
+++ b/prover/src/epoch_service.rs
@@ -4,6 +4,7 @@ use std::time::Duration;
// third-party
use chrono::{DateTime, NaiveDate, NaiveDateTime, OutOfRangeError, TimeDelta, Utc};
use parking_lot::RwLock;
+use tokio::sync::Notify;
use tracing::debug;
// internal
use crate::error::AppError;
@@ -24,6 +25,8 @@ pub struct EpochService {
pub current_epoch: Arc>,
/// Genesis time (aka when the service has been started at the first time)
genesis: DateTime,
+ /// Channel to notify when an epoch / epoch slice has just changed
+ pub epoch_changes: Arc,
}
impl EpochService {
@@ -66,7 +69,13 @@ impl EpochService {
current_epoch += 1;
}
*self.current_epoch.write() = (current_epoch.into(), current_epoch_slice.into());
- debug!("epoch: {}, epoch slice: {}", current_epoch, current_epoch_slice);
+ debug!(
+ "epoch: {}, epoch slice: {}",
+ current_epoch, current_epoch_slice
+ );
+
+ // println!("Epoch changed: {}", current_epoch);
+ self.epoch_changes.notify_one();
}
// Ok(())
@@ -186,10 +195,9 @@ impl TryFrom<(Duration, DateTime)> for EpochService {
Ok(Self {
epoch_slice_duration,
- // current_epoch: Arc::new(AtomicI64::new(0)),
- // current_epoch_slice: Arc::new(AtomicI64::new(0)),
current_epoch: Arc::new(Default::default()),
genesis,
+ epoch_changes: Arc::new(Default::default()),
})
}
}
@@ -203,8 +211,17 @@ pub enum WaitUntilError {
}
/// An Epoch (wrapper type over i64)
-#[derive(Debug, Clone, Copy, Default)]
+#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct Epoch(pub(crate) i64);
+
+impl Add for Epoch {
+ type Output = Self;
+
+ fn add(self, rhs: i64) -> Self::Output {
+ Self(self.0 + rhs)
+ }
+}
+
impl From for Epoch {
fn from(value: i64) -> Self {
Self(value)
@@ -218,9 +235,17 @@ impl From for i64 {
}
/// An Epoch slice (wrapper type over i64)
-#[derive(Debug, Clone, Copy, Default)]
+#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct EpochSlice(pub(crate) i64);
+impl Add for EpochSlice {
+ type Output = Self;
+
+ fn add(self, rhs: i64) -> Self::Output {
+ Self(self.0 + rhs)
+ }
+}
+
impl From for EpochSlice {
fn from(value: i64) -> Self {
Self(value)
@@ -237,6 +262,9 @@ impl From for i64 {
mod tests {
use super::*;
use chrono::{NaiveDate, NaiveDateTime, TimeDelta};
+ use futures::TryFutureExt;
+ use std::sync::atomic::{AtomicU64, Ordering};
+ use tracing_test::traced_test;
/*
#[tokio::test]
@@ -251,14 +279,15 @@ mod tests {
#[test]
fn test_wait_until() {
+ // Check wait_until is correctly computed
+
let date_0 = NaiveDate::from_ymd_opt(2025, 5, 14).unwrap();
let datetime_0 = date_0.and_hms_opt(0, 0, 0).unwrap();
{
// standard wait until in epoch 0, epoch slice 1
- let genesis: DateTime =
- chrono::DateTime::from_naive_utc_and_offset(datetime_0, chrono::Utc);
+ let genesis: DateTime = DateTime::from_naive_utc_and_offset(datetime_0, Utc);
let epoch_slice_duration = Duration::from_secs(60);
let epoch_service = EpochService::try_from((epoch_slice_duration, genesis)).unwrap();
@@ -266,7 +295,7 @@ mod tests {
let mut now_0: NaiveDateTime = date_0.and_hms_opt(0, 0, 0).unwrap();
// Set now_0 to be in epoch slice 1
now_0 += epoch_slice_duration;
- chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc)
+ DateTime::from_naive_utc_and_offset(now_0, chrono::Utc)
};
let (epoch, epoch_slice, wait_until): (_, _, DateTime) =
@@ -276,7 +305,7 @@ mod tests {
assert_eq!(epoch_slice, 1);
assert_eq!(
wait_until,
- chrono::DateTime::::from_naive_utc_and_offset(datetime_0, chrono::Utc)
+ DateTime::::from_naive_utc_and_offset(datetime_0, Utc)
+ 2 * epoch_slice_duration
);
}
@@ -284,8 +313,7 @@ mod tests {
{
// standard wait until (but in epoch 1)
- let genesis: DateTime =
- chrono::DateTime::from_naive_utc_and_offset(datetime_0, chrono::Utc);
+ let genesis: DateTime = DateTime::from_naive_utc_and_offset(datetime_0, Utc);
let epoch_slice_duration = Duration::from_secs(60);
let epoch_service = EpochService::try_from((epoch_slice_duration, genesis)).unwrap();
@@ -297,7 +325,7 @@ mod tests {
now_0 += epoch_slice_duration;
// Add 30 secs (but should still wait until epoch slice 2 starts)
now_0 += epoch_slice_duration / 2;
- chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc)
+ chrono::DateTime::from_naive_utc_and_offset(now_0, Utc)
};
let (epoch, epoch_slice, wait_until): (_, _, DateTime) =
@@ -307,7 +335,7 @@ mod tests {
assert_eq!(epoch_slice, 1);
assert_eq!(
wait_until,
- chrono::DateTime::::from_naive_utc_and_offset(datetime_0, chrono::Utc)
+ DateTime::::from_naive_utc_and_offset(datetime_0, Utc)
+ EPOCH_DURATION
+ 2 * epoch_slice_duration
);
@@ -317,7 +345,7 @@ mod tests {
// Check for WaitUntilError::TooLow
let genesis: DateTime =
- chrono::DateTime::from_naive_utc_and_offset(datetime_0, chrono::Utc);
+ chrono::DateTime::from_naive_utc_and_offset(datetime_0, Utc);
let epoch_slice_duration = Duration::from_secs(60);
let epoch_service = EpochService::try_from((epoch_slice_duration, genesis)).unwrap();
let epoch_slice_duration_minus_1 =
@@ -389,17 +417,20 @@ mod tests {
let now_f = move || {
let now_0: NaiveDateTime = day.and_hms_opt(0, 4, 0).unwrap();
- let now: DateTime = chrono::DateTime::from_utc(now_0, chrono::Utc);
+ let now: DateTime =
+ chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc);
now
};
let now_f_2 = move || {
let now_0: NaiveDateTime = day.and_hms_opt(0, 5, 59).unwrap();
- let now: DateTime = chrono::DateTime::from_utc(now_0, chrono::Utc);
+ let now: DateTime =
+ chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc);
now
};
let now_f_3 = move || {
let now_0: NaiveDateTime = day.and_hms_opt(0, 6, 0).unwrap();
- let now: DateTime = chrono::DateTime::from_utc(now_0, chrono::Utc);
+ let now: DateTime =
+ chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc);
now
};
@@ -430,4 +461,45 @@ mod tests {
3
);
}
+
+ #[derive(thiserror::Error, Debug)]
+ enum AppErrorExt {
+ #[error("AppError: {0}")]
+ AppError(#[from] AppError),
+ #[error("Future timeout")]
+ Elapsed,
+ }
+
+ #[tokio::test]
+ #[traced_test]
+ async fn test_notify() {
+ // Test epoch_service is really notifying when an epoch or epoch slice has just changed
+
+ let epoch_slice_duration = Duration::from_secs(10);
+ let epoch_service = EpochService::try_from((epoch_slice_duration, Utc::now())).unwrap();
+ let notifier = epoch_service.epoch_changes.clone();
+ let counter_0 = Arc::new(AtomicU64::new(0));
+ let counter = counter_0.clone();
+
+ let res = tokio::try_join!(
+ epoch_service
+ .listen_for_new_epoch()
+ .map_err(|e| AppErrorExt::AppError(e)),
+ // Wait for 3 epoch slices + 100 ms (to wait to receive notif + counter incr)
+ tokio::time::timeout(
+ epoch_slice_duration * 3 + Duration::from_millis(100),
+ async move {
+ loop {
+ notifier.notified().await;
+ debug!("[Notified] Epoch update...");
+ let _v = counter.fetch_add(1, Ordering::SeqCst);
+ }
+ Ok::<(), AppErrorExt>(())
+ }
+ )
+ .map_err(|_e| AppErrorExt::Elapsed)
+ );
+ assert!(matches!(res, Err(AppErrorExt::Elapsed)));
+ assert_eq!(counter_0.load(Ordering::SeqCst), 3);
+ }
}
diff --git a/prover/src/grpc_service.rs b/prover/src/grpc_service.rs
index 617e6a9..f152d2f 100644
--- a/prover/src/grpc_service.rs
+++ b/prover/src/grpc_service.rs
@@ -1,21 +1,16 @@
+use std::collections::BTreeMap;
// std
-use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
// third-party
-use alloy::primitives::Address;
+use alloy::primitives::{Address, U256};
use ark_bn254::Fr;
use async_channel::Sender;
use bytesize::ByteSize;
use futures::TryFutureExt;
use http::Method;
-use tokio::sync::{
- RwLock,
- broadcast,
- // broadcast::{Receiver, Sender},
- mpsc,
-};
+use tokio::sync::{broadcast, mpsc};
use tonic::{
Request,
Response,
@@ -32,13 +27,12 @@ use tracing::{
// info
};
// internal
-use crate::{
- error::{
- AppError,
- // RegistrationError
- },
- registry::UserRegistry,
+use crate::error::{
+ AppError,
+ // RegistrationError
};
+use crate::user_db_service::{KarmaAmountExt, UserDb, UserTierInfo};
+use rln_proof::{RlnIdentifier, RlnUserIdentity};
pub mod prover_proto {
@@ -48,18 +42,14 @@ pub mod prover_proto {
pub(crate) const FILE_DESCRIPTOR_SET: &[u8] =
tonic::include_file_descriptor_set!("prover_descriptor");
}
+use crate::tier::{KarmaAmount, TierLimit, TierName};
use prover_proto::{
- RegisterUserReply, RegisterUserRequest, RlnProof, RlnProofFilter, SendTransactionReply,
- SendTransactionRequest,
+ GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply, RegisterUserRequest, RlnProof,
+ RlnProofFilter, SendTransactionReply, SendTransactionRequest, SetTierLimitsReply,
+ SetTierLimitsRequest, Tier, UserTierInfoError, UserTierInfoResult,
+ get_user_tier_info_reply::Resp,
rln_prover_server::{RlnProver, RlnProverServer},
};
-use rln_proof::{
- // RlnData,
- RlnIdentifier,
- RlnUserIdentity,
- // ZerokitMerkleTree,
- // compute_rln_proof_and_values,
-};
const PROVER_SERVICE_LIMIT_PER_CONNECTION: usize = 16;
// Timeout for all handlers of a request
@@ -77,14 +67,10 @@ const PROVER_SPAM_LIMIT: u64 = 10_000;
#[derive(Debug)]
pub struct ProverService {
proof_sender: Sender<(RlnUserIdentity, Arc, u64)>,
- registry: UserRegistry,
+ user_db: UserDb,
rln_identifier: Arc,
- message_counters: RwLock>,
spam_limit: u64,
- broadcast_channel: (
- tokio::sync::broadcast::Sender>,
- tokio::sync::broadcast::Receiver>,
- ),
+ broadcast_channel: (broadcast::Sender>, broadcast::Receiver>),
}
#[tonic::async_trait]
@@ -106,23 +92,18 @@ impl RlnProver for ProverService {
return Err(Status::invalid_argument("No sender address"));
};
- // Update the counter as soon as possible (should help to prevent spamming...)
- let mut message_counter_guard = self.message_counters.write().await;
- let counter = *message_counter_guard
- .entry(sender)
- .and_modify(|e| *e += 1)
- .or_insert(1);
- drop(message_counter_guard);
-
- let user_id = if let Some(id) = self.registry.get(&sender) {
- *id
+ let user_id = if let Some(id) = self.user_db.get_user(&sender) {
+ id.clone()
} else {
return Err(Status::not_found("Sender not registered"));
};
+ // Update the counter as soon as possible (should help to prevent spamming...)
+ let counter = self.user_db.on_new_tx(&sender).unwrap_or_default();
+
let user_identity = RlnUserIdentity {
- secret_hash: user_id.0,
- commitment: user_id.1,
+ secret_hash: user_id.secret_hash,
+ commitment: user_id.commitment,
user_limit: Fr::from(self.spam_limit),
};
@@ -131,7 +112,7 @@ impl RlnProver for ProverService {
// Send some data to one of the proof services
self.proof_sender
- .send((user_identity, rln_identifier, counter))
+ .send((user_identity, rln_identifier, counter.into()))
.await
.map_err(|e| Status::from_error(Box::new(e)))?;
@@ -173,39 +154,100 @@ impl RlnProver for ProverService {
Ok(Response::new(ReceiverStream::new(rx)))
}
-}
-pub(crate) struct GrpcProverService {
- proof_sender: Sender<(RlnUserIdentity, Arc, u64)>,
- broadcast_channel: (broadcast::Sender>, broadcast::Receiver>),
- addr: SocketAddr,
- rln_identifier: RlnIdentifier,
- // epoch_counter: Arc,
-}
+ async fn get_user_tier_info(
+ &self,
+ request: Request,
+ ) -> Result, Status> {
+ debug!("request: {:?}", request);
-impl GrpcProverService {
- pub(crate) fn new(
- proof_sender: Sender<(RlnUserIdentity, Arc, u64)>,
- broadcast_channel: (broadcast::Sender>, broadcast::Receiver>),
- addr: SocketAddr,
- rln_identifier: RlnIdentifier,
- /* epoch_counter: Arc */
- ) -> Self {
- Self {
- proof_sender,
- broadcast_channel,
- addr,
- rln_identifier,
- // epoch_counter,
+ let req = request.into_inner();
+
+ let user = if let Some(user) = req.user {
+ if let Ok(user) = Address::try_from(user.value.as_slice()) {
+ user
+ } else {
+ return Err(Status::invalid_argument("Invalid user address"));
+ }
+ } else {
+ return Err(Status::invalid_argument("No user address"));
+ };
+
+ // TODO: SC call
+ struct MockKarmaSc {}
+
+ impl KarmaAmountExt for MockKarmaSc {
+ async fn karma_amount(&self, _address: &Address) -> U256 {
+ U256::from(10)
+ }
+ }
+ let tier_info = self.user_db.user_tier_info(&user, MockKarmaSc {}).await;
+
+ match tier_info {
+ Ok(tier_info) => Ok(Response::new(GetUserTierInfoReply {
+ resp: Some(Resp::Res(tier_info.into())),
+ })),
+ Err(e) => Ok(Response::new(GetUserTierInfoReply {
+ resp: Some(Resp::Error(e.into())),
+ })),
}
}
+ async fn set_tier_limits(
+ &self,
+ request: Request,
+ ) -> Result, Status> {
+ debug!("request: {:?}", request);
+
+ let request = request.into_inner();
+ let tier_limits: Option> = request
+ .karma_amounts
+ .iter()
+ .zip(request.tiers)
+ .map(|(k, tier)| {
+ let karma_amount = U256::try_from_le_slice(k.value.as_slice())?;
+ let karma_amount = KarmaAmount::from(karma_amount);
+ let tier_info = (
+ TierLimit::from(tier.quota),
+ TierName::from(tier.name.clone()),
+ );
+ Some((karma_amount, tier_info))
+ })
+ .collect();
+
+ if tier_limits.is_none() {
+ return Err(Status::invalid_argument("Invalid tier limits"));
+ }
+
+ // unwrap safe - just tested if None
+ let reply = match self.user_db.on_new_tier_limits(tier_limits.unwrap()) {
+ Ok(_) => SetTierLimitsReply {
+ status: true,
+ error: "".to_string(),
+ },
+ Err(e) => SetTierLimitsReply {
+ status: false,
+ error: e.to_string(),
+ },
+ };
+ Ok(Response::new(reply))
+ }
+}
+
+pub(crate) struct GrpcProverService {
+ pub proof_sender: Sender<(RlnUserIdentity, Arc, u64)>,
+ pub broadcast_channel: (broadcast::Sender>, broadcast::Receiver>),
+ pub addr: SocketAddr,
+ pub rln_identifier: RlnIdentifier,
+ pub user_db: UserDb,
+}
+
+impl GrpcProverService {
pub(crate) async fn serve(&self) -> Result<(), AppError> {
let prover_service = ProverService {
proof_sender: self.proof_sender.clone(),
- registry: Default::default(),
+ user_db: self.user_db.clone(),
rln_identifier: Arc::new(self.rln_identifier.clone()),
- message_counters: Default::default(),
spam_limit: PROVER_SPAM_LIMIT,
broadcast_channel: (
self.broadcast_channel.0.clone(),
@@ -262,6 +304,36 @@ impl GrpcProverService {
}
}
+/// UserTierInfo to UserTierInfoResult (Grpc message) conversion
+impl From for UserTierInfoResult {
+ fn from(tier_info: UserTierInfo) -> Self {
+ let mut res = UserTierInfoResult {
+ current_epoch: tier_info.current_epoch.into(),
+ current_epoch_slice: tier_info.current_epoch_slice.into(),
+ tx_count: tier_info.epoch_tx_count,
+ tier: None,
+ };
+
+ if tier_info.tier_name.is_some() && tier_info.tier_limit.is_some() {
+ res.tier = Some(Tier {
+ name: tier_info.tier_name.unwrap().into(),
+ quota: tier_info.tier_limit.unwrap().into(),
+ })
+ }
+
+ res
+ }
+}
+
+/// UserTierInfoError to UserTierInfoError (Grpc message) conversion
+impl From for UserTierInfoError {
+ fn from(value: crate::user_db_service::UserTierInfoError) -> Self {
+ UserTierInfoError {
+ message: value.to_string(),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::grpc_service::prover_proto::Address;
@@ -270,6 +342,7 @@ mod tests {
const MAX_ADDRESS_SIZE_BYTES: usize = 20;
#[test]
+ #[ignore]
#[should_panic]
fn test_address_size_limit() {
// Check if an invalid address can be encoded (as Address grpc type)
diff --git a/prover/src/main.rs b/prover/src/main.rs
index 7b47e7e..c355842 100644
--- a/prover/src/main.rs
+++ b/prover/src/main.rs
@@ -6,17 +6,14 @@ mod grpc_service;
mod proof_service;
mod registry;
mod registry_listener;
+mod tier;
+mod user_db_service;
// std
use std::net::SocketAddr;
use std::time::Duration;
// third-party
-use alloy::primitives::address;
use chrono::{DateTime, Utc};
-// use chrono::{
-// DateTime,
-// Utc
-// };
use clap::Parser;
use rln_proof::RlnIdentifier;
use tokio::task::JoinSet;
@@ -32,7 +29,7 @@ use crate::args::AppArgs;
use crate::epoch_service::EpochService;
use crate::grpc_service::GrpcProverService;
use crate::proof_service::ProofService;
-use crate::registry_listener::RegistryListener;
+use crate::user_db_service::UserDbService;
const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier";
const PROOF_SERVICE_COUNT: u8 = 8;
@@ -53,15 +50,21 @@ async fn main() -> Result<(), Box> {
// Smart contract
- let uniswap_token_address = address!("1f9840a85d5aF5bf1D1762F925BDADdC4201F984");
- let event = "Transfer(address,address,uint256)";
- let registry_listener =
- RegistryListener::new(app_args.rpc_url.as_str(), uniswap_token_address, event);
+ // let uniswap_token_address = address!("1f9840a85d5aF5bf1D1762F925BDADdC4201F984");
+ // let event = "Transfer(address,address,uint256)";
+ // let registry_listener =
+ // RegistryListener::new(app_args.rpc_url.as_str(), uniswap_token_address, event);
// Epoch
let epoch_service = EpochService::try_from((Duration::from_secs(60 * 2), GENESIS))
.expect("Failed to create epoch service");
+ // User db service
+ let user_db_service = UserDbService::new(
+ epoch_service.epoch_changes.clone(),
+ epoch_service.current_epoch.clone(),
+ );
+
// proof service
// FIXME: bound
let (tx, rx) = tokio::sync::broadcast::channel(2);
@@ -73,13 +76,13 @@ async fn main() -> Result<(), Box> {
let rln_identifier = RlnIdentifier::new(RLN_IDENTIFIER_NAME);
let addr = SocketAddr::new(app_args.ip, app_args.port);
debug!("Listening on: {}", addr);
- let prover_service = GrpcProverService::new(
+ let prover_grpc_service = GrpcProverService {
proof_sender,
- (tx.clone(), rx),
+ broadcast_channel: (tx.clone(), rx),
addr,
rln_identifier,
- // epoch_service.current_epoch.clone()
- );
+ user_db: user_db_service.get_user_db(),
+ };
let mut set = JoinSet::new();
for _i in 0..PROOF_SERVICE_COUNT {
@@ -94,7 +97,8 @@ async fn main() -> Result<(), Box> {
}
// set.spawn(async move { registry_listener.listen().await });
set.spawn(async move { epoch_service.listen_for_new_epoch().await });
- set.spawn(async move { prover_service.serve().await });
+ set.spawn(async move { user_db_service.listen_for_epoch_changes().await });
+ set.spawn(async move { prover_grpc_service.serve().await });
let _ = set.join_all().await;
Ok(())
diff --git a/prover/src/proof_service.rs b/prover/src/proof_service.rs
index 7e59bb3..253620b 100644
--- a/prover/src/proof_service.rs
+++ b/prover/src/proof_service.rs
@@ -28,6 +28,7 @@ enum ProofGenerationError {
Misc(String),
}
+/// A service to generate a RLN proof (and then to broadcast it)
#[derive(Debug)]
pub struct ProofService {
receiver: Receiver<(RlnUserIdentity, Arc, u64)>,
diff --git a/prover/src/registry.rs b/prover/src/registry.rs
index 0f37df8..fee0ace 100644
--- a/prover/src/registry.rs
+++ b/prover/src/registry.rs
@@ -1,9 +1,10 @@
-use alloy::primitives::Address;
-use ark_bn254::Fr;
-use dashmap::DashMap;
-use dashmap::mapref::one::Ref;
-use rln::protocol::keygen;
+// use alloy::primitives::Address;
+// use ark_bn254::Fr;
+// use dashmap::DashMap;
+// use dashmap::mapref::one::Ref;
+// use rln::protocol::keygen;
+/*
#[derive(Debug)]
pub(crate) struct UserRegistry {
inner: DashMap,
@@ -47,3 +48,4 @@ mod tests {
assert!(reg.get(&address).is_some());
}
}
+*/
diff --git a/prover/src/tier.rs b/prover/src/tier.rs
new file mode 100644
index 0000000..a315836
--- /dev/null
+++ b/prover/src/tier.rs
@@ -0,0 +1,60 @@
+use std::collections::BTreeMap;
+use std::sync::LazyLock;
+// third-party
+use alloy::primitives::U256;
+use derive_more::{From, Into};
+
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, From)]
+pub struct KarmaAmount(U256);
+
+impl KarmaAmount {
+ pub(crate) const ZERO: KarmaAmount = KarmaAmount(U256::ZERO);
+}
+
+impl From for KarmaAmount {
+ fn from(value: u64) -> Self {
+ Self(U256::from(value))
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, From, Into)]
+pub struct TierLimit(u64);
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash, From, Into)]
+pub struct TierName(String);
+
+impl From<&str> for TierName {
+ fn from(value: &str) -> Self {
+ Self(value.to_string())
+ }
+}
+
+pub static TIER_LIMITS: LazyLock> =
+ LazyLock::new(|| {
+ BTreeMap::from([
+ (
+ KarmaAmount::from(10),
+ (TierLimit(6), TierName::from("Basic")),
+ ),
+ (
+ KarmaAmount::from(50),
+ (TierLimit(120), TierName::from("Active")),
+ ),
+ (
+ KarmaAmount::from(100),
+ (TierLimit(720), TierName::from("Regular")),
+ ),
+ (
+ KarmaAmount::from(500),
+ (TierLimit(14440), TierName::from("Regular")),
+ ),
+ (
+ KarmaAmount::from(1000),
+ (TierLimit(86400), TierName::from("Power User")),
+ ),
+ (
+ KarmaAmount::from(5000),
+ (TierLimit(432000), TierName::from("S-Tier")),
+ ),
+ ])
+ });
diff --git a/prover/src/user_db_service.rs b/prover/src/user_db_service.rs
new file mode 100644
index 0000000..04d4df9
--- /dev/null
+++ b/prover/src/user_db_service.rs
@@ -0,0 +1,546 @@
+use std::collections::{BTreeMap, HashSet};
+use std::ops::Bound::Included;
+use std::ops::Deref;
+use std::sync::Arc;
+// third-party
+use alloy::primitives::{Address, U256};
+use derive_more::{Add, From, Into};
+use parking_lot::RwLock;
+use rln::protocol::keygen;
+use scc::HashMap;
+use tokio::sync::Notify;
+use tracing::debug;
+// internal
+use crate::epoch_service::{Epoch, EpochSlice};
+use crate::error::AppError;
+use crate::tier::{KarmaAmount, TIER_LIMITS, TierLimit, TierName};
+use rln_proof::RlnUserIdentity;
+
+#[derive(Debug, Default, Clone)]
+pub(crate) struct UserRegistry {
+ inner: HashMap,
+}
+impl UserRegistry {
+ fn register(&self, address: Address) -> bool {
+ let (identity_secret_hash, id_commitment) = keygen();
+ self.inner
+ .insert(
+ address,
+ RlnUserIdentity::from((identity_secret_hash, id_commitment)),
+ )
+ .is_ok()
+ }
+
+ fn has_user(&self, address: &Address) -> bool {
+ self.inner.contains(address)
+ }
+
+ fn get_user(&self, address: &Address) -> Option {
+ self.inner.get(address).map(|entry| entry.clone())
+ }
+}
+
+#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
+pub(crate) struct EpochCounter(u64);
+
+#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
+pub(crate) struct EpochSliceCounter(u64);
+
+#[derive(Debug, Default, Clone)]
+pub(crate) struct TxRegistry {
+ inner: HashMap,
+}
+
+impl Deref for TxRegistry {
+ type Target = HashMap;
+
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+impl TxRegistry {
+ /// Update the transaction counter for the given address
+ ///
+ /// If incr_value is None, the counter will be incremented by 1
+ /// If incr_value is Some(x), the counter will be incremented by x
+ ///
+ /// Returns the new value of the counter
+ pub fn incr_counter(&self, address: &Address, incr_value: Option) -> EpochSliceCounter {
+ let incr_value = incr_value.unwrap_or(1);
+ let mut entry = self.inner.entry(*address).or_default();
+ *entry = (
+ entry.0 + EpochCounter(incr_value),
+ entry.1 + EpochSliceCounter(incr_value),
+ );
+ entry.1
+ }
+}
+
+#[derive(Debug, PartialEq)]
+pub struct UserTierInfo {
+ pub(crate) current_epoch: Epoch,
+ pub(crate) current_epoch_slice: EpochSlice,
+ pub(crate) epoch_tx_count: u64,
+ pub(crate) epoch_slice_tx_count: u64,
+ karma_amount: U256,
+ pub(crate) tier_name: Option,
+ pub(crate) tier_limit: Option,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum UserTierInfoError {
+ #[error("User {0} not registered")]
+ NotRegistered(Address),
+}
+
+pub trait KarmaAmountExt {
+ async fn karma_amount(&self, address: &Address) -> U256;
+}
+
+/// User registration + tx counters + tier limits storage
+#[derive(Debug, Clone)]
+pub struct UserDb {
+ user_registry: Arc,
+ tx_registry: Arc,
+ tier_limits: Arc>>,
+ tier_limits_next: Arc>>,
+ epoch_store: Arc>,
+}
+
+impl UserDb {
+ fn on_new_epoch(&self) {
+ self.tx_registry.clear()
+ }
+
+ fn on_new_epoch_slice(&self) {
+ self.tx_registry.retain(|_a, v| {
+ *v = (v.0, Default::default());
+ true
+ });
+
+ let tier_limits_next_has_updates = !self.tier_limits_next.read().is_empty();
+ if tier_limits_next_has_updates {
+ let mut guard = self.tier_limits_next.write();
+ // mem::take will clear the BTreeMap in tier_limits_next
+ let new_tier_limits = std::mem::take(&mut *guard);
+ debug!("Installing new tier limits: {:?}", new_tier_limits);
+ *self.tier_limits.write() = new_tier_limits;
+ }
+ }
+
+ pub fn get_user(&self, address: &Address) -> Option {
+ self.user_registry.get_user(address)
+ }
+
+ pub(crate) fn on_new_tx(&self, address: &Address) -> Option {
+ if self.user_registry.has_user(address) {
+ Some(self.tx_registry.incr_counter(address, None))
+ } else {
+ None
+ }
+ }
+
+ pub(crate) fn on_new_tier_limits(
+ &self,
+ tier_limits: BTreeMap,
+ ) -> Result<(), SetTierLimitsError> {
+ #[derive(Default)]
+ struct Context<'a> {
+ tier_names: HashSet,
+ prev_karma_amount: Option<&'a KarmaAmount>,
+ prev_tier_limit: Option<&'a TierLimit>,
+ i: usize,
+ }
+
+ let _context = tier_limits.iter().try_fold(
+ Context::default(),
+ |mut state, (karma_amount, (tier_limit, tier_name))| {
+ if karma_amount <= state.prev_karma_amount.unwrap_or(&KarmaAmount::ZERO) {
+ return Err(SetTierLimitsError::InvalidKarmaAmount);
+ }
+
+ if tier_limit <= state.prev_tier_limit.unwrap_or(&TierLimit::from(0)) {
+ return Err(SetTierLimitsError::InvalidTierLimit);
+ }
+
+ if state.tier_names.contains(tier_name) {
+ return Err(SetTierLimitsError::NonUniqueTierName);
+ }
+
+ state.prev_karma_amount = Some(karma_amount);
+ state.prev_tier_limit = Some(tier_limit);
+ state.tier_names.insert(tier_name.clone());
+ state.i += 1;
+ Ok(state)
+ },
+ )?;
+
+ *self.tier_limits_next.write() = tier_limits;
+ Ok(())
+ }
+
+ /// Get user tier info
+ pub(crate) async fn user_tier_info(
+ &self,
+ address: &Address,
+ karma_sc: KSC,
+ ) -> Result {
+ if self.user_registry.has_user(address) {
+ let (epoch_tx_count, epoch_slice_tx_count) = self
+ .tx_registry
+ .get(address)
+ .map(|ref_v| (ref_v.0, ref_v.1))
+ .unwrap_or_default();
+
+ let karma_amount = karma_sc.karma_amount(address).await;
+ let guard = self.tier_limits.read();
+ let range_res = guard.range((
+ Included(&KarmaAmount::ZERO),
+ Included(&KarmaAmount::from(karma_amount)),
+ ));
+ let tier_info: Option<(TierLimit, TierName)> =
+ range_res.into_iter().last().map(|o| o.1.clone());
+ drop(guard);
+
+ let user_tier_info = {
+ let (current_epoch, current_epoch_slice) = *self.epoch_store.read();
+ let mut t = UserTierInfo {
+ current_epoch,
+ current_epoch_slice,
+ epoch_tx_count: epoch_tx_count.into(),
+ epoch_slice_tx_count: epoch_slice_tx_count.into(),
+ karma_amount,
+ tier_name: None,
+ tier_limit: None,
+ };
+ if let Some((tier_limit, tier_name)) = tier_info {
+ t.tier_name = Some(tier_name);
+ t.tier_limit = Some(tier_limit);
+ }
+ t
+ };
+
+ Ok(user_tier_info)
+ } else {
+ Err(UserTierInfoError::NotRegistered(*address))
+ }
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum SetTierLimitsError {
+ #[error("Invalid Karma amount (must be increasing)")]
+ InvalidKarmaAmount,
+ #[error("Invalid Tier limit (must be increasing)")]
+ InvalidTierLimit,
+ #[error("Non unique Tier name")]
+ NonUniqueTierName,
+}
+
+/// Async service to update a UserDb on epoch changes
+#[derive(Debug)]
+pub struct UserDbService {
+ user_db: UserDb,
+ epoch_changes: Arc,
+}
+
+impl UserDbService {
+ pub(crate) fn new(
+ epoch_changes_notifier: Arc,
+ epoch_store: Arc>,
+ ) -> Self {
+ Self {
+ user_db: UserDb {
+ user_registry: Default::default(),
+ tx_registry: Default::default(),
+ tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())),
+ tier_limits_next: Arc::new(Default::default()),
+ epoch_store,
+ },
+ epoch_changes: epoch_changes_notifier,
+ }
+ }
+
+ pub fn get_user_db(&self) -> UserDb {
+ self.user_db.clone()
+ }
+
+ pub async fn listen_for_epoch_changes(&self) -> Result<(), AppError> {
+ let (mut current_epoch, mut current_epoch_slice) = *self.user_db.epoch_store.read();
+
+ loop {
+ self.epoch_changes.notified().await;
+ let (new_epoch, new_epoch_slice) = *self.user_db.epoch_store.read();
+ debug!(
+ "new epoch: {:?}, new epoch slice: {:?}",
+ new_epoch, new_epoch_slice
+ );
+ self.update_on_epoch_changes(
+ &mut current_epoch,
+ new_epoch,
+ &mut current_epoch_slice,
+ new_epoch_slice,
+ );
+ }
+ }
+
+ /// Internal - used by listen_for_epoch_changes
+ fn update_on_epoch_changes(
+ &self,
+ current_epoch: &mut Epoch,
+ new_epoch: Epoch,
+ current_epoch_slice: &mut EpochSlice,
+ new_epoch_slice: EpochSlice,
+ ) {
+ if new_epoch > *current_epoch {
+ self.user_db.on_new_epoch()
+ } else if new_epoch_slice > *current_epoch_slice {
+ self.user_db.on_new_epoch_slice()
+ }
+
+ *current_epoch = new_epoch;
+ *current_epoch_slice = new_epoch_slice;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use alloy::primitives::address;
+ use claims::assert_err;
+
+ struct MockKarmaSc {}
+
+ impl KarmaAmountExt for MockKarmaSc {
+ async fn karma_amount(&self, _address: &Address) -> U256 {
+ U256::from(10)
+ }
+ }
+
+ const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
+ const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0");
+
+ struct MockKarmaSc2 {}
+
+ impl KarmaAmountExt for MockKarmaSc2 {
+ async fn karma_amount(&self, address: &Address) -> U256 {
+ if address == &ADDR_1 {
+ U256::from(10)
+ } else if address == &ADDR_2 {
+ U256::from(2000)
+ } else {
+ U256::ZERO
+ }
+ }
+ }
+
+ #[tokio::test]
+ async fn test_incr_tx_counter() {
+ let user_db = UserDb {
+ user_registry: Default::default(),
+ tx_registry: Default::default(),
+ tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())),
+ tier_limits_next: Arc::new(Default::default()),
+ epoch_store: Arc::new(RwLock::new(Default::default())),
+ };
+ let addr = Address::new([0; 20]);
+
+ // Try to update tx counter without registering first
+ assert_eq!(user_db.on_new_tx(&addr), None);
+ let tier_info = user_db.user_tier_info(&addr, MockKarmaSc {}).await;
+ // User is not registered -> no tier info
+ assert!(matches!(
+ tier_info,
+ Err(UserTierInfoError::NotRegistered(_))
+ ));
+ // Register user + update tx counter
+ user_db.user_registry.register(addr);
+ assert_eq!(user_db.on_new_tx(&addr), Some(EpochSliceCounter(1)));
+ let tier_info = user_db.user_tier_info(&addr, MockKarmaSc {}).await.unwrap();
+ assert_eq!(tier_info.epoch_tx_count, 1);
+ assert_eq!(tier_info.epoch_slice_tx_count, 1);
+ }
+
+ #[tokio::test]
+ async fn test_update_on_epoch_changes() {
+ let mut epoch = Epoch::from(11);
+ let mut epoch_slice = EpochSlice::from(42);
+ let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
+ let user_db_service = UserDbService::new(Default::default(), epoch_store);
+ let user_db = user_db_service.get_user_db();
+
+ let addr_1_tx_count = 2;
+ let addr_2_tx_count = 820;
+ user_db.user_registry.register(ADDR_1);
+ user_db
+ .tx_registry
+ .incr_counter(&ADDR_1, Some(addr_1_tx_count));
+ user_db.user_registry.register(ADDR_2);
+ user_db
+ .tx_registry
+ .incr_counter(&ADDR_2, Some(addr_2_tx_count));
+
+ // incr epoch slice (42 -> 43)
+ {
+ let new_epoch = epoch;
+ let new_epoch_slice = epoch_slice + 1;
+ user_db_service.update_on_epoch_changes(
+ &mut epoch,
+ new_epoch,
+ &mut epoch_slice,
+ new_epoch_slice,
+ );
+ let addr_1_tier_info = user_db
+ .user_tier_info(&ADDR_1, MockKarmaSc2 {})
+ .await
+ .unwrap();
+ assert_eq!(addr_1_tier_info.epoch_tx_count, addr_1_tx_count);
+ assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0);
+ assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic")));
+
+ let addr_2_tier_info = user_db
+ .user_tier_info(&ADDR_2, MockKarmaSc2 {})
+ .await
+ .unwrap();
+ assert_eq!(addr_2_tier_info.epoch_tx_count, addr_2_tx_count);
+ assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0);
+ assert_eq!(
+ addr_2_tier_info.tier_name,
+ Some(TierName::from("Power User"))
+ );
+ }
+
+ // incr epoch (11 -> 12, epoch slice reset)
+ {
+ let new_epoch = epoch + 1;
+ let new_epoch_slice = EpochSlice::from(0);
+ user_db_service.update_on_epoch_changes(
+ &mut epoch,
+ new_epoch,
+ &mut epoch_slice,
+ new_epoch_slice,
+ );
+ let addr_1_tier_info = user_db
+ .user_tier_info(&ADDR_1, MockKarmaSc2 {})
+ .await
+ .unwrap();
+ assert_eq!(addr_1_tier_info.epoch_tx_count, 0);
+ assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0);
+ assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic")));
+
+ let addr_2_tier_info = user_db
+ .user_tier_info(&ADDR_2, MockKarmaSc2 {})
+ .await
+ .unwrap();
+ assert_eq!(addr_2_tier_info.epoch_tx_count, 0);
+ assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0);
+ assert_eq!(
+ addr_2_tier_info.tier_name,
+ Some(TierName::from("Power User"))
+ );
+ }
+ }
+
+ #[test]
+ #[tracing_test::traced_test]
+ fn test_set_tier_limits() {
+ // Check if we can update tier limits (and it updates after an epoch slice change)
+
+ let mut epoch = Epoch::from(11);
+ let mut epoch_slice = EpochSlice::from(42);
+ let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
+ let user_db_service = UserDbService::new(Default::default(), epoch_store);
+ let user_db = user_db_service.get_user_db();
+ let tier_limits_original = user_db.tier_limits.read().clone();
+
+ let tier_limits = BTreeMap::from([
+ (
+ KarmaAmount::from(100),
+ (TierLimit::from(100), TierName::from("Basic")),
+ ),
+ (
+ KarmaAmount::from(200),
+ (TierLimit::from(200), TierName::from("Power User")),
+ ),
+ (
+ KarmaAmount::from(300),
+ (TierLimit::from(300), TierName::from("Elite User")),
+ ),
+ ]);
+
+ user_db.on_new_tier_limits(tier_limits.clone()).unwrap();
+ // Check it is not yet installed
+ assert_ne!(*user_db.tier_limits.read(), tier_limits);
+ assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
+ assert_eq!(*user_db.tier_limits_next.read(), tier_limits);
+
+ let new_epoch = epoch;
+ let new_epoch_slice = epoch_slice + 1;
+ user_db_service.update_on_epoch_changes(
+ &mut epoch,
+ new_epoch,
+ &mut epoch_slice,
+ new_epoch_slice,
+ );
+
+ // Should be installed now
+ assert_eq!(*user_db.tier_limits.read(), tier_limits);
+ // And the tier_limits_next field is expected to be empty
+ assert_eq!(*user_db.tier_limits_next.read(), BTreeMap::new());
+ }
+
+ #[test]
+ #[tracing_test::traced_test]
+ fn test_set_invalid_tier_limits() {
+ // Check we cannot update with invalid tier limits
+
+ let epoch = Epoch::from(11);
+ let epoch_slice = EpochSlice::from(42);
+ let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
+ let user_db_service = UserDbService::new(Default::default(), epoch_store);
+ let user_db = user_db_service.get_user_db();
+
+ let tier_limits_original = user_db.tier_limits.read().clone();
+
+ {
+ let tier_limits = BTreeMap::from([
+ (
+ KarmaAmount::from(100),
+ (TierLimit::from(100), TierName::from("Basic")),
+ ),
+ (
+ KarmaAmount::from(200),
+ (TierLimit::from(200), TierName::from("Power User")),
+ ),
+ (
+ KarmaAmount::from(199),
+ (TierLimit::from(300), TierName::from("Elite User")),
+ ),
+ ]);
+
+ assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
+ assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
+ }
+
+ {
+ let tier_limits = BTreeMap::from([
+ (
+ KarmaAmount::from(100),
+ (TierLimit::from(100), TierName::from("Basic")),
+ ),
+ (
+ KarmaAmount::from(200),
+ (TierLimit::from(200), TierName::from("Power User")),
+ ),
+ (
+ KarmaAmount::from(300),
+ (TierLimit::from(300), TierName::from("Basic")),
+ ),
+ ]);
+
+ assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
+ assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
+ }
+ }
+}
diff --git a/rln_proof/Cargo.toml b/rln_proof/Cargo.toml
index f831246..3edc0ed 100644
--- a/rln_proof/Cargo.toml
+++ b/rln_proof/Cargo.toml
@@ -4,23 +4,9 @@ version = "0.1.0"
edition = "2024"
[dependencies]
-# clap = { version = "4.5.37", features = ["derive"] }
-# tonic = { version = "0.13", features = ["gzip"] }
-# tonic-reflection = "0.13"
-# prost = "0.13"
-# tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
-# tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
-# tracing = "0.1.41"
-# tracing-test = "0.2.5"
-# alloy = { version = "0.15", features = ["full"] }
-# thiserror = "2.0"
-# futures = "0.3"
rln = { git = "https://github.com/vacp2p/zerokit", package = "rln", features = ["default"] }
zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] }
ark-bn254 = { version = "0.5", features = ["std"] }
-serde_json = "1.0"
-# dashmap = "6.1.0"
-# bytesize = "2.0.1"
ark-groth16 = "*"
ark-relations = "*"
ark-serialize = "0.5.0"
diff --git a/rln_proof/src/proof.rs b/rln_proof/src/proof.rs
index c008389..bf422ed 100644
--- a/rln_proof/src/proof.rs
+++ b/rln_proof/src/proof.rs
@@ -12,12 +12,23 @@ use rln::protocol::{
};
/// A RLN user identity & limit
+#[derive(Debug, Clone)]
pub struct RlnUserIdentity {
pub commitment: Fr,
pub secret_hash: Fr,
pub user_limit: Fr,
}
+impl From<(Fr, Fr)> for RlnUserIdentity {
+ fn from((commitment, secret_hash): (Fr, Fr)) -> Self {
+ Self {
+ commitment,
+ secret_hash,
+ user_limit: Fr::from(0),
+ }
+ }
+}
+
/// RLN info for a channel / group
#[derive(Debug, Clone)]
pub struct RlnIdentifier {