mirror of
https://github.com/vacp2p/status-rln-prover.git
synced 2026-01-08 05:03:54 -05:00
Add initial code for user db (user reg + tx counter + tier info) (#2)
* Add initial code for user db (user reg + tx counter + tier info) * Switch to scc hashmap (+ updated benchmark) * Add KSC into UserDB * Start user db service in prover main * Separate between user db service & user db * Add set tier limits grpc endpoint * Set tier limits unit test * Use derive_more
This commit is contained in:
54
Cargo.lock
generated
54
Cargo.lock
generated
@@ -322,6 +322,7 @@ dependencies = [
|
||||
"const-hex",
|
||||
"derive_more",
|
||||
"foldhash",
|
||||
"getrandom 0.3.2",
|
||||
"hashbrown 0.15.2",
|
||||
"indexmap 2.9.0",
|
||||
"itoa",
|
||||
@@ -1557,6 +1558,12 @@ dependencies = [
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "claims"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.37"
|
||||
@@ -1753,6 +1760,29 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"itertools 0.13.0",
|
||||
"num-traits",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
@@ -3886,9 +3916,8 @@ dependencies = [
|
||||
"ark-groth16",
|
||||
"ark-relations",
|
||||
"ark-serialize 0.5.0",
|
||||
"criterion",
|
||||
"criterion 0.5.1",
|
||||
"rln",
|
||||
"serde_json",
|
||||
"zerokit_utils",
|
||||
]
|
||||
|
||||
@@ -4057,6 +4086,15 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scc"
|
||||
version = "2.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22b2d775fb28f245817589471dd49c5edf64237f4a19d10ce9a92ff4651a27f4"
|
||||
dependencies = [
|
||||
"sdd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.27"
|
||||
@@ -4072,6 +4110,12 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "sdd"
|
||||
version = "3.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21"
|
||||
|
||||
[[package]]
|
||||
name = "sec1"
|
||||
version = "0.7.3"
|
||||
@@ -4367,15 +4411,19 @@ dependencies = [
|
||||
"async-channel",
|
||||
"bytesize",
|
||||
"chrono",
|
||||
"claims",
|
||||
"clap",
|
||||
"criterion 0.6.0",
|
||||
"dashmap",
|
||||
"derive_more",
|
||||
"futures",
|
||||
"http",
|
||||
"parking_lot 0.12.3",
|
||||
"prost",
|
||||
"rand 0.8.5",
|
||||
"rln",
|
||||
"rln_proof",
|
||||
"serde_json",
|
||||
"scc",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tonic",
|
||||
|
||||
@@ -14,6 +14,8 @@ service RlnProver {
|
||||
// Server side streaming RPC: 1 request -> X responses (stream)
|
||||
rpc GetProofs(RlnProofFilter) returns (stream RlnProof);
|
||||
|
||||
rpc GetUserTierInfo(GetUserTierInfoRequest) returns (GetUserTierInfoReply);
|
||||
rpc SetTierLimits(SetTierLimitsRequest) returns (SetTierLimitsReply);
|
||||
}
|
||||
|
||||
// TransactionType: https://github.com/Consensys/linea-besu/blob/09cbed1142cfe4d29b50ecf2f156639a4bc8c854/datatypes/src/main/java/org/hyperledger/besu/datatypes/TransactionType.java#L22
|
||||
@@ -142,3 +144,41 @@ message RegisterUserReply {
|
||||
RegistrationStatus status = 1;
|
||||
}
|
||||
|
||||
message GetUserTierInfoRequest {
|
||||
Address user = 1;
|
||||
}
|
||||
|
||||
message GetUserTierInfoReply {
|
||||
oneof resp {
|
||||
// variant for success
|
||||
UserTierInfoResult res = 1;
|
||||
// variant for error
|
||||
UserTierInfoError error = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message UserTierInfoError {
|
||||
string message = 1;
|
||||
}
|
||||
|
||||
message UserTierInfoResult {
|
||||
sint64 current_epoch = 1;
|
||||
sint64 current_epoch_slice = 2;
|
||||
uint64 tx_count = 3;
|
||||
optional Tier tier = 4;
|
||||
}
|
||||
|
||||
message Tier {
|
||||
string name = 1;
|
||||
uint64 quota = 2;
|
||||
}
|
||||
|
||||
message SetTierLimitsRequest {
|
||||
repeated U256 karmaAmounts = 1;
|
||||
repeated Tier tiers = 2;
|
||||
}
|
||||
|
||||
message SetTierLimitsReply {
|
||||
bool status = 1;
|
||||
string error = 2;
|
||||
}
|
||||
@@ -13,14 +13,14 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
tracing = "0.1.41"
|
||||
tracing-test = "0.2.5"
|
||||
alloy = { version = "0.15", features = ["full"] }
|
||||
alloy = { version = "0.15", features = ["full", "getrandom"] }
|
||||
thiserror = "2.0"
|
||||
futures = "0.3"
|
||||
rln = { git = "https://github.com/vacp2p/zerokit" }
|
||||
ark-bn254 = { version = "0.5", features = ["std"] }
|
||||
ark-serialize = "0.5.0"
|
||||
serde_json = "1.0"
|
||||
dashmap = "6.1.0"
|
||||
scc = "2.3"
|
||||
bytesize = "2.0.1"
|
||||
rln_proof = { path = "../rln_proof" }
|
||||
chrono = "0.4.41"
|
||||
@@ -28,6 +28,16 @@ parking_lot = "0.12.3"
|
||||
tower-http = { version = "0.6.4", features = ["cors"] }
|
||||
http = "*"
|
||||
async-channel = "2.3.1"
|
||||
rand = "0.8.5"
|
||||
derive_more = "2.0.1"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.6.0"
|
||||
claims = "0.8"
|
||||
|
||||
[[bench]]
|
||||
name = "user_db_heavy_write"
|
||||
harness = false
|
||||
|
||||
59
prover/benches/user_db_heavy_write.rs
Normal file
59
prover/benches/user_db_heavy_write.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use alloy::primitives::Address;
|
||||
use std::hint::black_box;
|
||||
// criterion
|
||||
use criterion::{
|
||||
BenchmarkId,
|
||||
Criterion,
|
||||
Throughput,
|
||||
criterion_group,
|
||||
criterion_main,
|
||||
// black_box
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use rand::Rng;
|
||||
use scc::HashMap;
|
||||
|
||||
pub fn criterion_benchmark(c: &mut Criterion) {
|
||||
let size = 1_250_000;
|
||||
let mut rng = rand::thread_rng();
|
||||
let d_1m: DashMap<Address, (u64, u64)> = DashMap::with_capacity(size as usize);
|
||||
let scc_1m: HashMap<Address, (u64, u64)> = HashMap::with_capacity(size as usize);
|
||||
|
||||
(0..size).into_iter().for_each(|_i| {
|
||||
let mut addr = Address::new([0; 20]);
|
||||
addr.0.randomize();
|
||||
let n1 = rng.r#gen();
|
||||
let n2 = rng.r#gen();
|
||||
d_1m.insert(addr, (n1, n2));
|
||||
scc_1m.insert(addr, (n1, n2)).unwrap();
|
||||
});
|
||||
|
||||
let mut group = c.benchmark_group("Scc versus DashMap alter_all");
|
||||
|
||||
group.throughput(Throughput::Elements(size));
|
||||
group.bench_function(BenchmarkId::new("Dashmap", size), |b| {
|
||||
b.iter(|| {
|
||||
black_box(d_1m.alter_all(|_, v| black_box((v.0, 0))));
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("Scc", size), |b| {
|
||||
b.iter(|| {
|
||||
black_box(scc_1m.retain(|_, v| {
|
||||
black_box(*v = (v.0, 0));
|
||||
black_box(true)
|
||||
}));
|
||||
})
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default()
|
||||
.measurement_time(std::time::Duration::from_secs(45))
|
||||
.sample_size(10);
|
||||
targets = criterion_benchmark
|
||||
}
|
||||
criterion_main!(benches);
|
||||
@@ -4,6 +4,7 @@ use std::time::Duration;
|
||||
// third-party
|
||||
use chrono::{DateTime, NaiveDate, NaiveDateTime, OutOfRangeError, TimeDelta, Utc};
|
||||
use parking_lot::RwLock;
|
||||
use tokio::sync::Notify;
|
||||
use tracing::debug;
|
||||
// internal
|
||||
use crate::error::AppError;
|
||||
@@ -24,6 +25,8 @@ pub struct EpochService {
|
||||
pub current_epoch: Arc<RwLock<(Epoch, EpochSlice)>>,
|
||||
/// Genesis time (aka when the service has been started at the first time)
|
||||
genesis: DateTime<Utc>,
|
||||
/// Channel to notify when an epoch / epoch slice has just changed
|
||||
pub epoch_changes: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl EpochService {
|
||||
@@ -66,7 +69,13 @@ impl EpochService {
|
||||
current_epoch += 1;
|
||||
}
|
||||
*self.current_epoch.write() = (current_epoch.into(), current_epoch_slice.into());
|
||||
debug!("epoch: {}, epoch slice: {}", current_epoch, current_epoch_slice);
|
||||
debug!(
|
||||
"epoch: {}, epoch slice: {}",
|
||||
current_epoch, current_epoch_slice
|
||||
);
|
||||
|
||||
// println!("Epoch changed: {}", current_epoch);
|
||||
self.epoch_changes.notify_one();
|
||||
}
|
||||
|
||||
// Ok(())
|
||||
@@ -186,10 +195,9 @@ impl TryFrom<(Duration, DateTime<Utc>)> for EpochService {
|
||||
|
||||
Ok(Self {
|
||||
epoch_slice_duration,
|
||||
// current_epoch: Arc::new(AtomicI64::new(0)),
|
||||
// current_epoch_slice: Arc::new(AtomicI64::new(0)),
|
||||
current_epoch: Arc::new(Default::default()),
|
||||
genesis,
|
||||
epoch_changes: Arc::new(Default::default()),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -203,8 +211,17 @@ pub enum WaitUntilError {
|
||||
}
|
||||
|
||||
/// An Epoch (wrapper type over i64)
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub(crate) struct Epoch(pub(crate) i64);
|
||||
|
||||
impl Add<i64> for Epoch {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: i64) -> Self::Output {
|
||||
Self(self.0 + rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i64> for Epoch {
|
||||
fn from(value: i64) -> Self {
|
||||
Self(value)
|
||||
@@ -218,9 +235,17 @@ impl From<Epoch> for i64 {
|
||||
}
|
||||
|
||||
/// An Epoch slice (wrapper type over i64)
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub(crate) struct EpochSlice(pub(crate) i64);
|
||||
|
||||
impl Add<i64> for EpochSlice {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: i64) -> Self::Output {
|
||||
Self(self.0 + rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i64> for EpochSlice {
|
||||
fn from(value: i64) -> Self {
|
||||
Self(value)
|
||||
@@ -237,6 +262,9 @@ impl From<EpochSlice> for i64 {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{NaiveDate, NaiveDateTime, TimeDelta};
|
||||
use futures::TryFutureExt;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing_test::traced_test;
|
||||
|
||||
/*
|
||||
#[tokio::test]
|
||||
@@ -251,14 +279,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_wait_until() {
|
||||
// Check wait_until is correctly computed
|
||||
|
||||
let date_0 = NaiveDate::from_ymd_opt(2025, 5, 14).unwrap();
|
||||
let datetime_0 = date_0.and_hms_opt(0, 0, 0).unwrap();
|
||||
|
||||
{
|
||||
// standard wait until in epoch 0, epoch slice 1
|
||||
|
||||
let genesis: DateTime<Utc> =
|
||||
chrono::DateTime::from_naive_utc_and_offset(datetime_0, chrono::Utc);
|
||||
let genesis: DateTime<Utc> = DateTime::from_naive_utc_and_offset(datetime_0, Utc);
|
||||
let epoch_slice_duration = Duration::from_secs(60);
|
||||
let epoch_service = EpochService::try_from((epoch_slice_duration, genesis)).unwrap();
|
||||
|
||||
@@ -266,7 +295,7 @@ mod tests {
|
||||
let mut now_0: NaiveDateTime = date_0.and_hms_opt(0, 0, 0).unwrap();
|
||||
// Set now_0 to be in epoch slice 1
|
||||
now_0 += epoch_slice_duration;
|
||||
chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc)
|
||||
DateTime::from_naive_utc_and_offset(now_0, chrono::Utc)
|
||||
};
|
||||
|
||||
let (epoch, epoch_slice, wait_until): (_, _, DateTime<Utc>) =
|
||||
@@ -276,7 +305,7 @@ mod tests {
|
||||
assert_eq!(epoch_slice, 1);
|
||||
assert_eq!(
|
||||
wait_until,
|
||||
chrono::DateTime::<Utc>::from_naive_utc_and_offset(datetime_0, chrono::Utc)
|
||||
DateTime::<Utc>::from_naive_utc_and_offset(datetime_0, Utc)
|
||||
+ 2 * epoch_slice_duration
|
||||
);
|
||||
}
|
||||
@@ -284,8 +313,7 @@ mod tests {
|
||||
{
|
||||
// standard wait until (but in epoch 1)
|
||||
|
||||
let genesis: DateTime<Utc> =
|
||||
chrono::DateTime::from_naive_utc_and_offset(datetime_0, chrono::Utc);
|
||||
let genesis: DateTime<Utc> = DateTime::from_naive_utc_and_offset(datetime_0, Utc);
|
||||
let epoch_slice_duration = Duration::from_secs(60);
|
||||
let epoch_service = EpochService::try_from((epoch_slice_duration, genesis)).unwrap();
|
||||
|
||||
@@ -297,7 +325,7 @@ mod tests {
|
||||
now_0 += epoch_slice_duration;
|
||||
// Add 30 secs (but should still wait until epoch slice 2 starts)
|
||||
now_0 += epoch_slice_duration / 2;
|
||||
chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc)
|
||||
chrono::DateTime::from_naive_utc_and_offset(now_0, Utc)
|
||||
};
|
||||
|
||||
let (epoch, epoch_slice, wait_until): (_, _, DateTime<Utc>) =
|
||||
@@ -307,7 +335,7 @@ mod tests {
|
||||
assert_eq!(epoch_slice, 1);
|
||||
assert_eq!(
|
||||
wait_until,
|
||||
chrono::DateTime::<Utc>::from_naive_utc_and_offset(datetime_0, chrono::Utc)
|
||||
DateTime::<Utc>::from_naive_utc_and_offset(datetime_0, Utc)
|
||||
+ EPOCH_DURATION
|
||||
+ 2 * epoch_slice_duration
|
||||
);
|
||||
@@ -317,7 +345,7 @@ mod tests {
|
||||
// Check for WaitUntilError::TooLow
|
||||
|
||||
let genesis: DateTime<Utc> =
|
||||
chrono::DateTime::from_naive_utc_and_offset(datetime_0, chrono::Utc);
|
||||
chrono::DateTime::from_naive_utc_and_offset(datetime_0, Utc);
|
||||
let epoch_slice_duration = Duration::from_secs(60);
|
||||
let epoch_service = EpochService::try_from((epoch_slice_duration, genesis)).unwrap();
|
||||
let epoch_slice_duration_minus_1 =
|
||||
@@ -389,17 +417,20 @@ mod tests {
|
||||
|
||||
let now_f = move || {
|
||||
let now_0: NaiveDateTime = day.and_hms_opt(0, 4, 0).unwrap();
|
||||
let now: DateTime<Utc> = chrono::DateTime::from_utc(now_0, chrono::Utc);
|
||||
let now: DateTime<Utc> =
|
||||
chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc);
|
||||
now
|
||||
};
|
||||
let now_f_2 = move || {
|
||||
let now_0: NaiveDateTime = day.and_hms_opt(0, 5, 59).unwrap();
|
||||
let now: DateTime<Utc> = chrono::DateTime::from_utc(now_0, chrono::Utc);
|
||||
let now: DateTime<Utc> =
|
||||
chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc);
|
||||
now
|
||||
};
|
||||
let now_f_3 = move || {
|
||||
let now_0: NaiveDateTime = day.and_hms_opt(0, 6, 0).unwrap();
|
||||
let now: DateTime<Utc> = chrono::DateTime::from_utc(now_0, chrono::Utc);
|
||||
let now: DateTime<Utc> =
|
||||
chrono::DateTime::from_naive_utc_and_offset(now_0, chrono::Utc);
|
||||
now
|
||||
};
|
||||
|
||||
@@ -430,4 +461,45 @@ mod tests {
|
||||
3
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
enum AppErrorExt {
|
||||
#[error("AppError: {0}")]
|
||||
AppError(#[from] AppError),
|
||||
#[error("Future timeout")]
|
||||
Elapsed,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[traced_test]
|
||||
async fn test_notify() {
|
||||
// Test epoch_service is really notifying when an epoch or epoch slice has just changed
|
||||
|
||||
let epoch_slice_duration = Duration::from_secs(10);
|
||||
let epoch_service = EpochService::try_from((epoch_slice_duration, Utc::now())).unwrap();
|
||||
let notifier = epoch_service.epoch_changes.clone();
|
||||
let counter_0 = Arc::new(AtomicU64::new(0));
|
||||
let counter = counter_0.clone();
|
||||
|
||||
let res = tokio::try_join!(
|
||||
epoch_service
|
||||
.listen_for_new_epoch()
|
||||
.map_err(|e| AppErrorExt::AppError(e)),
|
||||
// Wait for 3 epoch slices + 100 ms (to wait to receive notif + counter incr)
|
||||
tokio::time::timeout(
|
||||
epoch_slice_duration * 3 + Duration::from_millis(100),
|
||||
async move {
|
||||
loop {
|
||||
notifier.notified().await;
|
||||
debug!("[Notified] Epoch update...");
|
||||
let _v = counter.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
Ok::<(), AppErrorExt>(())
|
||||
}
|
||||
)
|
||||
.map_err(|_e| AppErrorExt::Elapsed)
|
||||
);
|
||||
assert!(matches!(res, Err(AppErrorExt::Elapsed)));
|
||||
assert_eq!(counter_0.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
use std::collections::BTreeMap;
|
||||
// std
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
// third-party
|
||||
use alloy::primitives::Address;
|
||||
use alloy::primitives::{Address, U256};
|
||||
use ark_bn254::Fr;
|
||||
use async_channel::Sender;
|
||||
use bytesize::ByteSize;
|
||||
use futures::TryFutureExt;
|
||||
use http::Method;
|
||||
use tokio::sync::{
|
||||
RwLock,
|
||||
broadcast,
|
||||
// broadcast::{Receiver, Sender},
|
||||
mpsc,
|
||||
};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tonic::{
|
||||
Request,
|
||||
Response,
|
||||
@@ -32,13 +27,12 @@ use tracing::{
|
||||
// info
|
||||
};
|
||||
// internal
|
||||
use crate::{
|
||||
error::{
|
||||
AppError,
|
||||
// RegistrationError
|
||||
},
|
||||
registry::UserRegistry,
|
||||
use crate::error::{
|
||||
AppError,
|
||||
// RegistrationError
|
||||
};
|
||||
use crate::user_db_service::{KarmaAmountExt, UserDb, UserTierInfo};
|
||||
use rln_proof::{RlnIdentifier, RlnUserIdentity};
|
||||
|
||||
pub mod prover_proto {
|
||||
|
||||
@@ -48,18 +42,14 @@ pub mod prover_proto {
|
||||
pub(crate) const FILE_DESCRIPTOR_SET: &[u8] =
|
||||
tonic::include_file_descriptor_set!("prover_descriptor");
|
||||
}
|
||||
use crate::tier::{KarmaAmount, TierLimit, TierName};
|
||||
use prover_proto::{
|
||||
RegisterUserReply, RegisterUserRequest, RlnProof, RlnProofFilter, SendTransactionReply,
|
||||
SendTransactionRequest,
|
||||
GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply, RegisterUserRequest, RlnProof,
|
||||
RlnProofFilter, SendTransactionReply, SendTransactionRequest, SetTierLimitsReply,
|
||||
SetTierLimitsRequest, Tier, UserTierInfoError, UserTierInfoResult,
|
||||
get_user_tier_info_reply::Resp,
|
||||
rln_prover_server::{RlnProver, RlnProverServer},
|
||||
};
|
||||
use rln_proof::{
|
||||
// RlnData,
|
||||
RlnIdentifier,
|
||||
RlnUserIdentity,
|
||||
// ZerokitMerkleTree,
|
||||
// compute_rln_proof_and_values,
|
||||
};
|
||||
|
||||
const PROVER_SERVICE_LIMIT_PER_CONNECTION: usize = 16;
|
||||
// Timeout for all handlers of a request
|
||||
@@ -77,14 +67,10 @@ const PROVER_SPAM_LIMIT: u64 = 10_000;
|
||||
#[derive(Debug)]
|
||||
pub struct ProverService {
|
||||
proof_sender: Sender<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
|
||||
registry: UserRegistry,
|
||||
user_db: UserDb,
|
||||
rln_identifier: Arc<RlnIdentifier>,
|
||||
message_counters: RwLock<HashMap<Address, u64>>,
|
||||
spam_limit: u64,
|
||||
broadcast_channel: (
|
||||
tokio::sync::broadcast::Sender<Vec<u8>>,
|
||||
tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||
),
|
||||
broadcast_channel: (broadcast::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>),
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
@@ -106,23 +92,18 @@ impl RlnProver for ProverService {
|
||||
return Err(Status::invalid_argument("No sender address"));
|
||||
};
|
||||
|
||||
// Update the counter as soon as possible (should help to prevent spamming...)
|
||||
let mut message_counter_guard = self.message_counters.write().await;
|
||||
let counter = *message_counter_guard
|
||||
.entry(sender)
|
||||
.and_modify(|e| *e += 1)
|
||||
.or_insert(1);
|
||||
drop(message_counter_guard);
|
||||
|
||||
let user_id = if let Some(id) = self.registry.get(&sender) {
|
||||
*id
|
||||
let user_id = if let Some(id) = self.user_db.get_user(&sender) {
|
||||
id.clone()
|
||||
} else {
|
||||
return Err(Status::not_found("Sender not registered"));
|
||||
};
|
||||
|
||||
// Update the counter as soon as possible (should help to prevent spamming...)
|
||||
let counter = self.user_db.on_new_tx(&sender).unwrap_or_default();
|
||||
|
||||
let user_identity = RlnUserIdentity {
|
||||
secret_hash: user_id.0,
|
||||
commitment: user_id.1,
|
||||
secret_hash: user_id.secret_hash,
|
||||
commitment: user_id.commitment,
|
||||
user_limit: Fr::from(self.spam_limit),
|
||||
};
|
||||
|
||||
@@ -131,7 +112,7 @@ impl RlnProver for ProverService {
|
||||
|
||||
// Send some data to one of the proof services
|
||||
self.proof_sender
|
||||
.send((user_identity, rln_identifier, counter))
|
||||
.send((user_identity, rln_identifier, counter.into()))
|
||||
.await
|
||||
.map_err(|e| Status::from_error(Box::new(e)))?;
|
||||
|
||||
@@ -173,39 +154,100 @@ impl RlnProver for ProverService {
|
||||
|
||||
Ok(Response::new(ReceiverStream::new(rx)))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GrpcProverService {
|
||||
proof_sender: Sender<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
|
||||
broadcast_channel: (broadcast::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>),
|
||||
addr: SocketAddr,
|
||||
rln_identifier: RlnIdentifier,
|
||||
// epoch_counter: Arc<AtomicI64>,
|
||||
}
|
||||
async fn get_user_tier_info(
|
||||
&self,
|
||||
request: Request<GetUserTierInfoRequest>,
|
||||
) -> Result<Response<GetUserTierInfoReply>, Status> {
|
||||
debug!("request: {:?}", request);
|
||||
|
||||
impl GrpcProverService {
|
||||
pub(crate) fn new(
|
||||
proof_sender: Sender<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
|
||||
broadcast_channel: (broadcast::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>),
|
||||
addr: SocketAddr,
|
||||
rln_identifier: RlnIdentifier,
|
||||
/* epoch_counter: Arc<AtomicI64> */
|
||||
) -> Self {
|
||||
Self {
|
||||
proof_sender,
|
||||
broadcast_channel,
|
||||
addr,
|
||||
rln_identifier,
|
||||
// epoch_counter,
|
||||
let req = request.into_inner();
|
||||
|
||||
let user = if let Some(user) = req.user {
|
||||
if let Ok(user) = Address::try_from(user.value.as_slice()) {
|
||||
user
|
||||
} else {
|
||||
return Err(Status::invalid_argument("Invalid user address"));
|
||||
}
|
||||
} else {
|
||||
return Err(Status::invalid_argument("No user address"));
|
||||
};
|
||||
|
||||
// TODO: SC call
|
||||
struct MockKarmaSc {}
|
||||
|
||||
impl KarmaAmountExt for MockKarmaSc {
|
||||
async fn karma_amount(&self, _address: &Address) -> U256 {
|
||||
U256::from(10)
|
||||
}
|
||||
}
|
||||
let tier_info = self.user_db.user_tier_info(&user, MockKarmaSc {}).await;
|
||||
|
||||
match tier_info {
|
||||
Ok(tier_info) => Ok(Response::new(GetUserTierInfoReply {
|
||||
resp: Some(Resp::Res(tier_info.into())),
|
||||
})),
|
||||
Err(e) => Ok(Response::new(GetUserTierInfoReply {
|
||||
resp: Some(Resp::Error(e.into())),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_tier_limits(
|
||||
&self,
|
||||
request: Request<SetTierLimitsRequest>,
|
||||
) -> Result<Response<SetTierLimitsReply>, Status> {
|
||||
debug!("request: {:?}", request);
|
||||
|
||||
let request = request.into_inner();
|
||||
let tier_limits: Option<BTreeMap<KarmaAmount, (TierLimit, TierName)>> = request
|
||||
.karma_amounts
|
||||
.iter()
|
||||
.zip(request.tiers)
|
||||
.map(|(k, tier)| {
|
||||
let karma_amount = U256::try_from_le_slice(k.value.as_slice())?;
|
||||
let karma_amount = KarmaAmount::from(karma_amount);
|
||||
let tier_info = (
|
||||
TierLimit::from(tier.quota),
|
||||
TierName::from(tier.name.clone()),
|
||||
);
|
||||
Some((karma_amount, tier_info))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if tier_limits.is_none() {
|
||||
return Err(Status::invalid_argument("Invalid tier limits"));
|
||||
}
|
||||
|
||||
// unwrap safe - just tested if None
|
||||
let reply = match self.user_db.on_new_tier_limits(tier_limits.unwrap()) {
|
||||
Ok(_) => SetTierLimitsReply {
|
||||
status: true,
|
||||
error: "".to_string(),
|
||||
},
|
||||
Err(e) => SetTierLimitsReply {
|
||||
status: false,
|
||||
error: e.to_string(),
|
||||
},
|
||||
};
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GrpcProverService {
|
||||
pub proof_sender: Sender<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
|
||||
pub broadcast_channel: (broadcast::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>),
|
||||
pub addr: SocketAddr,
|
||||
pub rln_identifier: RlnIdentifier,
|
||||
pub user_db: UserDb,
|
||||
}
|
||||
|
||||
impl GrpcProverService {
|
||||
pub(crate) async fn serve(&self) -> Result<(), AppError> {
|
||||
let prover_service = ProverService {
|
||||
proof_sender: self.proof_sender.clone(),
|
||||
registry: Default::default(),
|
||||
user_db: self.user_db.clone(),
|
||||
rln_identifier: Arc::new(self.rln_identifier.clone()),
|
||||
message_counters: Default::default(),
|
||||
spam_limit: PROVER_SPAM_LIMIT,
|
||||
broadcast_channel: (
|
||||
self.broadcast_channel.0.clone(),
|
||||
@@ -262,6 +304,36 @@ impl GrpcProverService {
|
||||
}
|
||||
}
|
||||
|
||||
/// UserTierInfo to UserTierInfoResult (Grpc message) conversion
|
||||
impl From<UserTierInfo> for UserTierInfoResult {
|
||||
fn from(tier_info: UserTierInfo) -> Self {
|
||||
let mut res = UserTierInfoResult {
|
||||
current_epoch: tier_info.current_epoch.into(),
|
||||
current_epoch_slice: tier_info.current_epoch_slice.into(),
|
||||
tx_count: tier_info.epoch_tx_count,
|
||||
tier: None,
|
||||
};
|
||||
|
||||
if tier_info.tier_name.is_some() && tier_info.tier_limit.is_some() {
|
||||
res.tier = Some(Tier {
|
||||
name: tier_info.tier_name.unwrap().into(),
|
||||
quota: tier_info.tier_limit.unwrap().into(),
|
||||
})
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// UserTierInfoError to UserTierInfoError (Grpc message) conversion
|
||||
impl From<crate::user_db_service::UserTierInfoError> for UserTierInfoError {
|
||||
fn from(value: crate::user_db_service::UserTierInfoError) -> Self {
|
||||
UserTierInfoError {
|
||||
message: value.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::grpc_service::prover_proto::Address;
|
||||
@@ -270,6 +342,7 @@ mod tests {
|
||||
const MAX_ADDRESS_SIZE_BYTES: usize = 20;
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
#[should_panic]
|
||||
fn test_address_size_limit() {
|
||||
// Check if an invalid address can be encoded (as Address grpc type)
|
||||
|
||||
@@ -6,17 +6,14 @@ mod grpc_service;
|
||||
mod proof_service;
|
||||
mod registry;
|
||||
mod registry_listener;
|
||||
mod tier;
|
||||
mod user_db_service;
|
||||
|
||||
// std
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
// third-party
|
||||
use alloy::primitives::address;
|
||||
use chrono::{DateTime, Utc};
|
||||
// use chrono::{
|
||||
// DateTime,
|
||||
// Utc
|
||||
// };
|
||||
use clap::Parser;
|
||||
use rln_proof::RlnIdentifier;
|
||||
use tokio::task::JoinSet;
|
||||
@@ -32,7 +29,7 @@ use crate::args::AppArgs;
|
||||
use crate::epoch_service::EpochService;
|
||||
use crate::grpc_service::GrpcProverService;
|
||||
use crate::proof_service::ProofService;
|
||||
use crate::registry_listener::RegistryListener;
|
||||
use crate::user_db_service::UserDbService;
|
||||
|
||||
const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier";
|
||||
const PROOF_SERVICE_COUNT: u8 = 8;
|
||||
@@ -53,15 +50,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
// Smart contract
|
||||
|
||||
let uniswap_token_address = address!("1f9840a85d5aF5bf1D1762F925BDADdC4201F984");
|
||||
let event = "Transfer(address,address,uint256)";
|
||||
let registry_listener =
|
||||
RegistryListener::new(app_args.rpc_url.as_str(), uniswap_token_address, event);
|
||||
// let uniswap_token_address = address!("1f9840a85d5aF5bf1D1762F925BDADdC4201F984");
|
||||
// let event = "Transfer(address,address,uint256)";
|
||||
// let registry_listener =
|
||||
// RegistryListener::new(app_args.rpc_url.as_str(), uniswap_token_address, event);
|
||||
|
||||
// Epoch
|
||||
let epoch_service = EpochService::try_from((Duration::from_secs(60 * 2), GENESIS))
|
||||
.expect("Failed to create epoch service");
|
||||
|
||||
// User db service
|
||||
let user_db_service = UserDbService::new(
|
||||
epoch_service.epoch_changes.clone(),
|
||||
epoch_service.current_epoch.clone(),
|
||||
);
|
||||
|
||||
// proof service
|
||||
// FIXME: bound
|
||||
let (tx, rx) = tokio::sync::broadcast::channel(2);
|
||||
@@ -73,13 +76,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let rln_identifier = RlnIdentifier::new(RLN_IDENTIFIER_NAME);
|
||||
let addr = SocketAddr::new(app_args.ip, app_args.port);
|
||||
debug!("Listening on: {}", addr);
|
||||
let prover_service = GrpcProverService::new(
|
||||
let prover_grpc_service = GrpcProverService {
|
||||
proof_sender,
|
||||
(tx.clone(), rx),
|
||||
broadcast_channel: (tx.clone(), rx),
|
||||
addr,
|
||||
rln_identifier,
|
||||
// epoch_service.current_epoch.clone()
|
||||
);
|
||||
user_db: user_db_service.get_user_db(),
|
||||
};
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for _i in 0..PROOF_SERVICE_COUNT {
|
||||
@@ -94,7 +97,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
// set.spawn(async move { registry_listener.listen().await });
|
||||
set.spawn(async move { epoch_service.listen_for_new_epoch().await });
|
||||
set.spawn(async move { prover_service.serve().await });
|
||||
set.spawn(async move { user_db_service.listen_for_epoch_changes().await });
|
||||
set.spawn(async move { prover_grpc_service.serve().await });
|
||||
|
||||
let _ = set.join_all().await;
|
||||
Ok(())
|
||||
|
||||
@@ -28,6 +28,7 @@ enum ProofGenerationError {
|
||||
Misc(String),
|
||||
}
|
||||
|
||||
/// A service to generate a RLN proof (and then to broadcast it)
|
||||
#[derive(Debug)]
|
||||
pub struct ProofService {
|
||||
receiver: Receiver<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use alloy::primitives::Address;
|
||||
use ark_bn254::Fr;
|
||||
use dashmap::DashMap;
|
||||
use dashmap::mapref::one::Ref;
|
||||
use rln::protocol::keygen;
|
||||
// use alloy::primitives::Address;
|
||||
// use ark_bn254::Fr;
|
||||
// use dashmap::DashMap;
|
||||
// use dashmap::mapref::one::Ref;
|
||||
// use rln::protocol::keygen;
|
||||
|
||||
/*
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UserRegistry {
|
||||
inner: DashMap<Address, (Fr, Fr)>,
|
||||
@@ -47,3 +48,4 @@ mod tests {
|
||||
assert!(reg.get(&address).is_some());
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
60
prover/src/tier.rs
Normal file
60
prover/src/tier.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
// third-party
|
||||
use alloy::primitives::U256;
|
||||
use derive_more::{From, Into};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, From)]
|
||||
pub struct KarmaAmount(U256);
|
||||
|
||||
impl KarmaAmount {
|
||||
pub(crate) const ZERO: KarmaAmount = KarmaAmount(U256::ZERO);
|
||||
}
|
||||
|
||||
impl From<u64> for KarmaAmount {
|
||||
fn from(value: u64) -> Self {
|
||||
Self(U256::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, From, Into)]
|
||||
pub struct TierLimit(u64);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, From, Into)]
|
||||
pub struct TierName(String);
|
||||
|
||||
impl From<&str> for TierName {
|
||||
fn from(value: &str) -> Self {
|
||||
Self(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub static TIER_LIMITS: LazyLock<BTreeMap<KarmaAmount, (TierLimit, TierName)>> =
|
||||
LazyLock::new(|| {
|
||||
BTreeMap::from([
|
||||
(
|
||||
KarmaAmount::from(10),
|
||||
(TierLimit(6), TierName::from("Basic")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(50),
|
||||
(TierLimit(120), TierName::from("Active")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(100),
|
||||
(TierLimit(720), TierName::from("Regular")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(500),
|
||||
(TierLimit(14440), TierName::from("Regular")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(1000),
|
||||
(TierLimit(86400), TierName::from("Power User")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(5000),
|
||||
(TierLimit(432000), TierName::from("S-Tier")),
|
||||
),
|
||||
])
|
||||
});
|
||||
546
prover/src/user_db_service.rs
Normal file
546
prover/src/user_db_service.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::ops::Bound::Included;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
// third-party
|
||||
use alloy::primitives::{Address, U256};
|
||||
use derive_more::{Add, From, Into};
|
||||
use parking_lot::RwLock;
|
||||
use rln::protocol::keygen;
|
||||
use scc::HashMap;
|
||||
use tokio::sync::Notify;
|
||||
use tracing::debug;
|
||||
// internal
|
||||
use crate::epoch_service::{Epoch, EpochSlice};
|
||||
use crate::error::AppError;
|
||||
use crate::tier::{KarmaAmount, TIER_LIMITS, TierLimit, TierName};
|
||||
use rln_proof::RlnUserIdentity;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct UserRegistry {
|
||||
inner: HashMap<Address, RlnUserIdentity>,
|
||||
}
|
||||
impl UserRegistry {
|
||||
fn register(&self, address: Address) -> bool {
|
||||
let (identity_secret_hash, id_commitment) = keygen();
|
||||
self.inner
|
||||
.insert(
|
||||
address,
|
||||
RlnUserIdentity::from((identity_secret_hash, id_commitment)),
|
||||
)
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn has_user(&self, address: &Address) -> bool {
|
||||
self.inner.contains(address)
|
||||
}
|
||||
|
||||
fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
|
||||
self.inner.get(address).map(|entry| entry.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
|
||||
pub(crate) struct EpochCounter(u64);
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
|
||||
pub(crate) struct EpochSliceCounter(u64);
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct TxRegistry {
|
||||
inner: HashMap<Address, (EpochCounter, EpochSliceCounter)>,
|
||||
}
|
||||
|
||||
impl Deref for TxRegistry {
|
||||
type Target = HashMap<Address, (EpochCounter, EpochSliceCounter)>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl TxRegistry {
|
||||
/// Update the transaction counter for the given address
|
||||
///
|
||||
/// If incr_value is None, the counter will be incremented by 1
|
||||
/// If incr_value is Some(x), the counter will be incremented by x
|
||||
///
|
||||
/// Returns the new value of the counter
|
||||
pub fn incr_counter(&self, address: &Address, incr_value: Option<u64>) -> EpochSliceCounter {
|
||||
let incr_value = incr_value.unwrap_or(1);
|
||||
let mut entry = self.inner.entry(*address).or_default();
|
||||
*entry = (
|
||||
entry.0 + EpochCounter(incr_value),
|
||||
entry.1 + EpochSliceCounter(incr_value),
|
||||
);
|
||||
entry.1
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct UserTierInfo {
|
||||
pub(crate) current_epoch: Epoch,
|
||||
pub(crate) current_epoch_slice: EpochSlice,
|
||||
pub(crate) epoch_tx_count: u64,
|
||||
pub(crate) epoch_slice_tx_count: u64,
|
||||
karma_amount: U256,
|
||||
pub(crate) tier_name: Option<TierName>,
|
||||
pub(crate) tier_limit: Option<TierLimit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UserTierInfoError {
|
||||
#[error("User {0} not registered")]
|
||||
NotRegistered(Address),
|
||||
}
|
||||
|
||||
pub trait KarmaAmountExt {
|
||||
async fn karma_amount(&self, address: &Address) -> U256;
|
||||
}
|
||||
|
||||
/// User registration + tx counters + tier limits storage
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserDb {
|
||||
user_registry: Arc<UserRegistry>,
|
||||
tx_registry: Arc<TxRegistry>,
|
||||
tier_limits: Arc<RwLock<BTreeMap<KarmaAmount, (TierLimit, TierName)>>>,
|
||||
tier_limits_next: Arc<RwLock<BTreeMap<KarmaAmount, (TierLimit, TierName)>>>,
|
||||
epoch_store: Arc<RwLock<(Epoch, EpochSlice)>>,
|
||||
}
|
||||
|
||||
impl UserDb {
|
||||
fn on_new_epoch(&self) {
|
||||
self.tx_registry.clear()
|
||||
}
|
||||
|
||||
fn on_new_epoch_slice(&self) {
|
||||
self.tx_registry.retain(|_a, v| {
|
||||
*v = (v.0, Default::default());
|
||||
true
|
||||
});
|
||||
|
||||
let tier_limits_next_has_updates = !self.tier_limits_next.read().is_empty();
|
||||
if tier_limits_next_has_updates {
|
||||
let mut guard = self.tier_limits_next.write();
|
||||
// mem::take will clear the BTreeMap in tier_limits_next
|
||||
let new_tier_limits = std::mem::take(&mut *guard);
|
||||
debug!("Installing new tier limits: {:?}", new_tier_limits);
|
||||
*self.tier_limits.write() = new_tier_limits;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
|
||||
self.user_registry.get_user(address)
|
||||
}
|
||||
|
||||
pub(crate) fn on_new_tx(&self, address: &Address) -> Option<EpochSliceCounter> {
|
||||
if self.user_registry.has_user(address) {
|
||||
Some(self.tx_registry.incr_counter(address, None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn on_new_tier_limits(
|
||||
&self,
|
||||
tier_limits: BTreeMap<KarmaAmount, (TierLimit, TierName)>,
|
||||
) -> Result<(), SetTierLimitsError> {
|
||||
#[derive(Default)]
|
||||
struct Context<'a> {
|
||||
tier_names: HashSet<TierName>,
|
||||
prev_karma_amount: Option<&'a KarmaAmount>,
|
||||
prev_tier_limit: Option<&'a TierLimit>,
|
||||
i: usize,
|
||||
}
|
||||
|
||||
let _context = tier_limits.iter().try_fold(
|
||||
Context::default(),
|
||||
|mut state, (karma_amount, (tier_limit, tier_name))| {
|
||||
if karma_amount <= state.prev_karma_amount.unwrap_or(&KarmaAmount::ZERO) {
|
||||
return Err(SetTierLimitsError::InvalidKarmaAmount);
|
||||
}
|
||||
|
||||
if tier_limit <= state.prev_tier_limit.unwrap_or(&TierLimit::from(0)) {
|
||||
return Err(SetTierLimitsError::InvalidTierLimit);
|
||||
}
|
||||
|
||||
if state.tier_names.contains(tier_name) {
|
||||
return Err(SetTierLimitsError::NonUniqueTierName);
|
||||
}
|
||||
|
||||
state.prev_karma_amount = Some(karma_amount);
|
||||
state.prev_tier_limit = Some(tier_limit);
|
||||
state.tier_names.insert(tier_name.clone());
|
||||
state.i += 1;
|
||||
Ok(state)
|
||||
},
|
||||
)?;
|
||||
|
||||
*self.tier_limits_next.write() = tier_limits;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get user tier info
|
||||
pub(crate) async fn user_tier_info<KSC: KarmaAmountExt>(
|
||||
&self,
|
||||
address: &Address,
|
||||
karma_sc: KSC,
|
||||
) -> Result<UserTierInfo, UserTierInfoError> {
|
||||
if self.user_registry.has_user(address) {
|
||||
let (epoch_tx_count, epoch_slice_tx_count) = self
|
||||
.tx_registry
|
||||
.get(address)
|
||||
.map(|ref_v| (ref_v.0, ref_v.1))
|
||||
.unwrap_or_default();
|
||||
|
||||
let karma_amount = karma_sc.karma_amount(address).await;
|
||||
let guard = self.tier_limits.read();
|
||||
let range_res = guard.range((
|
||||
Included(&KarmaAmount::ZERO),
|
||||
Included(&KarmaAmount::from(karma_amount)),
|
||||
));
|
||||
let tier_info: Option<(TierLimit, TierName)> =
|
||||
range_res.into_iter().last().map(|o| o.1.clone());
|
||||
drop(guard);
|
||||
|
||||
let user_tier_info = {
|
||||
let (current_epoch, current_epoch_slice) = *self.epoch_store.read();
|
||||
let mut t = UserTierInfo {
|
||||
current_epoch,
|
||||
current_epoch_slice,
|
||||
epoch_tx_count: epoch_tx_count.into(),
|
||||
epoch_slice_tx_count: epoch_slice_tx_count.into(),
|
||||
karma_amount,
|
||||
tier_name: None,
|
||||
tier_limit: None,
|
||||
};
|
||||
if let Some((tier_limit, tier_name)) = tier_info {
|
||||
t.tier_name = Some(tier_name);
|
||||
t.tier_limit = Some(tier_limit);
|
||||
}
|
||||
t
|
||||
};
|
||||
|
||||
Ok(user_tier_info)
|
||||
} else {
|
||||
Err(UserTierInfoError::NotRegistered(*address))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SetTierLimitsError {
|
||||
#[error("Invalid Karma amount (must be increasing)")]
|
||||
InvalidKarmaAmount,
|
||||
#[error("Invalid Tier limit (must be increasing)")]
|
||||
InvalidTierLimit,
|
||||
#[error("Non unique Tier name")]
|
||||
NonUniqueTierName,
|
||||
}
|
||||
|
||||
/// Async service to update a UserDb on epoch changes
|
||||
#[derive(Debug)]
|
||||
pub struct UserDbService {
|
||||
user_db: UserDb,
|
||||
epoch_changes: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl UserDbService {
|
||||
pub(crate) fn new(
|
||||
epoch_changes_notifier: Arc<Notify>,
|
||||
epoch_store: Arc<RwLock<(Epoch, EpochSlice)>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
user_db: UserDb {
|
||||
user_registry: Default::default(),
|
||||
tx_registry: Default::default(),
|
||||
tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())),
|
||||
tier_limits_next: Arc::new(Default::default()),
|
||||
epoch_store,
|
||||
},
|
||||
epoch_changes: epoch_changes_notifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_user_db(&self) -> UserDb {
|
||||
self.user_db.clone()
|
||||
}
|
||||
|
||||
pub async fn listen_for_epoch_changes(&self) -> Result<(), AppError> {
|
||||
let (mut current_epoch, mut current_epoch_slice) = *self.user_db.epoch_store.read();
|
||||
|
||||
loop {
|
||||
self.epoch_changes.notified().await;
|
||||
let (new_epoch, new_epoch_slice) = *self.user_db.epoch_store.read();
|
||||
debug!(
|
||||
"new epoch: {:?}, new epoch slice: {:?}",
|
||||
new_epoch, new_epoch_slice
|
||||
);
|
||||
self.update_on_epoch_changes(
|
||||
&mut current_epoch,
|
||||
new_epoch,
|
||||
&mut current_epoch_slice,
|
||||
new_epoch_slice,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal - used by listen_for_epoch_changes
|
||||
fn update_on_epoch_changes(
|
||||
&self,
|
||||
current_epoch: &mut Epoch,
|
||||
new_epoch: Epoch,
|
||||
current_epoch_slice: &mut EpochSlice,
|
||||
new_epoch_slice: EpochSlice,
|
||||
) {
|
||||
if new_epoch > *current_epoch {
|
||||
self.user_db.on_new_epoch()
|
||||
} else if new_epoch_slice > *current_epoch_slice {
|
||||
self.user_db.on_new_epoch_slice()
|
||||
}
|
||||
|
||||
*current_epoch = new_epoch;
|
||||
*current_epoch_slice = new_epoch_slice;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloy::primitives::address;
|
||||
use claims::assert_err;
|
||||
|
||||
struct MockKarmaSc {}
|
||||
|
||||
impl KarmaAmountExt for MockKarmaSc {
|
||||
async fn karma_amount(&self, _address: &Address) -> U256 {
|
||||
U256::from(10)
|
||||
}
|
||||
}
|
||||
|
||||
const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
|
||||
const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0");
|
||||
|
||||
struct MockKarmaSc2 {}
|
||||
|
||||
impl KarmaAmountExt for MockKarmaSc2 {
|
||||
async fn karma_amount(&self, address: &Address) -> U256 {
|
||||
if address == &ADDR_1 {
|
||||
U256::from(10)
|
||||
} else if address == &ADDR_2 {
|
||||
U256::from(2000)
|
||||
} else {
|
||||
U256::ZERO
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_incr_tx_counter() {
|
||||
let user_db = UserDb {
|
||||
user_registry: Default::default(),
|
||||
tx_registry: Default::default(),
|
||||
tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())),
|
||||
tier_limits_next: Arc::new(Default::default()),
|
||||
epoch_store: Arc::new(RwLock::new(Default::default())),
|
||||
};
|
||||
let addr = Address::new([0; 20]);
|
||||
|
||||
// Try to update tx counter without registering first
|
||||
assert_eq!(user_db.on_new_tx(&addr), None);
|
||||
let tier_info = user_db.user_tier_info(&addr, MockKarmaSc {}).await;
|
||||
// User is not registered -> no tier info
|
||||
assert!(matches!(
|
||||
tier_info,
|
||||
Err(UserTierInfoError::NotRegistered(_))
|
||||
));
|
||||
// Register user + update tx counter
|
||||
user_db.user_registry.register(addr);
|
||||
assert_eq!(user_db.on_new_tx(&addr), Some(EpochSliceCounter(1)));
|
||||
let tier_info = user_db.user_tier_info(&addr, MockKarmaSc {}).await.unwrap();
|
||||
assert_eq!(tier_info.epoch_tx_count, 1);
|
||||
assert_eq!(tier_info.epoch_slice_tx_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_on_epoch_changes() {
|
||||
let mut epoch = Epoch::from(11);
|
||||
let mut epoch_slice = EpochSlice::from(42);
|
||||
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
|
||||
let user_db_service = UserDbService::new(Default::default(), epoch_store);
|
||||
let user_db = user_db_service.get_user_db();
|
||||
|
||||
let addr_1_tx_count = 2;
|
||||
let addr_2_tx_count = 820;
|
||||
user_db.user_registry.register(ADDR_1);
|
||||
user_db
|
||||
.tx_registry
|
||||
.incr_counter(&ADDR_1, Some(addr_1_tx_count));
|
||||
user_db.user_registry.register(ADDR_2);
|
||||
user_db
|
||||
.tx_registry
|
||||
.incr_counter(&ADDR_2, Some(addr_2_tx_count));
|
||||
|
||||
// incr epoch slice (42 -> 43)
|
||||
{
|
||||
let new_epoch = epoch;
|
||||
let new_epoch_slice = epoch_slice + 1;
|
||||
user_db_service.update_on_epoch_changes(
|
||||
&mut epoch,
|
||||
new_epoch,
|
||||
&mut epoch_slice,
|
||||
new_epoch_slice,
|
||||
);
|
||||
let addr_1_tier_info = user_db
|
||||
.user_tier_info(&ADDR_1, MockKarmaSc2 {})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(addr_1_tier_info.epoch_tx_count, addr_1_tx_count);
|
||||
assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0);
|
||||
assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic")));
|
||||
|
||||
let addr_2_tier_info = user_db
|
||||
.user_tier_info(&ADDR_2, MockKarmaSc2 {})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(addr_2_tier_info.epoch_tx_count, addr_2_tx_count);
|
||||
assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0);
|
||||
assert_eq!(
|
||||
addr_2_tier_info.tier_name,
|
||||
Some(TierName::from("Power User"))
|
||||
);
|
||||
}
|
||||
|
||||
// incr epoch (11 -> 12, epoch slice reset)
|
||||
{
|
||||
let new_epoch = epoch + 1;
|
||||
let new_epoch_slice = EpochSlice::from(0);
|
||||
user_db_service.update_on_epoch_changes(
|
||||
&mut epoch,
|
||||
new_epoch,
|
||||
&mut epoch_slice,
|
||||
new_epoch_slice,
|
||||
);
|
||||
let addr_1_tier_info = user_db
|
||||
.user_tier_info(&ADDR_1, MockKarmaSc2 {})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(addr_1_tier_info.epoch_tx_count, 0);
|
||||
assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0);
|
||||
assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic")));
|
||||
|
||||
let addr_2_tier_info = user_db
|
||||
.user_tier_info(&ADDR_2, MockKarmaSc2 {})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(addr_2_tier_info.epoch_tx_count, 0);
|
||||
assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0);
|
||||
assert_eq!(
|
||||
addr_2_tier_info.tier_name,
|
||||
Some(TierName::from("Power User"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tracing_test::traced_test]
|
||||
fn test_set_tier_limits() {
|
||||
// Check if we can update tier limits (and it updates after an epoch slice change)
|
||||
|
||||
let mut epoch = Epoch::from(11);
|
||||
let mut epoch_slice = EpochSlice::from(42);
|
||||
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
|
||||
let user_db_service = UserDbService::new(Default::default(), epoch_store);
|
||||
let user_db = user_db_service.get_user_db();
|
||||
let tier_limits_original = user_db.tier_limits.read().clone();
|
||||
|
||||
let tier_limits = BTreeMap::from([
|
||||
(
|
||||
KarmaAmount::from(100),
|
||||
(TierLimit::from(100), TierName::from("Basic")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(200),
|
||||
(TierLimit::from(200), TierName::from("Power User")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(300),
|
||||
(TierLimit::from(300), TierName::from("Elite User")),
|
||||
),
|
||||
]);
|
||||
|
||||
user_db.on_new_tier_limits(tier_limits.clone()).unwrap();
|
||||
// Check it is not yet installed
|
||||
assert_ne!(*user_db.tier_limits.read(), tier_limits);
|
||||
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
|
||||
assert_eq!(*user_db.tier_limits_next.read(), tier_limits);
|
||||
|
||||
let new_epoch = epoch;
|
||||
let new_epoch_slice = epoch_slice + 1;
|
||||
user_db_service.update_on_epoch_changes(
|
||||
&mut epoch,
|
||||
new_epoch,
|
||||
&mut epoch_slice,
|
||||
new_epoch_slice,
|
||||
);
|
||||
|
||||
// Should be installed now
|
||||
assert_eq!(*user_db.tier_limits.read(), tier_limits);
|
||||
// And the tier_limits_next field is expected to be empty
|
||||
assert_eq!(*user_db.tier_limits_next.read(), BTreeMap::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tracing_test::traced_test]
|
||||
fn test_set_invalid_tier_limits() {
|
||||
// Check we cannot update with invalid tier limits
|
||||
|
||||
let epoch = Epoch::from(11);
|
||||
let epoch_slice = EpochSlice::from(42);
|
||||
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
|
||||
let user_db_service = UserDbService::new(Default::default(), epoch_store);
|
||||
let user_db = user_db_service.get_user_db();
|
||||
|
||||
let tier_limits_original = user_db.tier_limits.read().clone();
|
||||
|
||||
{
|
||||
let tier_limits = BTreeMap::from([
|
||||
(
|
||||
KarmaAmount::from(100),
|
||||
(TierLimit::from(100), TierName::from("Basic")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(200),
|
||||
(TierLimit::from(200), TierName::from("Power User")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(199),
|
||||
(TierLimit::from(300), TierName::from("Elite User")),
|
||||
),
|
||||
]);
|
||||
|
||||
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
|
||||
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
|
||||
}
|
||||
|
||||
{
|
||||
let tier_limits = BTreeMap::from([
|
||||
(
|
||||
KarmaAmount::from(100),
|
||||
(TierLimit::from(100), TierName::from("Basic")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(200),
|
||||
(TierLimit::from(200), TierName::from("Power User")),
|
||||
),
|
||||
(
|
||||
KarmaAmount::from(300),
|
||||
(TierLimit::from(300), TierName::from("Basic")),
|
||||
),
|
||||
]);
|
||||
|
||||
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
|
||||
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,23 +4,9 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
# clap = { version = "4.5.37", features = ["derive"] }
|
||||
# tonic = { version = "0.13", features = ["gzip"] }
|
||||
# tonic-reflection = "0.13"
|
||||
# prost = "0.13"
|
||||
# tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
# tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
# tracing = "0.1.41"
|
||||
# tracing-test = "0.2.5"
|
||||
# alloy = { version = "0.15", features = ["full"] }
|
||||
# thiserror = "2.0"
|
||||
# futures = "0.3"
|
||||
rln = { git = "https://github.com/vacp2p/zerokit", package = "rln", features = ["default"] }
|
||||
zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] }
|
||||
ark-bn254 = { version = "0.5", features = ["std"] }
|
||||
serde_json = "1.0"
|
||||
# dashmap = "6.1.0"
|
||||
# bytesize = "2.0.1"
|
||||
ark-groth16 = "*"
|
||||
ark-relations = "*"
|
||||
ark-serialize = "0.5.0"
|
||||
|
||||
@@ -12,12 +12,23 @@ use rln::protocol::{
|
||||
};
|
||||
|
||||
/// A RLN user identity & limit
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RlnUserIdentity {
|
||||
pub commitment: Fr,
|
||||
pub secret_hash: Fr,
|
||||
pub user_limit: Fr,
|
||||
}
|
||||
|
||||
impl From<(Fr, Fr)> for RlnUserIdentity {
|
||||
fn from((commitment, secret_hash): (Fr, Fr)) -> Self {
|
||||
Self {
|
||||
commitment,
|
||||
secret_hash,
|
||||
user_limit: Fr::from(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RLN info for a channel / group
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RlnIdentifier {
|
||||
|
||||
Reference in New Issue
Block a user