Add UserRocksDb - initial attemp to rewrite UserDb with persistent st… (#10)

* Use RocksDb + PmTree for UserDb persistent storage
This commit is contained in:
Sydhds
2025-06-27 15:41:41 +02:00
committed by GitHub
parent cbb058d330
commit 9f4027ed2b
23 changed files with 2174 additions and 964 deletions

188
Cargo.lock generated
View File

@@ -331,7 +331,7 @@ dependencies = [
"proptest",
"rand 0.9.1",
"ruint",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"sha3",
"tiny-keccak",
@@ -1272,6 +1272,26 @@ version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"lazy_static",
"lazycell",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
]
[[package]]
name = "bit-set"
version = "0.8.0"
@@ -1390,6 +1410,16 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3c8f83209414aacf0eeae3cf730b18d6981697fba62f200fcfb92b9f082acba"
[[package]]
name = "bzip2-sys"
version = "0.1.13+1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14"
dependencies = [
"cc",
"pkg-config",
]
[[package]]
name = "c-kzg"
version = "2.1.1"
@@ -1417,9 +1447,20 @@ version = "1.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04da6a0d40b948dfc4fa8f5bbf402b0fc1a64a28dbf7d12ffd683550f2c1b63a"
dependencies = [
"jobserver",
"libc",
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom 7.1.3",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
@@ -1474,6 +1515,17 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18"
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "4.5.40"
@@ -2704,6 +2756,16 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jobserver"
version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a"
dependencies = [
"getrandom 0.3.2",
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.77"
@@ -2753,18 +2815,59 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.172"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
[[package]]
name = "libloading"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if",
"windows-targets 0.53.0",
]
[[package]]
name = "libm"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72"
[[package]]
name = "librocksdb-sys"
version = "0.17.2+9.10.0"
source = "git+https://github.com/tillrohrmann/rust-rocksdb?branch=issues%2F836#9c2482629ac4a3f73e622fc6be1b6458b6bae8cc"
dependencies = [
"bindgen",
"bzip2-sys",
"cc",
"libc",
"libz-sys",
"lz4-sys",
"zstd-sys",
]
[[package]]
name = "libz-sys"
version = "1.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.9.4"
@@ -2802,6 +2905,16 @@ dependencies = [
"hashbrown 0.15.2",
]
[[package]]
name = "lz4-sys"
version = "1.11.1+lz4-1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "macro-string"
version = "0.1.4"
@@ -2852,6 +2965,12 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.7.4"
@@ -2904,6 +3023,25 @@ dependencies = [
"tempfile",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "nom"
version = "8.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405"
dependencies = [
"memchr",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@@ -3377,7 +3515,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [
"heck",
"itertools 0.10.5",
"itertools 0.13.0",
"log",
"multimap",
"once_cell",
@@ -3397,7 +3535,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [
"anyhow",
"itertools 0.10.5",
"itertools 0.13.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -3709,6 +3847,15 @@ dependencies = [
"rustc-hex",
]
[[package]]
name = "rocksdb"
version = "0.23.0"
source = "git+https://github.com/tillrohrmann/rust-rocksdb?branch=issues%2F836#9c2482629ac4a3f73e622fc6be1b6458b6bae8cc"
dependencies = [
"libc",
"librocksdb-sys",
]
[[package]]
name = "ruint"
version = "1.15.0"
@@ -3748,6 +3895,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
@@ -4217,6 +4370,7 @@ version = "0.1.0"
dependencies = [
"alloy",
"ark-bn254",
"ark-ff 0.5.0",
"ark-groth16",
"ark-serialize 0.5.0",
"async-channel",
@@ -4230,16 +4384,19 @@ dependencies = [
"derive_more",
"futures",
"http",
"nom 8.0.0",
"num-bigint",
"parking_lot 0.12.4",
"prost",
"rand 0.8.5",
"rln",
"rln_proof",
"rocksdb",
"scc",
"serde",
"serde_json",
"smart_contract",
"tempfile",
"thiserror 2.0.12",
"tokio",
"tonic",
@@ -4251,6 +4408,7 @@ dependencies = [
"tracing-subscriber 0.3.19",
"tracing-test",
"url",
"zerokit_utils",
]
[[package]]
@@ -4349,9 +4507,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "tempfile"
version = "3.19.1"
version = "3.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf"
checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1"
dependencies = [
"fastrand",
"getrandom 0.3.2",
@@ -4915,6 +5073,15 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "vacp2p_pmtree"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632293f506ca10d412dbe1d427295317b4c794fa9ddfd66fbd2fa971de88c1f6"
dependencies = [
"rayon",
]
[[package]]
name = "valuable"
version = "0.1.1"
@@ -5521,6 +5688,7 @@ dependencies = [
"serde_json",
"sled",
"thiserror 2.0.12",
"vacp2p_pmtree",
]
[[package]]
@@ -5544,3 +5712,13 @@ dependencies = [
"quote",
"syn 2.0.100",
]
[[package]]
name = "zstd-sys"
version = "2.0.15+zstd.1.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237"
dependencies = [
"cc",
"pkg-config",
]

View File

@@ -9,6 +9,7 @@ resolver = "2"
ark-bn254 = { version = "0.5", features = ["std"] }
ark-serialize = "0.5"
ark-groth16 = "0.5"
ark-ff = "0.5"
url = "2.5.4"
alloy = { version = "1.0", features = ["getrandom", "sol-types", "contract", "provider-ws"] }
async-trait = "0.1"

View File

@@ -211,7 +211,7 @@ message UserTierInfoResult {
message Tier {
string name = 1;
uint64 quota = 2;
uint32 quota = 2;
}
/*

View File

@@ -16,9 +16,9 @@ tracing-test = "0.2.5"
alloy.workspace = true
thiserror = "2.0"
futures = "0.3"
rln = { git = "https://github.com/vacp2p/zerokit", default-features = false }
ark-bn254.workspace = true
ark-serialize.workspace = true
ark-ff.workspace = true
dashmap = "6.1.0"
scc = "2.3"
bytesize = "2.0.1"
@@ -34,6 +34,12 @@ num-bigint = "0.4"
async-trait.workspace = true
serde = { version="1", features = ["derive"] }
serde_json = "1.0"
# rocksdb = "0.23"
rocksdb = { git = "https://github.com/tillrohrmann/rust-rocksdb", branch="issues/836" }
nom = "8.0"
claims = "0.8"
rln = { git = "https://github.com/vacp2p/zerokit", features = ["pmtree-ft"] }
zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] }
rln_proof = { path = "../rln_proof" }
smart_contract = { path = "../smart_contract" }
@@ -42,8 +48,8 @@ tonic-build = "*"
[dev-dependencies]
criterion.workspace = true
claims = "0.8"
ark-groth16.workspace = true
tempfile = "3.20"
[[bench]]
name = "user_db_heavy_write"

View File

@@ -23,6 +23,14 @@ pub struct AppArgs {
help = "Websocket rpc url (e.g. wss://eth-mainnet.g.alchemy.com/v2/your-api-key)"
)]
pub(crate) ws_rpc_url: Option<Url>,
#[arg(long = "db", help = "Db path", default_value = "./storage/db")]
pub(crate) db_path: PathBuf,
#[arg(
long = "tree",
help = "Merkle tree path",
default_value = "./storage/tree"
)]
pub(crate) merkle_tree_path: PathBuf,
#[arg(short = 'k', long = "ksc", help = "Karma smart contract address")]
pub(crate) ksc_address: Option<Address>,
#[arg(short = 'r', long = "rlnsc", help = "RLN smart contract address")]

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use std::time::Duration;
// third-party
use chrono::{DateTime, NaiveDate, NaiveDateTime, OutOfRangeError, TimeDelta, Utc};
use derive_more::{Deref, From, Into};
use parking_lot::RwLock;
use tokio::sync::Notify;
use tracing::debug;
@@ -211,8 +212,8 @@ pub enum WaitUntilError {
}
/// An Epoch (wrapper type over i64)
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct Epoch(pub(crate) i64);
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into, Deref)]
pub(crate) struct Epoch(i64);
impl Add<i64> for Epoch {
type Output = Self;
@@ -222,21 +223,9 @@ impl Add<i64> for Epoch {
}
}
impl From<i64> for Epoch {
fn from(value: i64) -> Self {
Self(value)
}
}
impl From<Epoch> for i64 {
fn from(value: Epoch) -> Self {
value.0
}
}
/// An Epoch slice (wrapper type over i64)
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct EpochSlice(pub(crate) i64);
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into, Deref)]
pub(crate) struct EpochSlice(i64);
impl Add<i64> for EpochSlice {
type Output = Self;
@@ -246,18 +235,6 @@ impl Add<i64> for EpochSlice {
}
}
impl From<i64> for EpochSlice {
fn from(value: i64) -> Self {
Self(value)
}
}
impl From<EpochSlice> for i64 {
fn from(value: EpochSlice) -> Self {
value.0
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,10 +1,9 @@
use crate::epoch_service::WaitUntilError;
use alloy::{
primitives::Address,
transports::{RpcError, TransportErrorKind},
};
use alloy::transports::{RpcError, TransportErrorKind};
use ark_serialize::SerializationError;
use rln::error::ProofError;
// internal
use crate::epoch_service::WaitUntilError;
use crate::user_db_error::{RegisterError, UserMerkleTreeIndexError};
#[derive(thiserror::Error, Debug)]
pub enum AppError {
@@ -53,33 +52,19 @@ impl From<ProofGenerationError> for ProofGenerationStringError {
fn from(value: ProofGenerationError) -> Self {
match value {
ProofGenerationError::Proof(e) => ProofGenerationStringError::Proof(e.to_string()),
ProofGenerationError::Serialization(e) => {
ProofGenerationStringError::Serialization(e.to_string())
}
ProofGenerationError::SerializationWrite(e) => {
ProofGenerationStringError::SerializationWrite(e.to_string())
}
ProofGenerationError::MerkleProofError(e) => {
ProofGenerationStringError::MerkleProofError(e)
}
ProofGenerationError::Serialization(e) => Self::Serialization(e.to_string()),
ProofGenerationError::SerializationWrite(e) => Self::SerializationWrite(e.to_string()),
ProofGenerationError::MerkleProofError(e) => Self::MerkleProofError(e),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum RegisterError {
#[error("User (address: {0:?}) has already been registered")]
AlreadyRegistered(Address),
#[error("Merkle tree error: {0}")]
TreeError(String),
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum GetMerkleTreeProofError {
#[error("User not registered")]
NotRegistered,
#[error("Merkle tree error: {0}")]
TreeError(String),
#[error(transparent)]
MerkleTree(#[from] UserMerkleTreeIndexError),
}
#[derive(thiserror::Error, Debug)]

View File

@@ -18,9 +18,9 @@ use tower_http::cors::{Any, CorsLayer};
use tracing::debug;
use url::Url;
// internal
use crate::error::{AppError, ProofGenerationStringError, RegisterError};
use crate::error::{AppError, ProofGenerationStringError};
use crate::proof_generation::{ProofGenerationData, ProofSendingData};
use crate::user_db_service::{UserDb, UserTierInfo};
use crate::user_db::{UserDb, UserTierInfo};
use rln_proof::RlnIdentifier;
use smart_contract::{
KarmaAmountExt,
@@ -39,6 +39,7 @@ pub mod prover_proto {
pub(crate) const FILE_DESCRIPTOR_SET: &[u8] =
tonic::include_file_descriptor_set!("prover_descriptor");
}
use crate::user_db_error::RegisterError;
use prover_proto::{
GetUserTierInfoReply,
GetUserTierInfoRequest,
@@ -165,7 +166,7 @@ where
return Err(Status::invalid_argument("No sender address"));
};
let result = self.user_db.on_new_user(user);
let result = self.user_db.on_new_user(&user);
let status = match result {
Ok(id_commitment) => {
@@ -466,11 +467,11 @@ impl From<UserTierInfo> for UserTierInfoResult {
}
/// UserTierInfoError to UserTierInfoError (Grpc message) conversion
impl<E> From<crate::user_db_service::UserTierInfoError<E>> for UserTierInfoError
impl<E> From<crate::user_db_error::UserTierInfoError<E>> for UserTierInfoError
where
E: std::error::Error,
{
fn from(value: crate::user_db_service::UserTierInfoError<E>) -> Self {
fn from(value: crate::user_db_error::UserTierInfoError<E>) -> Self {
UserTierInfoError {
message: value.to_string(),
}

View File

@@ -7,9 +7,14 @@ mod mock;
mod proof_generation;
mod proof_service;
mod registry_listener;
mod rocksdb_operands;
mod tier;
mod tiers_listener;
mod user_db;
mod user_db_error;
mod user_db_serialization;
mod user_db_service;
mod user_db_types;
// std
use std::net::SocketAddr;
@@ -38,7 +43,8 @@ use crate::proof_service::ProofService;
use crate::registry_listener::RegistryListener;
use crate::tier::TierLimits;
use crate::tiers_listener::TiersListener;
use crate::user_db_service::{RateLimit, UserDbService};
use crate::user_db_service::UserDbService;
use crate::user_db_types::RateLimit;
const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier";
const PROVER_SPAM_LIMIT: RateLimit = RateLimit::new(10_000u64);
@@ -98,11 +104,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// User db service
let user_db_service = UserDbService::new(
app_args.db_path,
app_args.merkle_tree_path,
epoch_service.epoch_changes.clone(),
epoch_service.current_epoch.clone(),
PROVER_SPAM_LIMIT,
tier_limits,
);
)?;
if app_args.mock_sc.is_some() {
if let Some(user_filepath) = app_args.mock_user.as_ref() {
@@ -114,7 +122,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
mock_user.address, mock_user.tx_count
);
let user_db = user_db_service.get_user_db();
user_db.on_new_user(mock_user.address).unwrap();
user_db.on_new_user(&mock_user.address).unwrap();
user_db
.on_new_tx(&mock_user.address, Some(mock_user.tx_count))
.unwrap();

View File

@@ -12,7 +12,8 @@ use tracing::{debug, info};
use crate::epoch_service::{Epoch, EpochSlice};
use crate::error::{AppError, ProofGenerationError, ProofGenerationStringError};
use crate::proof_generation::{ProofGenerationData, ProofSendingData};
use crate::user_db_service::{RateLimit, UserDb};
use crate::user_db::UserDb;
use crate::user_db_types::RateLimit;
use rln_proof::{RlnData, compute_rln_proof_and_values};
const PROOF_SIZE: usize = 512;
@@ -82,8 +83,8 @@ impl ProofService {
};
let epoch_bytes = {
let mut v = i64::from(current_epoch).to_be_bytes().to_vec();
v.extend(i64::from(current_epoch_slice).to_be_bytes());
let mut v = current_epoch.to_le_bytes().to_vec();
v.extend(current_epoch_slice.to_le_bytes());
v
};
let epoch = hash_to_field(epoch_bytes.as_slice());
@@ -140,6 +141,7 @@ impl ProofService {
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
// third-party
use alloy::primitives::{Address, address};
use ark_groth16::{Proof as ArkProof, Proof, VerifyingKey};
@@ -256,15 +258,20 @@ mod tests {
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
// User db
let temp_folder = tempfile::tempdir().unwrap();
let temp_folder_tree = tempfile::tempdir().unwrap();
let user_db_service = UserDbService::new(
PathBuf::from(temp_folder.path()),
PathBuf::from(temp_folder_tree.path()),
Default::default(),
epoch_store.clone(),
10.into(),
Default::default(),
);
)
.unwrap();
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
user_db.on_new_user(ADDR_2).unwrap();
user_db.on_new_user(&ADDR_1).unwrap();
user_db.on_new_user(&ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
@@ -308,14 +315,19 @@ mod tests {
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
// User db
let temp_folder = tempfile::tempdir().unwrap();
let temp_folder_tree = tempfile::tempdir().unwrap();
let user_db_service = UserDbService::new(
PathBuf::from(temp_folder.path()),
PathBuf::from(temp_folder_tree.path()),
Default::default(),
epoch_store.clone(),
10.into(),
Default::default(),
);
)
.unwrap();
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
user_db.on_new_user(&ADDR_1).unwrap();
// user_db.on_new_user(ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
@@ -463,16 +475,21 @@ mod tests {
let rate_limit = RateLimit::from(1);
// User db
let temp_folder = tempfile::tempdir().unwrap();
let temp_folder_tree = tempfile::tempdir().unwrap();
let user_db_service = UserDbService::new(
PathBuf::from(temp_folder.path()),
PathBuf::from(temp_folder_tree.path()),
Default::default(),
epoch_store.clone(),
rate_limit,
Default::default(),
);
)
.unwrap();
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
user_db.on_new_user(&ADDR_1).unwrap();
let user_addr_1 = user_db.get_user(&ADDR_1).unwrap();
user_db.on_new_user(ADDR_2).unwrap();
user_db.on_new_user(&ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
@@ -528,17 +545,22 @@ mod tests {
let rate_limit = RateLimit::from(1);
// User db - limit is 1 message per epoch
let temp_folder = tempfile::tempdir().unwrap();
let temp_folder_tree = tempfile::tempdir().unwrap();
let user_db_service = UserDbService::new(
PathBuf::from(temp_folder.path()),
PathBuf::from(temp_folder_tree.path()),
Default::default(),
epoch_store.clone(),
rate_limit.into(),
rate_limit,
Default::default(),
);
)
.unwrap();
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
user_db.on_new_user(&ADDR_1).unwrap();
let user_addr_1 = user_db.get_user(&ADDR_1).unwrap();
debug!("user_addr_1: {:?}", user_addr_1);
user_db.on_new_user(ADDR_2).unwrap();
user_db.on_new_user(&ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));

View File

@@ -9,8 +9,9 @@ use alloy::{
use tonic::codegen::tokio_stream::StreamExt;
use tracing::{debug, error, info};
// internal
use crate::error::{AppError, HandleTransferError, RegisterError};
use crate::user_db_service::UserDb;
use crate::error::{AppError, HandleTransferError};
use crate::user_db::UserDb;
use crate::user_db_error::RegisterError;
use smart_contract::{AlloyWsProvider, KarmaAmountExt, KarmaSC};
pub(crate) struct RegistryListener {
@@ -113,7 +114,7 @@ impl RegistryListener {
if should_register {
self.user_db
.on_new_user(to_address)
.on_new_user(&to_address)
.map_err(HandleTransferError::Register)?;
}
}
@@ -125,6 +126,7 @@ impl RegistryListener {
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
// std
use std::sync::Arc;
// third-party
@@ -152,12 +154,17 @@ mod tests {
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let temp_folder = tempfile::tempdir().unwrap();
let temp_folder_tree = tempfile::tempdir().unwrap();
let user_db_service = UserDbService::new(
PathBuf::from(temp_folder.path()),
PathBuf::from(temp_folder_tree.path()),
Default::default(),
epoch_store,
10.into(),
Default::default(),
);
)
.unwrap();
let user_db = user_db_service.get_user_db();
assert!(user_db_service.get_user_db().get_user(&ADDR_2).is_none());

View File

@@ -0,0 +1,380 @@
use crate::epoch_service::{Epoch, EpochSlice};
use nom::{
IResult,
error::ContextError,
number::complete::{le_i64, le_u64},
};
use rocksdb::MergeOperands;
#[derive(Debug, PartialEq)]
pub enum DeserializeError<I> {
Nom(I, nom::error::ErrorKind),
}
impl<I> nom::error::ParseError<I> for DeserializeError<I> {
fn from_error_kind(input: I, kind: nom::error::ErrorKind) -> Self {
DeserializeError::Nom(input, kind)
}
fn append(_: I, _: nom::error::ErrorKind, other: Self) -> Self {
other
}
}
impl<I> ContextError<I> for DeserializeError<I> {}
#[derive(Debug, Default, PartialEq)]
pub struct EpochCounters {
pub epoch: Epoch,
pub epoch_slice: EpochSlice,
pub epoch_counter: u64,
pub epoch_slice_counter: u64,
}
pub struct EpochCounterSerializer {}
impl EpochCounterSerializer {
fn serialize(&self, value: &EpochCounters, buffer: &mut Vec<u8>) {
buffer.extend(value.epoch.to_le_bytes());
buffer.extend(value.epoch_slice.to_le_bytes());
buffer.extend(value.epoch_counter.to_le_bytes());
buffer.extend(value.epoch_slice_counter.to_le_bytes());
}
pub(crate) const fn size_hint_() -> usize {
size_of::<EpochCounters>()
}
pub(crate) fn size_hint(&self) -> usize {
Self::size_hint_()
}
pub const fn default() -> [u8; Self::size_hint_()] {
[0u8; Self::size_hint_()]
}
}
pub struct EpochCounterDeserializer {}
impl EpochCounterDeserializer {
pub fn deserialize<'a>(
&self,
buffer: &'a [u8],
) -> IResult<&'a [u8], EpochCounters, DeserializeError<&'a [u8]>> {
let (input, epoch) = le_i64(buffer).map(|(i, e)| (i, Epoch::from(e)))?;
let (input, epoch_slice) = le_i64(input).map(|(i, es)| (i, EpochSlice::from(es)))?;
let (input, epoch_counter) = le_u64(input)?;
let (_input, epoch_slice_counter) = le_u64(input)?;
Ok((
input,
EpochCounters {
epoch,
epoch_slice,
epoch_counter,
epoch_slice_counter,
},
))
}
}
#[derive(Debug, Default, PartialEq)]
pub struct EpochIncr {
pub epoch: Epoch,
pub epoch_slice: EpochSlice,
pub incr_value: u64,
}
pub struct EpochIncrSerializer {}
impl EpochIncrSerializer {
pub fn serialize(&self, value: &EpochIncr, buffer: &mut Vec<u8>) {
buffer.extend(value.epoch.to_le_bytes());
buffer.extend(value.epoch_slice.to_le_bytes());
buffer.extend(value.incr_value.to_le_bytes());
}
pub fn size_hint(&self) -> usize {
size_of::<u64>() * 3
}
}
pub struct EpochIncrDeserializer {}
impl EpochIncrDeserializer {
pub fn deserialize<'a>(
&self,
buffer: &'a [u8],
) -> IResult<&'a [u8], EpochIncr, DeserializeError<&'a [u8]>> {
let (input, epoch) = le_i64(buffer).map(|(i, e)| (i, Epoch::from(e)))?;
let (input, epoch_slice) = le_i64(input).map(|(i, es)| (i, EpochSlice::from(es)))?;
let (input, incr_value) = le_u64(input)?;
Ok((
input,
EpochIncr {
epoch,
epoch_slice,
incr_value,
},
))
}
}
pub fn epoch_counters_operands(
_key: &[u8],
existing_val: Option<&[u8]>,
operands: &MergeOperands,
) -> Option<Vec<u8>> {
let counter_ser = EpochCounterSerializer {};
let counter_deser = EpochCounterDeserializer {};
let ser = EpochIncrSerializer {};
let deser = EpochIncrDeserializer {};
// Current epoch counter structure (stored in DB)
let epoch_counter_current = counter_deser
.deserialize(existing_val.unwrap_or_default())
.map(|(_, c)| c)
.unwrap_or_default();
// Iter over merge operands (can have multiple one with DBBatch)
let counter_value = operands.iter().fold(epoch_counter_current, |mut acc, x| {
// Note: unwrap on EpochIncr deserialize error - serialization is done by the prover
// thus no error should never happen here
let (_, epoch_incr) = deser.deserialize(x).unwrap();
// TODO - optim: partial deser ?
// TODO: check if increasing ? debug_assert otherwise?
if acc == Default::default() {
// Default value - so this is the first time
acc = EpochCounters {
epoch: epoch_incr.epoch,
epoch_slice: epoch_incr.epoch_slice,
epoch_counter: epoch_incr.incr_value,
epoch_slice_counter: epoch_incr.incr_value,
}
} else if epoch_incr.epoch != acc.epoch {
// New epoch
acc = EpochCounters {
epoch: epoch_incr.epoch,
epoch_slice: Default::default(),
epoch_counter: epoch_incr.incr_value,
epoch_slice_counter: epoch_incr.incr_value,
}
} else if epoch_incr.epoch_slice != acc.epoch_slice {
// New epoch slice
acc = EpochCounters {
epoch: epoch_incr.epoch,
epoch_slice: epoch_incr.epoch_slice,
epoch_counter: acc.epoch_counter.saturating_add(epoch_incr.incr_value),
epoch_slice_counter: epoch_incr.incr_value,
}
} else {
acc = EpochCounters {
epoch: acc.epoch,
epoch_slice: acc.epoch_slice,
epoch_counter: acc.epoch_counter.saturating_add(epoch_incr.incr_value),
epoch_slice_counter: acc
.epoch_slice_counter
.saturating_add(epoch_incr.incr_value),
}
}
acc
});
let mut buffer = Vec::with_capacity(ser.size_hint());
counter_ser.serialize(&counter_value, &mut buffer);
Some(buffer.to_vec())
}
pub fn u64_counter_operands(
_key: &[u8],
existing_val: Option<&[u8]>,
operands: &MergeOperands,
) -> Option<Vec<u8>> {
let counter_current_value = if let Some(existing_val) = existing_val {
u64::from_le_bytes(existing_val.try_into().unwrap())
} else {
0
};
let counter_value = operands.iter().fold(counter_current_value, |mut acc, x| {
let incr_value = u64::from_le_bytes(x.try_into().unwrap());
acc = acc.saturating_add(incr_value);
acc
});
Some(counter_value.to_le_bytes().to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
// std
// third-party
use rocksdb::{DB, Options, WriteBatch};
use tempfile::TempDir;
#[test]
fn test_ser_der() {
// EpochCounter struct
{
let epoch_counter = EpochCounters {
epoch: 1.into(),
epoch_slice: 42.into(),
epoch_counter: 12,
epoch_slice_counter: u64::MAX,
};
let serializer = EpochCounterSerializer {};
let mut buffer = Vec::with_capacity(serializer.size_hint());
serializer.serialize(&epoch_counter, &mut buffer);
let deserializer = EpochCounterDeserializer {};
let (_, de) = deserializer.deserialize(&buffer).unwrap();
assert_eq!(epoch_counter, de);
}
{
let deserializer = EpochCounterDeserializer {};
let (_, de) = deserializer
.deserialize(EpochCounterSerializer::default().as_slice())
.unwrap();
assert_eq!(EpochCounters::default(), de);
}
// EpochIncr struct
{
let epoch_incr = EpochIncr {
epoch: 1.into(),
epoch_slice: 42.into(),
incr_value: 1,
};
let serializer = EpochIncrSerializer {};
let mut buffer = Vec::with_capacity(serializer.size_hint());
serializer.serialize(&epoch_incr, &mut buffer);
let deserializer = EpochIncrDeserializer {};
let (_, de) = deserializer.deserialize(&buffer).unwrap();
assert_eq!(epoch_incr, de);
}
}
#[test]
fn test_counter() {
let tmp_path = TempDir::new().unwrap().path().to_path_buf();
let options = {
let mut opts = Options::default();
opts.create_if_missing(true);
opts.set_merge_operator("o", u64_counter_operands, u64_counter_operands);
opts
};
let db = DB::open(&options, tmp_path).unwrap();
let key_1 = "foo1";
// let key_2 = "baz42";
let index = 42u64;
let buffer = index.to_le_bytes();
let mut db_batch = WriteBatch::default();
db_batch.merge(key_1, &buffer);
db_batch.merge(key_1, &buffer);
db.write(db_batch).unwrap();
let get_key_1 = db.get(&key_1).unwrap().unwrap();
let value = u64::from_le_bytes(get_key_1.try_into().unwrap());
assert_eq!(value, index * 2); // 2x merge
}
#[test]
fn test_counters() {
let tmp_path = TempDir::new().unwrap().path().to_path_buf();
let options = {
let mut opts = Options::default();
opts.create_if_missing(true);
opts.set_merge_operator("o", epoch_counters_operands, epoch_counters_operands);
opts
};
let db = DB::open(&options, tmp_path).unwrap();
let key_1 = "foo1";
let key_2 = "baz42";
let value_1 = EpochIncr {
epoch: 0.into(),
epoch_slice: 0.into(),
incr_value: 2,
};
let epoch_incr_ser = EpochIncrSerializer {};
let epoch_counter_deser = EpochCounterDeserializer {};
let mut buffer = Vec::with_capacity(epoch_incr_ser.size_hint());
epoch_incr_ser.serialize(&value_1, &mut buffer);
let mut db_batch = WriteBatch::default();
db_batch.merge(key_1, &buffer);
db_batch.merge(key_1, &buffer);
db.write(db_batch).unwrap();
let get_key_1 = db.get(&key_1).unwrap().unwrap();
let (_, get_value_k1) = epoch_counter_deser.deserialize(&get_key_1).unwrap();
// Applied EpochIncr 2x
assert_eq!(get_value_k1.epoch_counter, 4);
assert_eq!(get_value_k1.epoch_slice_counter, 4);
let get_key_2 = db.get(&key_2).unwrap();
assert!(get_key_2.is_none());
// new epoch slice
{
let value_2 = EpochIncr {
epoch: 0.into(),
epoch_slice: 1.into(),
incr_value: 1,
};
let mut buffer = Vec::with_capacity(epoch_incr_ser.size_hint());
epoch_incr_ser.serialize(&value_2, &mut buffer);
db.merge(key_1, buffer).unwrap();
let get_key_1 = db.get(&key_1).unwrap().unwrap();
let (_, get_value_2) = epoch_counter_deser.deserialize(&get_key_1).unwrap();
assert_eq!(
get_value_2,
EpochCounters {
epoch: 0.into(),
epoch_slice: 1.into(),
epoch_counter: 5,
epoch_slice_counter: 1,
}
)
}
// new epoch
{
let value_3 = EpochIncr {
epoch: 1.into(),
epoch_slice: 0.into(),
incr_value: 3,
};
let mut buffer = Vec::with_capacity(epoch_incr_ser.size_hint());
epoch_incr_ser.serialize(&value_3, &mut buffer);
db.merge(key_1, buffer).unwrap();
let get_key_1 = db.get(&key_1).unwrap().unwrap();
let (_, get_value_3) = epoch_counter_deser.deserialize(&get_key_1).unwrap();
assert_eq!(
get_value_3,
EpochCounters {
epoch: 1.into(),
epoch_slice: 0.into(),
epoch_counter: 3,
epoch_slice_counter: 3,
}
)
}
}
}

View File

@@ -1,17 +1,14 @@
use std::collections::{BTreeMap, HashSet};
use std::ops::{
ControlFlow,
Deref, DerefMut
};
use std::ops::{ControlFlow, Deref, DerefMut};
// third-party
use derive_more::{From, Into};
use alloy::primitives::U256;
use derive_more::{From, Into};
// internal
use crate::user_db_service::SetTierLimitsError;
// use crate::user_db_service::SetTierLimitsError;
use smart_contract::{Tier, TierIndex};
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, From, Into)]
pub struct TierLimit(u64);
pub struct TierLimit(u32);
#[derive(Debug, Clone, PartialEq, Eq, Hash, From, Into)]
pub struct TierName(String);
@@ -47,7 +44,7 @@ impl TierLimits {
}
/// Validate tier limits (unique names, increasing min & max karma ...)
pub(crate) fn validate(&self) -> Result<(), SetTierLimitsError> {
pub(crate) fn validate(&self) -> Result<(), ValidateTierLimitsError> {
#[derive(Default)]
struct Context<'a> {
tier_names: HashSet<String>,
@@ -61,30 +58,30 @@ impl TierLimits {
.iter()
.try_fold(Context::default(), |mut state, (tier_index, tier)| {
if !tier.active {
return Err(SetTierLimitsError::InactiveTier);
return Err(ValidateTierLimitsError::InactiveTier);
}
if *tier_index <= *state.prev_index.unwrap_or(&TierIndex::default()) {
return Err(SetTierLimitsError::InvalidTierIndex);
return Err(ValidateTierLimitsError::InvalidTierIndex);
}
if tier.min_karma >= tier.max_karma {
return Err(SetTierLimitsError::InvalidMaxAmount(
return Err(ValidateTierLimitsError::InvalidMaxAmount(
tier.min_karma,
tier.max_karma,
));
}
if tier.min_karma <= *state.prev_amount.unwrap_or(&U256::ZERO) {
return Err(SetTierLimitsError::InvalidKarmaAmount);
return Err(ValidateTierLimitsError::InvalidKarmaAmount);
}
if tier.tx_per_epoch <= *state.prev_tx_per_epoch.unwrap_or(&0) {
return Err(SetTierLimitsError::InvalidTierLimit);
return Err(ValidateTierLimitsError::InvalidTierLimit);
}
if state.tier_names.contains(&tier.name) {
return Err(SetTierLimitsError::NonUniqueTierName);
return Err(ValidateTierLimitsError::NonUniqueTierName);
}
state.prev_amount = Some(&tier.min_karma);
@@ -99,7 +96,6 @@ impl TierLimits {
/// Given some karma amount, find the matching Tier
pub(crate) fn get_tier_by_karma(&self, karma_amount: &U256) -> Option<(TierIndex, Tier)> {
struct Context<'a> {
prev: Option<(&'a TierIndex, &'a Tier)>,
}
@@ -124,3 +120,19 @@ impl TierLimits {
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ValidateTierLimitsError {
#[error("Invalid Karma amount (must be increasing)")]
InvalidKarmaAmount,
#[error("Invalid Karma max amount (min: {0} vs max: {1})")]
InvalidMaxAmount(U256, U256),
#[error("Invalid Tier limit (must be increasing)")]
InvalidTierLimit,
#[error("Invalid Tier index (must be increasing)")]
InvalidTierIndex,
#[error("Non unique Tier name")]
NonUniqueTierName,
#[error("Non active Tier")]
InactiveTier,
}

View File

@@ -9,7 +9,7 @@ use futures::StreamExt;
use tracing::error;
// internal
use crate::error::AppError;
use crate::user_db_service::UserDb;
use crate::user_db::UserDb;
use smart_contract::{AlloyWsProvider, KarmaTiersSC, Tier, TierIndex};
pub(crate) struct TiersListener {

1018
prover/src/user_db.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
use std::num::TryFromIntError;
// third-party
use alloy::primitives::Address;
use zerokit_utils::error::{FromConfigError, ZerokitMerkleTreeError};
// internal
use crate::tier::ValidateTierLimitsError;
#[derive(Debug, thiserror::Error)]
pub(crate) enum UserDbOpenError {
#[error(transparent)]
RocksDb(#[from] rocksdb::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] TryFromIntError),
#[error(transparent)]
JsonSerialization(#[from] serde_json::Error),
#[error(transparent)]
TreeConfig(#[from] FromConfigError),
#[error(transparent)]
MerkleTree(#[from] ZerokitMerkleTreeError),
#[error(transparent)]
MerkleTreeIndex(#[from] MerkleTreeIndexError),
}
#[derive(thiserror::Error, Debug)]
pub enum RegisterError {
#[error("User (address: {0:?}) has already been registered")]
AlreadyRegistered(Address),
#[error(transparent)]
Db(#[from] rocksdb::Error),
#[error("Merkle tree error: {0}")]
TreeError(String),
}
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum TxCounterError {
#[error("User (address: {0:?}) is not registered")]
NotRegistered(Address),
#[error(transparent)]
Db(#[from] rocksdb::Error),
}
#[derive(thiserror::Error, Debug, PartialEq, Clone)]
pub enum MerkleTreeIndexError {
#[error("Uninitialized counter")]
DbUninitialized,
#[error(transparent)]
Db(#[from] rocksdb::Error),
}
#[derive(thiserror::Error, Debug, PartialEq, Clone)]
pub enum UserMerkleTreeIndexError {
#[error("User (address: {0:?}) is not registered")]
NotRegistered(Address),
#[error(transparent)]
Db(#[from] rocksdb::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum SetTierLimitsError {
#[error(transparent)]
Validate(#[from] ValidateTierLimitsError),
#[error("Updating an invalid tier index")]
InvalidUpdateTierIndex,
#[error(transparent)]
Db(#[from] rocksdb::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum UserTierInfoError<E: std::error::Error> {
#[error("User {0} not registered")]
NotRegistered(Address),
#[error(transparent)]
Contract(E),
#[error(transparent)]
TxCounter(#[from] TxCounterError),
#[error(transparent)]
Db(#[from] rocksdb::Error),
}

View File

@@ -0,0 +1,317 @@
use std::collections::BTreeMap;
use std::num::TryFromIntError;
use std::string::FromUtf8Error;
// third-party
use alloy::primitives::U256;
use ark_bn254::Fr;
use ark_ff::fields::AdditiveGroup;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError};
use nom::{
IResult, Parser,
bytes::complete::take,
error::{ContextError, context},
multi::length_count,
number::complete::{le_u32, le_u64},
};
use rln_proof::RlnUserIdentity;
// internal
use crate::tier::TierLimits;
use crate::user_db_types::MerkleTreeIndex;
use smart_contract::{Tier, TierIndex};
pub(crate) struct RlnUserIdentitySerializer {}
impl RlnUserIdentitySerializer {
pub(crate) fn serialize(
&self,
value: &RlnUserIdentity,
buffer: &mut Vec<u8>,
) -> Result<(), SerializationError> {
buffer.resize(self.size_hint(), 0);
let compressed_size = value.commitment.compressed_size();
let (co_buffer, rem_buffer) = buffer.split_at_mut(compressed_size);
value.commitment.serialize_compressed(co_buffer)?;
let (secret_buffer, user_limit_buffer) = rem_buffer.split_at_mut(compressed_size);
value.secret_hash.serialize_compressed(secret_buffer)?;
value.user_limit.serialize_compressed(user_limit_buffer)?;
Ok(())
}
pub(crate) fn size_hint(&self) -> usize {
Fr::ZERO.compressed_size() * 3
}
}
pub(crate) struct RlnUserIdentityDeserializer {}
impl RlnUserIdentityDeserializer {
pub(crate) fn deserialize(&self, buffer: &[u8]) -> Result<RlnUserIdentity, SerializationError> {
let compressed_size = Fr::ZERO.compressed_size();
let (co_buffer, rem_buffer) = buffer.split_at(compressed_size);
let commitment: Fr = CanonicalDeserialize::deserialize_compressed(co_buffer)?;
let (secret_buffer, user_limit_buffer) = rem_buffer.split_at(compressed_size);
// TODO: IdSecret
let secret_hash: Fr = CanonicalDeserialize::deserialize_compressed(secret_buffer)?;
let user_limit: Fr = CanonicalDeserialize::deserialize_compressed(user_limit_buffer)?;
Ok({
RlnUserIdentity {
commitment,
secret_hash,
user_limit,
}
})
}
}
pub(crate) struct MerkleTreeIndexSerializer {}
impl MerkleTreeIndexSerializer {
pub(crate) fn serialize(&self, value: &MerkleTreeIndex, buffer: &mut Vec<u8>) {
let value: u64 = (*value).into();
buffer.extend(value.to_le_bytes());
}
pub(crate) fn size_hint(&self) -> usize {
// Note: Assume usize size == 8 bytes
size_of::<MerkleTreeIndex>()
}
}
pub(crate) struct MerkleTreeIndexDeserializer {}
impl MerkleTreeIndexDeserializer {
pub(crate) fn deserialize<'a>(
&self,
buffer: &'a [u8],
) -> IResult<&'a [u8], MerkleTreeIndex, nom::error::Error<&'a [u8]>> {
le_u64(buffer).map(|(input, idx)| (input, MerkleTreeIndex::from(idx)))
}
}
#[derive(Default)]
pub(crate) struct TierSerializer {}
impl TierSerializer {
pub(crate) fn serialize(
&self,
value: &Tier,
buffer: &mut Vec<u8>,
) -> Result<(), TryFromIntError> {
const U256_SIZE: usize = size_of::<U256>();
buffer.extend(value.min_karma.to_le_bytes::<U256_SIZE>().as_slice());
buffer.extend(value.max_karma.to_le_bytes::<U256_SIZE>().as_slice());
let name_len = u32::try_from(value.name.len())?;
buffer.extend(name_len.to_le_bytes());
buffer.extend(value.name.as_bytes());
buffer.extend(value.tx_per_epoch.to_le_bytes().as_slice());
buffer.push(u8::from(value.active));
Ok(())
}
pub(crate) fn size_hint(&self) -> usize {
size_of::<Tier>()
}
}
#[derive(Default)]
pub(crate) struct TierDeserializer {}
#[derive(Debug, PartialEq)]
pub enum TierDeserializeError<I> {
Utf8Error(FromUtf8Error),
TryFrom,
Nom(I, nom::error::ErrorKind),
}
impl<I> nom::error::ParseError<I> for TierDeserializeError<I> {
fn from_error_kind(input: I, kind: nom::error::ErrorKind) -> Self {
TierDeserializeError::Nom(input, kind)
}
fn append(_: I, _: nom::error::ErrorKind, other: Self) -> Self {
other
}
}
impl<I> ContextError<I> for TierDeserializeError<I> {}
impl TierDeserializer {
pub(crate) fn deserialize<'a>(
&self,
buffer: &'a [u8],
) -> IResult<&'a [u8], Tier, TierDeserializeError<&'a [u8]>> {
let (input, min_karma) = take(32usize)(buffer)?;
let min_karma = U256::try_from_le_slice(min_karma)
.ok_or(nom::Err::Error(TierDeserializeError::TryFrom))?;
let (input, max_karma) = take(32usize)(input)?;
let max_karma = U256::try_from_le_slice(max_karma)
.ok_or(nom::Err::Error(TierDeserializeError::TryFrom))?;
let (input, name_len) = le_u32(input)?;
let name_len_ = usize::try_from(name_len)
.map_err(|_e| nom::Err::Error(TierDeserializeError::TryFrom))?;
let (input, name) = take(name_len_)(input)?;
let name = String::from_utf8(name.to_vec())
.map_err(|e| nom::Err::Error(TierDeserializeError::Utf8Error(e)))?;
let (input, tx_per_epoch) = le_u32(input)?;
let (input, active) = take(1usize)(input)?;
let active = active[0] != 0;
Ok((
input,
Tier {
min_karma,
max_karma,
name,
tx_per_epoch,
active,
},
))
}
}
#[derive(Default)]
pub(crate) struct TierLimitsSerializer {
tier_serializer: TierSerializer,
}
impl TierLimitsSerializer {
pub(crate) fn serialize(
&self,
value: &TierLimits,
buffer: &mut Vec<u8>,
) -> Result<(), TryFromIntError> {
let len = value.len() as u32;
buffer.extend(len.to_le_bytes());
let mut tier_buffer = Vec::with_capacity(self.tier_serializer.size_hint());
value.iter().try_for_each(|(k, v)| {
buffer.push(k.into());
self.tier_serializer.serialize(v, &mut tier_buffer)?;
buffer.extend_from_slice(&tier_buffer);
tier_buffer.clear();
Ok(())
})
}
pub(crate) fn size_hint(&self, len: usize) -> usize {
size_of::<u32>() + len * self.tier_serializer.size_hint()
}
}
#[derive(Default)]
pub(crate) struct TierLimitsDeserializer {
pub(crate) tier_deserializer: TierDeserializer,
}
impl TierLimitsDeserializer {
pub(crate) fn deserialize<'a>(
&self,
buffer: &'a [u8],
) -> IResult<&'a [u8], TierLimits, TierDeserializeError<&'a [u8]>> {
let (input, tiers): (&[u8], BTreeMap<TierIndex, Tier>) = length_count(
le_u32,
context("Tier index & Tier deser", |input: &'a [u8]| {
let (input, tier_index) = take(1usize)(input)?;
let tier_index = TierIndex::from(tier_index[0]);
let (input, tier) = self.tier_deserializer.deserialize(input)?;
Ok((input, (tier_index, tier)))
}),
)
.map(BTreeMap::from_iter)
.parse(buffer)?;
Ok((input, TierLimits::from(tiers)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rln_ser_der() {
let rln_user_identity = RlnUserIdentity {
commitment: Fr::from(42),
secret_hash: Fr::from(u16::MAX),
user_limit: Fr::from(1_000_000),
};
let serializer = RlnUserIdentitySerializer {};
let mut buffer = Vec::with_capacity(serializer.size_hint());
serializer
.serialize(&rln_user_identity, &mut buffer)
.unwrap();
let deserializer = RlnUserIdentityDeserializer {};
let de = deserializer.deserialize(&buffer).unwrap();
assert_eq!(rln_user_identity, de);
}
#[test]
fn test_mtree_ser_der() {
let index = MerkleTreeIndex::from(4242);
let serializer = MerkleTreeIndexSerializer {};
let mut buffer = Vec::with_capacity(serializer.size_hint());
serializer.serialize(&index, &mut buffer);
let deserializer = MerkleTreeIndexDeserializer {};
let (_, de) = deserializer.deserialize(&buffer).unwrap();
assert_eq!(index, de);
}
#[test]
fn test_tier_ser_der() {
let tier = Tier {
min_karma: U256::from(10),
max_karma: U256::from(u64::MAX),
name: "All".to_string(),
tx_per_epoch: 10_000_000,
active: false,
};
let serializer = TierSerializer {};
let mut buffer = Vec::with_capacity(serializer.size_hint());
serializer.serialize(&tier, &mut buffer).unwrap();
let deserializer = TierDeserializer {};
let (_, de) = deserializer.deserialize(&buffer).unwrap();
assert_eq!(tier, de);
}
#[test]
fn test_tier_limits_ser_der() {
let tier_1 = Tier {
min_karma: U256::from(2),
max_karma: U256::from(4),
name: "Basic".to_string(),
tx_per_epoch: 10_000,
active: false,
};
let tier_2 = Tier {
min_karma: U256::from(10),
max_karma: U256::from(u64::MAX),
name: "Premium".to_string(),
tx_per_epoch: 1_000_000_000,
active: true,
};
let tier_limits = TierLimits::from(BTreeMap::from([
(TierIndex::from(1), tier_1),
(TierIndex::from(2), tier_2),
]));
let serializer = TierLimitsSerializer::default();
let mut buffer = Vec::with_capacity(serializer.size_hint(tier_limits.len()));
serializer.serialize(&tier_limits, &mut buffer).unwrap();
let deserializer = TierLimitsDeserializer::default();
let (_, de) = deserializer.deserialize(&buffer).unwrap();
assert_eq!(tier_limits, de);
}
}

View File

@@ -1,349 +1,17 @@
use std::ops::Deref;
// std
use parking_lot::RwLock;
use std::path::PathBuf;
use std::sync::Arc;
// third-party
use alloy::primitives::{Address, U256};
use ark_bn254::Fr;
use derive_more::{Add, From, Into};
use parking_lot::RwLock;
use rln::hashers::poseidon_hash;
use rln::poseidon_tree::{MerkleProof, PoseidonTree};
use rln::protocol::keygen;
use scc::HashMap;
use tokio::sync::Notify;
use tracing::debug;
// internal
use crate::epoch_service::{Epoch, EpochSlice};
use crate::error::{AppError, GetMerkleTreeProofError, RegisterError};
use crate::tier::{TierLimit, TierLimits, TierName};
use rln_proof::{RlnUserIdentity, ZerokitMerkleTree};
use smart_contract::{KarmaAmountExt, Tier, TierIndex};
const MERKLE_TREE_HEIGHT: usize = 20;
#[derive(Debug, Clone, Copy, From, Into)]
struct MerkleTreeIndex(usize);
#[derive(Debug, Clone, Copy, Default, PartialOrd, PartialEq, From, Into)]
pub struct RateLimit(u64);
impl RateLimit {
pub(crate) const ZERO: RateLimit = RateLimit(0);
pub(crate) const fn new(value: u64) -> Self {
Self(value)
}
}
impl From<RateLimit> for Fr {
fn from(rate_limit: RateLimit) -> Self {
Fr::from(rate_limit.0)
}
}
#[derive(Clone)]
pub(crate) struct UserRegistry {
inner: HashMap<Address, (RlnUserIdentity, MerkleTreeIndex)>,
merkle_tree: Arc<RwLock<PoseidonTree>>,
rate_limit: RateLimit,
}
impl std::fmt::Debug for UserRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "UserRegistry {{ inner: {:?} }}", self.inner)
}
}
impl Default for UserRegistry {
fn default() -> Self {
Self {
inner: Default::default(),
// unwrap safe - no config
merkle_tree: Arc::new(RwLock::new(
PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default())
.unwrap(),
)),
rate_limit: Default::default(),
}
}
}
impl From<RateLimit> for UserRegistry {
fn from(rate_limit: RateLimit) -> Self {
Self {
inner: Default::default(),
// unwrap safe - no config
merkle_tree: Arc::new(RwLock::new(
PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default())
.unwrap(),
)),
rate_limit,
}
}
}
impl UserRegistry {
fn register(&self, address: Address) -> Result<Fr, RegisterError> {
let (identity_secret_hash, id_commitment) = keygen();
let index = self.inner.len();
self.inner
.insert(
address,
(
RlnUserIdentity::from((
identity_secret_hash,
id_commitment,
Fr::from(self.rate_limit),
)),
MerkleTreeIndex(index),
),
)
.map_err(|_e| RegisterError::AlreadyRegistered(address))?;
let rate_commit = poseidon_hash(&[id_commitment, Fr::from(u64::from(self.rate_limit))]);
self.merkle_tree
.write()
.set(index, rate_commit)
.map_err(|e| RegisterError::TreeError(e.to_string()))?;
Ok(id_commitment)
}
fn has_user(&self, address: &Address) -> bool {
self.inner.contains(address)
}
fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
self.inner.get(address).map(|entry| entry.0.clone())
}
fn get_merkle_proof(&self, address: &Address) -> Result<MerkleProof, GetMerkleTreeProofError> {
let index = self
.inner
.get(address)
.map(|entry| entry.1)
.ok_or(GetMerkleTreeProofError::NotRegistered)?;
self.merkle_tree
.read()
.proof(index.into())
.map_err(|e| GetMerkleTreeProofError::TreeError(e.to_string()))
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
pub(crate) struct EpochCounter(u64);
#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
pub(crate) struct EpochSliceCounter(u64);
#[derive(Debug, Default, Clone)]
pub(crate) struct TxRegistry {
inner: HashMap<Address, (EpochCounter, EpochSliceCounter)>,
}
impl Deref for TxRegistry {
type Target = HashMap<Address, (EpochCounter, EpochSliceCounter)>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl TxRegistry {
/// Update the transaction counter for the given address
///
/// If incr_value is None, the counter will be incremented by 1
/// If incr_value is Some(x), the counter will be incremented by x
///
/// Returns the new value of the counter
pub fn incr_counter(&self, address: &Address, incr_value: Option<u64>) -> EpochSliceCounter {
let incr_value = incr_value.unwrap_or(1);
let mut entry = self.inner.entry(*address).or_default();
*entry = (
entry.0 + EpochCounter(incr_value),
entry.1 + EpochSliceCounter(incr_value),
);
entry.1
}
}
#[derive(Debug, PartialEq)]
pub struct UserTierInfo {
pub(crate) current_epoch: Epoch,
pub(crate) current_epoch_slice: EpochSlice,
pub(crate) epoch_tx_count: u64,
pub(crate) epoch_slice_tx_count: u64,
karma_amount: U256,
pub(crate) tier_name: Option<TierName>,
pub(crate) tier_limit: Option<TierLimit>,
}
#[derive(Debug, thiserror::Error)]
pub enum UserTierInfoError<E: std::error::Error> {
#[error("User {0} not registered")]
NotRegistered(Address),
#[error(transparent)]
Contract(E),
}
/// User registration + tx counters + tier limits storage
#[derive(Debug, Clone)]
pub struct UserDb {
user_registry: Arc<UserRegistry>,
tx_registry: Arc<TxRegistry>,
tier_limits: Arc<RwLock<TierLimits>>,
tier_limits_next: Arc<RwLock<TierLimits>>,
epoch_store: Arc<RwLock<(Epoch, EpochSlice)>>,
}
impl UserDb {
fn on_new_epoch(&self) {
self.tx_registry.clear()
}
fn on_new_epoch_slice(&self) {
self.tx_registry.retain(|_a, v| {
*v = (v.0, Default::default());
true
});
let tier_limits_next_has_updates = !self.tier_limits_next.read().is_empty();
if tier_limits_next_has_updates {
let mut guard = self.tier_limits_next.write();
// mem::take will clear the TierLimits in tier_limits_next
let new_tier_limits = std::mem::take(&mut *guard);
debug!("Installing new tier limits: {:?}", new_tier_limits);
*self.tier_limits.write() = new_tier_limits;
}
}
pub fn on_new_user(&self, address: Address) -> Result<Fr, RegisterError> {
self.user_registry.register(address)
}
pub fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
self.user_registry.get_user(address)
}
pub fn get_merkle_proof(
&self,
address: &Address,
) -> Result<MerkleProof, GetMerkleTreeProofError> {
self.user_registry.get_merkle_proof(address)
}
pub(crate) fn on_new_tx(
&self,
address: &Address,
incr_value: Option<u64>,
) -> Option<EpochSliceCounter> {
if self.user_registry.has_user(address) {
Some(self.tx_registry.incr_counter(address, incr_value))
} else {
None
}
}
pub(crate) fn on_new_tier_limits(
&self,
tier_limits: TierLimits,
) -> Result<(), SetTierLimitsError> {
let tier_limits = tier_limits.clone().filter_inactive();
tier_limits.validate()?;
*self.tier_limits_next.write() = tier_limits;
Ok(())
}
pub(crate) fn on_new_tier(
&self,
tier_index: TierIndex,
tier: Tier,
) -> Result<(), SetTierLimitsError> {
let mut tier_limits = self.tier_limits.read().clone();
tier_limits.insert(tier_index, tier);
tier_limits.validate()?;
// Write
*self.tier_limits_next.write() = tier_limits;
Ok(())
}
pub(crate) fn on_tier_updated(
&self,
tier_index: TierIndex,
tier: Tier,
) -> Result<(), SetTierLimitsError> {
let mut tier_limits = self.tier_limits.read().clone();
if !tier_limits.contains_key(&tier_index) {
return Err(SetTierLimitsError::InvalidTierIndex);
}
tier_limits.entry(tier_index).and_modify(|e| *e = tier);
tier_limits.validate()?;
// Write
*self.tier_limits_next.write() = tier_limits;
Ok(())
}
/// Get user tier info
pub(crate) async fn user_tier_info<E: std::error::Error, KSC: KarmaAmountExt<Error = E>>(
&self,
address: &Address,
karma_sc: &KSC,
) -> Result<UserTierInfo, UserTierInfoError<E>> {
if self.user_registry.has_user(address) {
let (epoch_tx_count, epoch_slice_tx_count) = self
.tx_registry
.get(address)
.map(|ref_v| (ref_v.0, ref_v.1))
.unwrap_or_default();
let karma_amount = karma_sc
.karma_amount(address)
.await
.map_err(|e| UserTierInfoError::Contract(e))?;
let tier_limits_guard = self.tier_limits.read();
let tier_info = tier_limits_guard.get_tier_by_karma(&karma_amount);
drop(tier_limits_guard);
let user_tier_info = {
let (current_epoch, current_epoch_slice) = *self.epoch_store.read();
let mut t = UserTierInfo {
current_epoch,
current_epoch_slice,
epoch_tx_count: epoch_tx_count.into(),
epoch_slice_tx_count: epoch_slice_tx_count.into(),
karma_amount,
tier_name: None,
tier_limit: None,
};
if let Some((_tier_index, tier)) = tier_info {
t.tier_name = Some(tier.name.into());
// TODO
t.tier_limit = Some(0.into());
}
t
};
Ok(user_tier_info)
} else {
Err(UserTierInfoError::NotRegistered(*address))
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum SetTierLimitsError {
#[error("Invalid Karma amount (must be increasing)")]
InvalidKarmaAmount,
#[error("Invalid Karma max amount (min: {0} vs max: {1})")]
InvalidMaxAmount(U256, U256),
#[error("Invalid Tier limit (must be increasing)")]
InvalidTierLimit,
#[error("Invalid Tier index (must be increasing)")]
InvalidTierIndex,
#[error("Non unique Tier name")]
NonUniqueTierName,
#[error("Non active Tier")]
InactiveTier,
}
use crate::error::AppError;
use crate::tier::TierLimits;
use crate::user_db::UserDb;
use crate::user_db_error::UserDbOpenError;
use crate::user_db_types::RateLimit;
/// Async service to update a UserDb on epoch changes
#[derive(Debug)]
@@ -353,22 +21,25 @@ pub struct UserDbService {
}
impl UserDbService {
pub(crate) fn new(
pub fn new(
db_path: PathBuf,
merkle_tree_path: PathBuf,
epoch_changes_notifier: Arc<Notify>,
epoch_store: Arc<RwLock<(Epoch, EpochSlice)>>,
rate_limit: RateLimit,
tier_limits: TierLimits,
) -> Self {
Self {
user_db: UserDb {
user_registry: Arc::new(UserRegistry::from(rate_limit)),
tx_registry: Default::default(),
tier_limits: Arc::new(RwLock::new(tier_limits)),
tier_limits_next: Arc::new(Default::default()),
epoch_store,
},
) -> Result<Self, UserDbOpenError> {
let user_db = UserDb::new(
db_path,
merkle_tree_path,
epoch_store,
tier_limits,
rate_limit,
)?;
Ok(Self {
user_db,
epoch_changes: epoch_changes_notifier,
}
})
}
pub fn get_user_db(&self) -> UserDb {
@@ -412,507 +83,3 @@ impl UserDbService {
*current_epoch_slice = new_epoch_slice;
}
}
#[cfg(test)]
mod tests {
use super::*;
// std
use std::collections::BTreeMap;
// third-party
use alloy::primitives::address;
use async_trait::async_trait;
use claims::{assert_err, assert_matches};
use derive_more::Display;
#[derive(Debug, Display, thiserror::Error)]
struct DummyError();
struct MockKarmaSc {}
#[async_trait]
impl KarmaAmountExt for MockKarmaSc {
type Error = DummyError;
async fn karma_amount(&self, _address: &Address) -> Result<U256, Self::Error> {
Ok(U256::from(10))
}
}
const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0");
struct MockKarmaSc2 {}
#[async_trait]
impl KarmaAmountExt for MockKarmaSc2 {
type Error = DummyError;
async fn karma_amount(&self, address: &Address) -> Result<U256, Self::Error> {
if address == &ADDR_1 {
Ok(U256::from(10))
} else if address == &ADDR_2 {
Ok(U256::from(2000))
} else {
Ok(U256::ZERO)
}
}
}
#[test]
fn test_user_register() {
let user_db = UserDb {
user_registry: Default::default(),
tx_registry: Default::default(),
tier_limits: Arc::new(RwLock::new(Default::default())),
tier_limits_next: Arc::new(Default::default()),
epoch_store: Arc::new(RwLock::new(Default::default())),
};
let addr = Address::new([0; 20]);
user_db.user_registry.register(addr).unwrap();
assert_matches!(
user_db.user_registry.register(addr),
Err(RegisterError::AlreadyRegistered(_))
);
}
#[tokio::test]
async fn test_incr_tx_counter() {
let user_db = UserDb {
user_registry: Default::default(),
tx_registry: Default::default(),
tier_limits: Arc::new(RwLock::new(Default::default())),
tier_limits_next: Arc::new(Default::default()),
epoch_store: Arc::new(RwLock::new(Default::default())),
};
let addr = Address::new([0; 20]);
// Try to update tx counter without registering first
assert_eq!(user_db.on_new_tx(&addr, None), None);
let tier_info = user_db.user_tier_info(&addr, &MockKarmaSc {}).await;
// User is not registered -> no tier info
assert!(matches!(
tier_info,
Err(UserTierInfoError::NotRegistered(_))
));
// Register user
user_db.user_registry.register(addr).unwrap();
// Now update user tx counter
assert_eq!(user_db.on_new_tx(&addr, None), Some(EpochSliceCounter(1)));
let tier_info = user_db
.user_tier_info(&addr, &MockKarmaSc {})
.await
.unwrap();
assert_eq!(tier_info.epoch_tx_count, 1);
assert_eq!(tier_info.epoch_slice_tx_count, 1);
}
#[tokio::test]
async fn test_update_on_epoch_changes() {
let mut epoch = Epoch::from(11);
let mut epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let tier_limits = BTreeMap::from([
(
TierIndex::from(0),
Tier {
name: "Basic".into(),
min_karma: U256::from(10),
max_karma: U256::from(49),
tx_per_epoch: 5,
active: true,
},
),
(
TierIndex::from(1),
Tier {
name: "Active".into(),
min_karma: U256::from(50),
max_karma: U256::from(99),
tx_per_epoch: 10,
active: true,
},
),
(
TierIndex::from(2),
Tier {
name: "Regular".into(),
min_karma: U256::from(100),
max_karma: U256::from(499),
tx_per_epoch: 15,
active: true,
},
),
(
TierIndex::from(3),
Tier {
name: "Power User".into(),
min_karma: U256::from(500),
max_karma: U256::from(4999),
tx_per_epoch: 20,
active: true,
},
),
(
TierIndex::from(4),
Tier {
name: "S-Tier".into(),
min_karma: U256::from(5000),
max_karma: U256::from(U256::MAX),
tx_per_epoch: 25,
active: true,
},
),
]);
let user_db_service = UserDbService::new(
Default::default(),
epoch_store,
10.into(),
tier_limits.into(),
);
let user_db = user_db_service.get_user_db();
let addr_1_tx_count = 2;
let addr_2_tx_count = 820;
user_db.user_registry.register(ADDR_1).unwrap();
user_db
.tx_registry
.incr_counter(&ADDR_1, Some(addr_1_tx_count));
user_db.user_registry.register(ADDR_2).unwrap();
user_db
.tx_registry
.incr_counter(&ADDR_2, Some(addr_2_tx_count));
// incr epoch slice (42 -> 43)
{
let new_epoch = epoch;
let new_epoch_slice = epoch_slice + 1;
user_db_service.update_on_epoch_changes(
&mut epoch,
new_epoch,
&mut epoch_slice,
new_epoch_slice,
);
let addr_1_tier_info = user_db
.user_tier_info(&ADDR_1, &MockKarmaSc2 {})
.await
.unwrap();
assert_eq!(addr_1_tier_info.epoch_tx_count, addr_1_tx_count);
assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0);
assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic")));
let addr_2_tier_info = user_db
.user_tier_info(&ADDR_2, &MockKarmaSc2 {})
.await
.unwrap();
assert_eq!(addr_2_tier_info.epoch_tx_count, addr_2_tx_count);
assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0);
assert_eq!(
addr_2_tier_info.tier_name,
Some(TierName::from("Power User"))
);
}
// incr epoch (11 -> 12, epoch slice reset)
{
let new_epoch = epoch + 1;
let new_epoch_slice = EpochSlice::from(0);
user_db_service.update_on_epoch_changes(
&mut epoch,
new_epoch,
&mut epoch_slice,
new_epoch_slice,
);
let addr_1_tier_info = user_db
.user_tier_info(&ADDR_1, &MockKarmaSc2 {})
.await
.unwrap();
assert_eq!(addr_1_tier_info.epoch_tx_count, 0);
assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0);
assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic")));
let addr_2_tier_info = user_db
.user_tier_info(&ADDR_2, &MockKarmaSc2 {})
.await
.unwrap();
assert_eq!(addr_2_tier_info.epoch_tx_count, 0);
assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0);
assert_eq!(
addr_2_tier_info.tier_name,
Some(TierName::from("Power User"))
);
}
}
#[test]
#[tracing_test::traced_test]
fn test_set_tier_limits() {
// Check if we can update tier limits (and it updates after an epoch slice change)
let mut epoch = Epoch::from(11);
let mut epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let user_db_service = UserDbService::new(
Default::default(),
epoch_store,
10.into(),
Default::default(),
);
let user_db = user_db_service.get_user_db();
let tier_limits_original = user_db.tier_limits.read().clone();
let tier_limits = BTreeMap::from([
(
TierIndex::from(1),
Tier {
name: "Basic".into(),
min_karma: U256::from(10),
max_karma: U256::from(49),
tx_per_epoch: 5,
active: true,
},
),
(
TierIndex::from(2),
Tier {
name: "Power User".into(),
min_karma: U256::from(50),
max_karma: U256::from(299),
tx_per_epoch: 20,
active: true,
},
),
]);
let tier_limits = TierLimits::from(tier_limits);
user_db.on_new_tier_limits(tier_limits.clone()).unwrap();
// Check it is not yet installed
assert_ne!(*user_db.tier_limits.read(), tier_limits);
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
assert_eq!(*user_db.tier_limits_next.read(), tier_limits);
let new_epoch = epoch;
let new_epoch_slice = epoch_slice + 1;
user_db_service.update_on_epoch_changes(
&mut epoch,
new_epoch,
&mut epoch_slice,
new_epoch_slice,
);
// Should be installed now
assert_eq!(*user_db.tier_limits.read(), tier_limits);
// And the tier_limits_next field is expected to be empty
assert!(user_db.tier_limits_next.read().is_empty());
}
#[test]
#[tracing_test::traced_test]
fn test_set_invalid_tier_limits() {
// Check we cannot update with invalid tier limits
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let user_db_service = UserDbService::new(
Default::default(),
epoch_store,
10.into(),
Default::default(),
);
let user_db = user_db_service.get_user_db();
let tier_limits_original = user_db.tier_limits.read().clone();
// Invalid: non unique index
{
let tier_limits = BTreeMap::from([
(
TierIndex::from(0),
Tier {
min_karma: Default::default(),
max_karma: Default::default(),
name: "Basic".to_string(),
tx_per_epoch: 100,
active: true,
},
),
(
TierIndex::from(0),
Tier {
min_karma: Default::default(),
max_karma: Default::default(),
name: "Power User".to_string(),
tx_per_epoch: 200,
active: true,
},
),
]);
let tier_limits = TierLimits::from(tier_limits);
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
}
// Invalid: min Karma amount not increasing
{
let tier_limits = BTreeMap::from([
(
TierIndex::from(0),
Tier {
min_karma: U256::from(10),
max_karma: U256::from(49),
name: "Basic".to_string(),
tx_per_epoch: 5,
active: true,
},
),
(
TierIndex::from(1),
Tier {
min_karma: U256::from(50),
max_karma: U256::from(99),
name: "Power User".to_string(),
tx_per_epoch: 10,
active: true,
},
),
(
TierIndex::from(2),
Tier {
min_karma: U256::from(60),
max_karma: U256::from(99),
name: "Power User".to_string(),
tx_per_epoch: 15,
active: true,
},
),
]);
let tier_limits = TierLimits::from(tier_limits);
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
}
// Invalid: Non unique tier name
{
let tier_limits = BTreeMap::from([
(
TierIndex::from(0),
Tier {
min_karma: U256::from(10),
max_karma: U256::from(49),
name: "Basic".to_string(),
tx_per_epoch: 5,
active: true,
},
),
(
TierIndex::from(1),
Tier {
min_karma: U256::from(50),
max_karma: U256::from(99),
name: "Power User".to_string(),
tx_per_epoch: 10,
active: true,
},
),
(
TierIndex::from(2),
Tier {
min_karma: U256::from(100),
max_karma: U256::from(999),
name: "Power User".to_string(),
tx_per_epoch: 15,
active: true,
},
),
]);
let tier_limits = TierLimits::from(tier_limits);
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
}
}
#[test]
#[tracing_test::traced_test]
fn test_set_invalid_tier_limits_2() {
// Check we cannot update with invalid tier limits
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let user_db_service = UserDbService::new(
Default::default(),
epoch_store,
10.into(),
Default::default(),
);
let user_db = user_db_service.get_user_db();
let tier_limits_original = user_db.tier_limits.read().clone();
// Invalid: inactive tier
{
let tier_limits = BTreeMap::from([
(
TierIndex::from(0),
Tier {
min_karma: U256::from(10),
max_karma: U256::from(49),
name: "Basic".to_string(),
tx_per_epoch: 5,
active: true,
},
),
(
TierIndex::from(1),
Tier {
min_karma: U256::from(50),
max_karma: U256::from(99),
name: "Power User".to_string(),
tx_per_epoch: 10,
active: true,
},
),
]);
let tier_limits = TierLimits::from(tier_limits);
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
}
// Invalid: non-increasing tx_per_epoch
{
let tier_limits = BTreeMap::from([
(
TierIndex::from(0),
Tier {
min_karma: U256::from(10),
max_karma: U256::from(49),
name: "Basic".to_string(),
tx_per_epoch: 5,
active: true,
},
),
(
TierIndex::from(1),
Tier {
min_karma: U256::from(50),
max_karma: U256::from(99),
name: "Power User".to_string(),
tx_per_epoch: 5,
active: true,
},
),
]);
let tier_limits = TierLimits::from(tier_limits);
assert_err!(user_db.on_new_tier_limits(tier_limits.clone()));
assert_eq!(*user_db.tier_limits.read(), tier_limits_original);
}
}
}

View File

@@ -0,0 +1,37 @@
use ark_bn254::Fr;
use derive_more::{Add, From, Into};
/// A wrapper type over u64
#[derive(Debug, Clone, Copy, From, Into, PartialEq)]
pub(crate) struct MerkleTreeIndex(u64);
impl From<MerkleTreeIndex> for usize {
fn from(value: MerkleTreeIndex) -> Self {
// TODO: compile time assert
value.0 as usize
}
}
/// A wrapper type over u64
#[derive(Debug, Clone, Copy, Default, PartialOrd, PartialEq, From, Into)]
pub struct RateLimit(u64);
impl RateLimit {
pub(crate) const ZERO: RateLimit = RateLimit(0);
pub(crate) const fn new(value: u64) -> Self {
Self(value)
}
}
impl From<RateLimit> for Fr {
fn from(rate_limit: RateLimit) -> Self {
Fr::from(rate_limit.0)
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
pub(crate) struct EpochCounter(u64);
#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)]
pub(crate) struct EpochSliceCounter(u64);

View File

@@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024"
[dependencies]
rln = { git = "https://github.com/vacp2p/zerokit", package = "rln", default-features = false }
rln = { git = "https://github.com/vacp2p/zerokit", package = "rln", features = ["pmtree-ft"] }
zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] }
ark-bn254.workspace = true
ark-groth16.workspace = true

View File

@@ -6,6 +6,7 @@ use criterion::{Criterion, criterion_group, criterion_main};
use ark_bn254::Fr;
use ark_serialize::CanonicalSerialize;
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::poseidon_tree::PoseidonTree;
use rln::protocol::{keygen, serialize_proof_values};
use zerokit_utils::OptimalMerkleTree;
// internal
@@ -29,7 +30,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
// Merkle tree
let tree_height = 20;
let mut tree = OptimalMerkleTree::new(tree_height, Fr::from(0), Default::default()).unwrap();
let mut tree = PoseidonTree::new(tree_height, Fr::from(0), Default::default()).unwrap();
let rate_commit = poseidon_hash(&[rln_identity.commitment, rln_identity.user_limit]);
tree.set(0, rate_commit).unwrap();
let merkle_proof = tree.proof(0).unwrap();

View File

@@ -15,7 +15,7 @@ use rln::{
};
/// A RLN user identity & limit
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct RlnUserIdentity {
pub commitment: Fr,
pub secret_hash: Fr,
@@ -42,7 +42,6 @@ pub struct RlnIdentifier {
impl RlnIdentifier {
pub fn new(identifier: &[u8]) -> Self {
// TODO: valid / correct ?
let pk_and_matrices = {
let mut reader = Cursor::new(ZKEY_BYTES);
read_zkey(&mut reader).unwrap()
@@ -96,8 +95,9 @@ pub fn compute_rln_proof_and_values(
#[cfg(test)]
mod tests {
use super::*;
use rln::poseidon_tree::PoseidonTree;
use rln::protocol::{compute_id_secret, keygen};
use zerokit_utils::{OptimalMerkleTree, ZerokitMerkleTree};
use zerokit_utils::ZerokitMerkleTree;
#[test]
fn test_recover_secret_hash() {
@@ -105,7 +105,8 @@ mod tests {
let epoch = hash_to_field(b"foo");
let spam_limit = Fr::from(10);
let mut tree = OptimalMerkleTree::new(20, Default::default(), Default::default()).unwrap();
// let mut tree = OptimalMerkleTree::new(20, Default::default(), Default::default()).unwrap();
let mut tree = PoseidonTree::new(20, Default::default(), Default::default()).unwrap();
tree.set(0, spam_limit).unwrap();
let m_proof = tree.proof(0).unwrap();

View File

@@ -61,6 +61,12 @@ impl KarmaTiersSC::KarmaTiersSCInstance<AlloyWsProvider> {
#[derive(Debug, Clone, Default, Copy, From, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TierIndex(u8);
impl From<&TierIndex> for u8 {
fn from(value: &TierIndex) -> u8 {
value.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Tier {
pub min_karma: U256,