From 9f4027ed2bd5e6b3e1fe32affd3c9ffd69e3cfdb Mon Sep 17 00:00:00 2001 From: Sydhds Date: Fri, 27 Jun 2025 15:41:41 +0200 Subject: [PATCH] =?UTF-8?q?Add=20UserRocksDb=20-=20initial=20attemp=20to?= =?UTF-8?q?=20rewrite=20UserDb=20with=20persistent=20st=E2=80=A6=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use RocksDb + PmTree for UserDb persistent storage --- Cargo.lock | 188 ++++- Cargo.toml | 1 + proto/net/vac/prover/prover.proto | 2 +- prover/Cargo.toml | 10 +- prover/src/args.rs | 8 + prover/src/epoch_service.rs | 33 +- prover/src/error.rs | 33 +- prover/src/grpc_service.rs | 11 +- prover/src/main.rs | 14 +- prover/src/proof_service.rs | 52 +- prover/src/registry_listener.rs | 15 +- prover/src/rocksdb_operands.rs | 380 ++++++++++ prover/src/tier.rs | 42 +- prover/src/tiers_listener.rs | 2 +- prover/src/user_db.rs | 1018 +++++++++++++++++++++++++++ prover/src/user_db_error.rs | 78 ++ prover/src/user_db_serialization.rs | 317 +++++++++ prover/src/user_db_service.rs | 877 +---------------------- prover/src/user_db_types.rs | 37 + rln_proof/Cargo.toml | 2 +- rln_proof/benches/generate_proof.rs | 3 +- rln_proof/src/proof.rs | 9 +- smart_contract/src/karma_tiers.rs | 6 + 23 files changed, 2174 insertions(+), 964 deletions(-) create mode 100644 prover/src/rocksdb_operands.rs create mode 100644 prover/src/user_db.rs create mode 100644 prover/src/user_db_error.rs create mode 100644 prover/src/user_db_serialization.rs create mode 100644 prover/src/user_db_types.rs diff --git a/Cargo.lock b/Cargo.lock index ecf3058..1985e3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,7 +331,7 @@ dependencies = [ "proptest", "rand 0.9.1", "ruint", - "rustc-hash", + "rustc-hash 2.1.1", "serde", "sha3", "tiny-keccak", @@ -1272,6 +1272,26 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230c5f1ca6a325a32553f8640d31ac9b49f2411e901e427570154868b46da4f7" +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.10.5", + "lazy_static", + "lazycell", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.100", +] + [[package]] name = "bit-set" version = "0.8.0" @@ -1390,6 +1410,16 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3c8f83209414aacf0eeae3cf730b18d6981697fba62f200fcfb92b9f082acba" +[[package]] +name = "bzip2-sys" +version = "0.1.13+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "c-kzg" version = "2.1.1" @@ -1417,9 +1447,20 @@ version = "1.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04da6a0d40b948dfc4fa8f5bbf402b0fc1a64a28dbf7d12ffd683550f2c1b63a" dependencies = [ + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom 7.1.3", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -1474,6 +1515,17 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.40" @@ -2704,6 +2756,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jobserver" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +dependencies = [ + "getrandom 0.3.2", + "libc", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -2753,18 +2815,59 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.53.0", +] + [[package]] name = "libm" version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72" +[[package]] +name = "librocksdb-sys" +version = "0.17.2+9.10.0" +source = "git+https://github.com/tillrohrmann/rust-rocksdb?branch=issues%2F836#9c2482629ac4a3f73e622fc6be1b6458b6bae8cc" +dependencies = [ + "bindgen", + "bzip2-sys", + "cc", + "libc", + "libz-sys", + "lz4-sys", + "zstd-sys", +] + +[[package]] +name = "libz-sys" +version = "1.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -2802,6 +2905,16 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "macro-string" version = "0.1.4" @@ -2852,6 +2965,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.4" @@ -2904,6 +3023,25 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -3377,7 +3515,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck", - "itertools 0.10.5", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -3397,7 +3535,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.100", @@ -3709,6 +3847,15 @@ dependencies = [ "rustc-hex", ] +[[package]] +name = "rocksdb" +version = "0.23.0" +source = "git+https://github.com/tillrohrmann/rust-rocksdb?branch=issues%2F836#9c2482629ac4a3f73e622fc6be1b6458b6bae8cc" +dependencies = [ + "libc", + "librocksdb-sys", +] + [[package]] name = "ruint" version = "1.15.0" @@ -3748,6 +3895,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -4217,6 +4370,7 @@ version = "0.1.0" dependencies = [ "alloy", "ark-bn254", + "ark-ff 0.5.0", "ark-groth16", "ark-serialize 0.5.0", "async-channel", @@ -4230,16 +4384,19 @@ dependencies = [ "derive_more", "futures", "http", + "nom 8.0.0", "num-bigint", "parking_lot 0.12.4", "prost", "rand 0.8.5", "rln", "rln_proof", + "rocksdb", "scc", "serde", "serde_json", "smart_contract", + "tempfile", "thiserror 2.0.12", "tokio", "tonic", @@ -4251,6 +4408,7 @@ dependencies = [ "tracing-subscriber 0.3.19", "tracing-test", "url", + "zerokit_utils", ] [[package]] @@ -4349,9 +4507,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", "getrandom 0.3.2", @@ -4915,6 +5073,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "vacp2p_pmtree" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632293f506ca10d412dbe1d427295317b4c794fa9ddfd66fbd2fa971de88c1f6" +dependencies = [ + "rayon", +] + [[package]] name = "valuable" version = "0.1.1" @@ -5521,6 +5688,7 @@ dependencies = [ "serde_json", "sled", "thiserror 2.0.12", + "vacp2p_pmtree", ] [[package]] @@ -5544,3 +5712,13 @@ dependencies = [ "quote", "syn 2.0.100", ] + +[[package]] +name = "zstd-sys" +version = "2.0.15+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 24cce6a..0435789 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ resolver = "2" ark-bn254 = { version = "0.5", features = ["std"] } ark-serialize = "0.5" ark-groth16 = "0.5" +ark-ff = "0.5" url = "2.5.4" alloy = { version = "1.0", features = ["getrandom", "sol-types", "contract", "provider-ws"] } async-trait = "0.1" diff --git a/proto/net/vac/prover/prover.proto b/proto/net/vac/prover/prover.proto index 2e70a19..bc65fb8 100644 --- a/proto/net/vac/prover/prover.proto +++ b/proto/net/vac/prover/prover.proto @@ -211,7 +211,7 @@ message UserTierInfoResult { message Tier { string name = 1; - uint64 quota = 2; + uint32 quota = 2; } /* diff --git a/prover/Cargo.toml b/prover/Cargo.toml index bc3bef9..ad04782 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -16,9 +16,9 @@ tracing-test = "0.2.5" alloy.workspace = true thiserror = "2.0" futures = "0.3" -rln = { git = "https://github.com/vacp2p/zerokit", default-features = false } ark-bn254.workspace = true ark-serialize.workspace = true +ark-ff.workspace = true dashmap = "6.1.0" scc = "2.3" bytesize = "2.0.1" @@ -34,6 +34,12 @@ num-bigint = "0.4" async-trait.workspace = true serde = { version="1", features = ["derive"] } serde_json = "1.0" +# rocksdb = "0.23" +rocksdb = { git = "https://github.com/tillrohrmann/rust-rocksdb", branch="issues/836" } +nom = "8.0" +claims = "0.8" +rln = { git = "https://github.com/vacp2p/zerokit", features = ["pmtree-ft"] } +zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] } rln_proof = { path = "../rln_proof" } smart_contract = { path = "../smart_contract" } @@ -42,8 +48,8 @@ tonic-build = "*" [dev-dependencies] criterion.workspace = true -claims = "0.8" ark-groth16.workspace = true +tempfile = "3.20" [[bench]] name = "user_db_heavy_write" diff --git a/prover/src/args.rs b/prover/src/args.rs index b1ccf77..e67c244 100644 --- a/prover/src/args.rs +++ b/prover/src/args.rs @@ -23,6 +23,14 @@ pub struct AppArgs { help = "Websocket rpc url (e.g. wss://eth-mainnet.g.alchemy.com/v2/your-api-key)" )] pub(crate) ws_rpc_url: Option, + #[arg(long = "db", help = "Db path", default_value = "./storage/db")] + pub(crate) db_path: PathBuf, + #[arg( + long = "tree", + help = "Merkle tree path", + default_value = "./storage/tree" + )] + pub(crate) merkle_tree_path: PathBuf, #[arg(short = 'k', long = "ksc", help = "Karma smart contract address")] pub(crate) ksc_address: Option
, #[arg(short = 'r', long = "rlnsc", help = "RLN smart contract address")] diff --git a/prover/src/epoch_service.rs b/prover/src/epoch_service.rs index 98b2f73..5678a41 100644 --- a/prover/src/epoch_service.rs +++ b/prover/src/epoch_service.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::time::Duration; // third-party use chrono::{DateTime, NaiveDate, NaiveDateTime, OutOfRangeError, TimeDelta, Utc}; +use derive_more::{Deref, From, Into}; use parking_lot::RwLock; use tokio::sync::Notify; use tracing::debug; @@ -211,8 +212,8 @@ pub enum WaitUntilError { } /// An Epoch (wrapper type over i64) -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct Epoch(pub(crate) i64); +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into, Deref)] +pub(crate) struct Epoch(i64); impl Add for Epoch { type Output = Self; @@ -222,21 +223,9 @@ impl Add for Epoch { } } -impl From for Epoch { - fn from(value: i64) -> Self { - Self(value) - } -} - -impl From for i64 { - fn from(value: Epoch) -> Self { - value.0 - } -} - /// An Epoch slice (wrapper type over i64) -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct EpochSlice(pub(crate) i64); +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From, Into, Deref)] +pub(crate) struct EpochSlice(i64); impl Add for EpochSlice { type Output = Self; @@ -246,18 +235,6 @@ impl Add for EpochSlice { } } -impl From for EpochSlice { - fn from(value: i64) -> Self { - Self(value) - } -} - -impl From for i64 { - fn from(value: EpochSlice) -> Self { - value.0 - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/prover/src/error.rs b/prover/src/error.rs index 5332b59..4fcc27d 100644 --- a/prover/src/error.rs +++ b/prover/src/error.rs @@ -1,10 +1,9 @@ -use crate::epoch_service::WaitUntilError; -use alloy::{ - primitives::Address, - transports::{RpcError, TransportErrorKind}, -}; +use alloy::transports::{RpcError, TransportErrorKind}; use ark_serialize::SerializationError; use rln::error::ProofError; +// internal +use crate::epoch_service::WaitUntilError; +use crate::user_db_error::{RegisterError, UserMerkleTreeIndexError}; #[derive(thiserror::Error, Debug)] pub enum AppError { @@ -53,33 +52,19 @@ impl From for ProofGenerationStringError { fn from(value: ProofGenerationError) -> Self { match value { ProofGenerationError::Proof(e) => ProofGenerationStringError::Proof(e.to_string()), - ProofGenerationError::Serialization(e) => { - ProofGenerationStringError::Serialization(e.to_string()) - } - ProofGenerationError::SerializationWrite(e) => { - ProofGenerationStringError::SerializationWrite(e.to_string()) - } - ProofGenerationError::MerkleProofError(e) => { - ProofGenerationStringError::MerkleProofError(e) - } + ProofGenerationError::Serialization(e) => Self::Serialization(e.to_string()), + ProofGenerationError::SerializationWrite(e) => Self::SerializationWrite(e.to_string()), + ProofGenerationError::MerkleProofError(e) => Self::MerkleProofError(e), } } } -#[derive(thiserror::Error, Debug)] -pub enum RegisterError { - #[error("User (address: {0:?}) has already been registered")] - AlreadyRegistered(Address), - #[error("Merkle tree error: {0}")] - TreeError(String), -} - #[derive(thiserror::Error, Debug, Clone)] pub enum GetMerkleTreeProofError { - #[error("User not registered")] - NotRegistered, #[error("Merkle tree error: {0}")] TreeError(String), + #[error(transparent)] + MerkleTree(#[from] UserMerkleTreeIndexError), } #[derive(thiserror::Error, Debug)] diff --git a/prover/src/grpc_service.rs b/prover/src/grpc_service.rs index 942e377..b030090 100644 --- a/prover/src/grpc_service.rs +++ b/prover/src/grpc_service.rs @@ -18,9 +18,9 @@ use tower_http::cors::{Any, CorsLayer}; use tracing::debug; use url::Url; // internal -use crate::error::{AppError, ProofGenerationStringError, RegisterError}; +use crate::error::{AppError, ProofGenerationStringError}; use crate::proof_generation::{ProofGenerationData, ProofSendingData}; -use crate::user_db_service::{UserDb, UserTierInfo}; +use crate::user_db::{UserDb, UserTierInfo}; use rln_proof::RlnIdentifier; use smart_contract::{ KarmaAmountExt, @@ -39,6 +39,7 @@ pub mod prover_proto { pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = tonic::include_file_descriptor_set!("prover_descriptor"); } +use crate::user_db_error::RegisterError; use prover_proto::{ GetUserTierInfoReply, GetUserTierInfoRequest, @@ -165,7 +166,7 @@ where return Err(Status::invalid_argument("No sender address")); }; - let result = self.user_db.on_new_user(user); + let result = self.user_db.on_new_user(&user); let status = match result { Ok(id_commitment) => { @@ -466,11 +467,11 @@ impl From for UserTierInfoResult { } /// UserTierInfoError to UserTierInfoError (Grpc message) conversion -impl From> for UserTierInfoError +impl From> for UserTierInfoError where E: std::error::Error, { - fn from(value: crate::user_db_service::UserTierInfoError) -> Self { + fn from(value: crate::user_db_error::UserTierInfoError) -> Self { UserTierInfoError { message: value.to_string(), } diff --git a/prover/src/main.rs b/prover/src/main.rs index 748ec6f..495958d 100644 --- a/prover/src/main.rs +++ b/prover/src/main.rs @@ -7,9 +7,14 @@ mod mock; mod proof_generation; mod proof_service; mod registry_listener; +mod rocksdb_operands; mod tier; mod tiers_listener; +mod user_db; +mod user_db_error; +mod user_db_serialization; mod user_db_service; +mod user_db_types; // std use std::net::SocketAddr; @@ -38,7 +43,8 @@ use crate::proof_service::ProofService; use crate::registry_listener::RegistryListener; use crate::tier::TierLimits; use crate::tiers_listener::TiersListener; -use crate::user_db_service::{RateLimit, UserDbService}; +use crate::user_db_service::UserDbService; +use crate::user_db_types::RateLimit; const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier"; const PROVER_SPAM_LIMIT: RateLimit = RateLimit::new(10_000u64); @@ -98,11 +104,13 @@ async fn main() -> Result<(), Box> { // User db service let user_db_service = UserDbService::new( + app_args.db_path, + app_args.merkle_tree_path, epoch_service.epoch_changes.clone(), epoch_service.current_epoch.clone(), PROVER_SPAM_LIMIT, tier_limits, - ); + )?; if app_args.mock_sc.is_some() { if let Some(user_filepath) = app_args.mock_user.as_ref() { @@ -114,7 +122,7 @@ async fn main() -> Result<(), Box> { mock_user.address, mock_user.tx_count ); let user_db = user_db_service.get_user_db(); - user_db.on_new_user(mock_user.address).unwrap(); + user_db.on_new_user(&mock_user.address).unwrap(); user_db .on_new_tx(&mock_user.address, Some(mock_user.tx_count)) .unwrap(); diff --git a/prover/src/proof_service.rs b/prover/src/proof_service.rs index f3bf4ba..a139718 100644 --- a/prover/src/proof_service.rs +++ b/prover/src/proof_service.rs @@ -12,7 +12,8 @@ use tracing::{debug, info}; use crate::epoch_service::{Epoch, EpochSlice}; use crate::error::{AppError, ProofGenerationError, ProofGenerationStringError}; use crate::proof_generation::{ProofGenerationData, ProofSendingData}; -use crate::user_db_service::{RateLimit, UserDb}; +use crate::user_db::UserDb; +use crate::user_db_types::RateLimit; use rln_proof::{RlnData, compute_rln_proof_and_values}; const PROOF_SIZE: usize = 512; @@ -82,8 +83,8 @@ impl ProofService { }; let epoch_bytes = { - let mut v = i64::from(current_epoch).to_be_bytes().to_vec(); - v.extend(i64::from(current_epoch_slice).to_be_bytes()); + let mut v = current_epoch.to_le_bytes().to_vec(); + v.extend(current_epoch_slice.to_le_bytes()); v }; let epoch = hash_to_field(epoch_bytes.as_slice()); @@ -140,6 +141,7 @@ impl ProofService { #[cfg(test)] mod tests { use super::*; + use std::path::PathBuf; // third-party use alloy::primitives::{Address, address}; use ark_groth16::{Proof as ArkProof, Proof, VerifyingKey}; @@ -256,15 +258,20 @@ mod tests { let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); // User db + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); let user_db_service = UserDbService::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), Default::default(), epoch_store.clone(), 10.into(), Default::default(), - ); + ) + .unwrap(); let user_db = user_db_service.get_user_db(); - user_db.on_new_user(ADDR_1).unwrap(); - user_db.on_new_user(ADDR_2).unwrap(); + user_db.on_new_user(&ADDR_1).unwrap(); + user_db.on_new_user(&ADDR_2).unwrap(); let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); @@ -308,14 +315,19 @@ mod tests { let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); // User db + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); let user_db_service = UserDbService::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), Default::default(), epoch_store.clone(), 10.into(), Default::default(), - ); + ) + .unwrap(); let user_db = user_db_service.get_user_db(); - user_db.on_new_user(ADDR_1).unwrap(); + user_db.on_new_user(&ADDR_1).unwrap(); // user_db.on_new_user(ADDR_2).unwrap(); let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); @@ -463,16 +475,21 @@ mod tests { let rate_limit = RateLimit::from(1); // User db + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); let user_db_service = UserDbService::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), Default::default(), epoch_store.clone(), rate_limit, Default::default(), - ); + ) + .unwrap(); let user_db = user_db_service.get_user_db(); - user_db.on_new_user(ADDR_1).unwrap(); + user_db.on_new_user(&ADDR_1).unwrap(); let user_addr_1 = user_db.get_user(&ADDR_1).unwrap(); - user_db.on_new_user(ADDR_2).unwrap(); + user_db.on_new_user(&ADDR_2).unwrap(); let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); @@ -528,17 +545,22 @@ mod tests { let rate_limit = RateLimit::from(1); // User db - limit is 1 message per epoch + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); let user_db_service = UserDbService::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), Default::default(), epoch_store.clone(), - rate_limit.into(), + rate_limit, Default::default(), - ); + ) + .unwrap(); let user_db = user_db_service.get_user_db(); - user_db.on_new_user(ADDR_1).unwrap(); + user_db.on_new_user(&ADDR_1).unwrap(); let user_addr_1 = user_db.get_user(&ADDR_1).unwrap(); debug!("user_addr_1: {:?}", user_addr_1); - user_db.on_new_user(ADDR_2).unwrap(); + user_db.on_new_user(&ADDR_2).unwrap(); let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); diff --git a/prover/src/registry_listener.rs b/prover/src/registry_listener.rs index a7f0916..9e50bbe 100644 --- a/prover/src/registry_listener.rs +++ b/prover/src/registry_listener.rs @@ -9,8 +9,9 @@ use alloy::{ use tonic::codegen::tokio_stream::StreamExt; use tracing::{debug, error, info}; // internal -use crate::error::{AppError, HandleTransferError, RegisterError}; -use crate::user_db_service::UserDb; +use crate::error::{AppError, HandleTransferError}; +use crate::user_db::UserDb; +use crate::user_db_error::RegisterError; use smart_contract::{AlloyWsProvider, KarmaAmountExt, KarmaSC}; pub(crate) struct RegistryListener { @@ -113,7 +114,7 @@ impl RegistryListener { if should_register { self.user_db - .on_new_user(to_address) + .on_new_user(&to_address) .map_err(HandleTransferError::Register)?; } } @@ -125,6 +126,7 @@ impl RegistryListener { #[cfg(test)] mod tests { use super::*; + use std::path::PathBuf; // std use std::sync::Arc; // third-party @@ -152,12 +154,17 @@ mod tests { let epoch = Epoch::from(11); let epoch_slice = EpochSlice::from(42); let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); let user_db_service = UserDbService::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), Default::default(), epoch_store, 10.into(), Default::default(), - ); + ) + .unwrap(); let user_db = user_db_service.get_user_db(); assert!(user_db_service.get_user_db().get_user(&ADDR_2).is_none()); diff --git a/prover/src/rocksdb_operands.rs b/prover/src/rocksdb_operands.rs new file mode 100644 index 0000000..5326809 --- /dev/null +++ b/prover/src/rocksdb_operands.rs @@ -0,0 +1,380 @@ +use crate::epoch_service::{Epoch, EpochSlice}; +use nom::{ + IResult, + error::ContextError, + number::complete::{le_i64, le_u64}, +}; +use rocksdb::MergeOperands; + +#[derive(Debug, PartialEq)] +pub enum DeserializeError { + Nom(I, nom::error::ErrorKind), +} + +impl nom::error::ParseError for DeserializeError { + fn from_error_kind(input: I, kind: nom::error::ErrorKind) -> Self { + DeserializeError::Nom(input, kind) + } + + fn append(_: I, _: nom::error::ErrorKind, other: Self) -> Self { + other + } +} + +impl ContextError for DeserializeError {} + +#[derive(Debug, Default, PartialEq)] +pub struct EpochCounters { + pub epoch: Epoch, + pub epoch_slice: EpochSlice, + pub epoch_counter: u64, + pub epoch_slice_counter: u64, +} + +pub struct EpochCounterSerializer {} + +impl EpochCounterSerializer { + fn serialize(&self, value: &EpochCounters, buffer: &mut Vec) { + buffer.extend(value.epoch.to_le_bytes()); + buffer.extend(value.epoch_slice.to_le_bytes()); + buffer.extend(value.epoch_counter.to_le_bytes()); + buffer.extend(value.epoch_slice_counter.to_le_bytes()); + } + + pub(crate) const fn size_hint_() -> usize { + size_of::() + } + + pub(crate) fn size_hint(&self) -> usize { + Self::size_hint_() + } + + pub const fn default() -> [u8; Self::size_hint_()] { + [0u8; Self::size_hint_()] + } +} + +pub struct EpochCounterDeserializer {} + +impl EpochCounterDeserializer { + pub fn deserialize<'a>( + &self, + buffer: &'a [u8], + ) -> IResult<&'a [u8], EpochCounters, DeserializeError<&'a [u8]>> { + let (input, epoch) = le_i64(buffer).map(|(i, e)| (i, Epoch::from(e)))?; + let (input, epoch_slice) = le_i64(input).map(|(i, es)| (i, EpochSlice::from(es)))?; + let (input, epoch_counter) = le_u64(input)?; + let (_input, epoch_slice_counter) = le_u64(input)?; + Ok(( + input, + EpochCounters { + epoch, + epoch_slice, + epoch_counter, + epoch_slice_counter, + }, + )) + } +} + +#[derive(Debug, Default, PartialEq)] +pub struct EpochIncr { + pub epoch: Epoch, + pub epoch_slice: EpochSlice, + pub incr_value: u64, +} + +pub struct EpochIncrSerializer {} + +impl EpochIncrSerializer { + pub fn serialize(&self, value: &EpochIncr, buffer: &mut Vec) { + buffer.extend(value.epoch.to_le_bytes()); + buffer.extend(value.epoch_slice.to_le_bytes()); + buffer.extend(value.incr_value.to_le_bytes()); + } + + pub fn size_hint(&self) -> usize { + size_of::() * 3 + } +} + +pub struct EpochIncrDeserializer {} + +impl EpochIncrDeserializer { + pub fn deserialize<'a>( + &self, + buffer: &'a [u8], + ) -> IResult<&'a [u8], EpochIncr, DeserializeError<&'a [u8]>> { + let (input, epoch) = le_i64(buffer).map(|(i, e)| (i, Epoch::from(e)))?; + let (input, epoch_slice) = le_i64(input).map(|(i, es)| (i, EpochSlice::from(es)))?; + let (input, incr_value) = le_u64(input)?; + Ok(( + input, + EpochIncr { + epoch, + epoch_slice, + incr_value, + }, + )) + } +} + +pub fn epoch_counters_operands( + _key: &[u8], + existing_val: Option<&[u8]>, + operands: &MergeOperands, +) -> Option> { + let counter_ser = EpochCounterSerializer {}; + let counter_deser = EpochCounterDeserializer {}; + let ser = EpochIncrSerializer {}; + let deser = EpochIncrDeserializer {}; + + // Current epoch counter structure (stored in DB) + let epoch_counter_current = counter_deser + .deserialize(existing_val.unwrap_or_default()) + .map(|(_, c)| c) + .unwrap_or_default(); + + // Iter over merge operands (can have multiple one with DBBatch) + let counter_value = operands.iter().fold(epoch_counter_current, |mut acc, x| { + // Note: unwrap on EpochIncr deserialize error - serialization is done by the prover + // thus no error should never happen here + let (_, epoch_incr) = deser.deserialize(x).unwrap(); + + // TODO - optim: partial deser ? + // TODO: check if increasing ? debug_assert otherwise? + if acc == Default::default() { + // Default value - so this is the first time + acc = EpochCounters { + epoch: epoch_incr.epoch, + epoch_slice: epoch_incr.epoch_slice, + epoch_counter: epoch_incr.incr_value, + epoch_slice_counter: epoch_incr.incr_value, + } + } else if epoch_incr.epoch != acc.epoch { + // New epoch + acc = EpochCounters { + epoch: epoch_incr.epoch, + epoch_slice: Default::default(), + epoch_counter: epoch_incr.incr_value, + epoch_slice_counter: epoch_incr.incr_value, + } + } else if epoch_incr.epoch_slice != acc.epoch_slice { + // New epoch slice + acc = EpochCounters { + epoch: epoch_incr.epoch, + epoch_slice: epoch_incr.epoch_slice, + epoch_counter: acc.epoch_counter.saturating_add(epoch_incr.incr_value), + epoch_slice_counter: epoch_incr.incr_value, + } + } else { + acc = EpochCounters { + epoch: acc.epoch, + epoch_slice: acc.epoch_slice, + epoch_counter: acc.epoch_counter.saturating_add(epoch_incr.incr_value), + epoch_slice_counter: acc + .epoch_slice_counter + .saturating_add(epoch_incr.incr_value), + } + } + + acc + }); + + let mut buffer = Vec::with_capacity(ser.size_hint()); + counter_ser.serialize(&counter_value, &mut buffer); + Some(buffer.to_vec()) +} + +pub fn u64_counter_operands( + _key: &[u8], + existing_val: Option<&[u8]>, + operands: &MergeOperands, +) -> Option> { + let counter_current_value = if let Some(existing_val) = existing_val { + u64::from_le_bytes(existing_val.try_into().unwrap()) + } else { + 0 + }; + + let counter_value = operands.iter().fold(counter_current_value, |mut acc, x| { + let incr_value = u64::from_le_bytes(x.try_into().unwrap()); + acc = acc.saturating_add(incr_value); + acc + }); + + Some(counter_value.to_le_bytes().to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + // std + // third-party + use rocksdb::{DB, Options, WriteBatch}; + use tempfile::TempDir; + + #[test] + fn test_ser_der() { + // EpochCounter struct + { + let epoch_counter = EpochCounters { + epoch: 1.into(), + epoch_slice: 42.into(), + epoch_counter: 12, + epoch_slice_counter: u64::MAX, + }; + + let serializer = EpochCounterSerializer {}; + let mut buffer = Vec::with_capacity(serializer.size_hint()); + serializer.serialize(&epoch_counter, &mut buffer); + + let deserializer = EpochCounterDeserializer {}; + let (_, de) = deserializer.deserialize(&buffer).unwrap(); + assert_eq!(epoch_counter, de); + } + + { + let deserializer = EpochCounterDeserializer {}; + let (_, de) = deserializer + .deserialize(EpochCounterSerializer::default().as_slice()) + .unwrap(); + assert_eq!(EpochCounters::default(), de); + } + + // EpochIncr struct + { + let epoch_incr = EpochIncr { + epoch: 1.into(), + epoch_slice: 42.into(), + incr_value: 1, + }; + + let serializer = EpochIncrSerializer {}; + let mut buffer = Vec::with_capacity(serializer.size_hint()); + serializer.serialize(&epoch_incr, &mut buffer); + + let deserializer = EpochIncrDeserializer {}; + let (_, de) = deserializer.deserialize(&buffer).unwrap(); + assert_eq!(epoch_incr, de); + } + } + + #[test] + fn test_counter() { + let tmp_path = TempDir::new().unwrap().path().to_path_buf(); + let options = { + let mut opts = Options::default(); + opts.create_if_missing(true); + opts.set_merge_operator("o", u64_counter_operands, u64_counter_operands); + opts + }; + let db = DB::open(&options, tmp_path).unwrap(); + let key_1 = "foo1"; + // let key_2 = "baz42"; + + let index = 42u64; + let buffer = index.to_le_bytes(); + + let mut db_batch = WriteBatch::default(); + db_batch.merge(key_1, &buffer); + db_batch.merge(key_1, &buffer); + db.write(db_batch).unwrap(); + + let get_key_1 = db.get(&key_1).unwrap().unwrap(); + let value = u64::from_le_bytes(get_key_1.try_into().unwrap()); + + assert_eq!(value, index * 2); // 2x merge + } + + #[test] + fn test_counters() { + let tmp_path = TempDir::new().unwrap().path().to_path_buf(); + let options = { + let mut opts = Options::default(); + opts.create_if_missing(true); + opts.set_merge_operator("o", epoch_counters_operands, epoch_counters_operands); + opts + }; + let db = DB::open(&options, tmp_path).unwrap(); + let key_1 = "foo1"; + let key_2 = "baz42"; + + let value_1 = EpochIncr { + epoch: 0.into(), + epoch_slice: 0.into(), + incr_value: 2, + }; + let epoch_incr_ser = EpochIncrSerializer {}; + let epoch_counter_deser = EpochCounterDeserializer {}; + + let mut buffer = Vec::with_capacity(epoch_incr_ser.size_hint()); + epoch_incr_ser.serialize(&value_1, &mut buffer); + let mut db_batch = WriteBatch::default(); + db_batch.merge(key_1, &buffer); + db_batch.merge(key_1, &buffer); + db.write(db_batch).unwrap(); + + let get_key_1 = db.get(&key_1).unwrap().unwrap(); + let (_, get_value_k1) = epoch_counter_deser.deserialize(&get_key_1).unwrap(); + + // Applied EpochIncr 2x + assert_eq!(get_value_k1.epoch_counter, 4); + assert_eq!(get_value_k1.epoch_slice_counter, 4); + + let get_key_2 = db.get(&key_2).unwrap(); + assert!(get_key_2.is_none()); + + // new epoch slice + { + let value_2 = EpochIncr { + epoch: 0.into(), + epoch_slice: 1.into(), + incr_value: 1, + }; + + let mut buffer = Vec::with_capacity(epoch_incr_ser.size_hint()); + epoch_incr_ser.serialize(&value_2, &mut buffer); + db.merge(key_1, buffer).unwrap(); + + let get_key_1 = db.get(&key_1).unwrap().unwrap(); + let (_, get_value_2) = epoch_counter_deser.deserialize(&get_key_1).unwrap(); + + assert_eq!( + get_value_2, + EpochCounters { + epoch: 0.into(), + epoch_slice: 1.into(), + epoch_counter: 5, + epoch_slice_counter: 1, + } + ) + } + + // new epoch + { + let value_3 = EpochIncr { + epoch: 1.into(), + epoch_slice: 0.into(), + incr_value: 3, + }; + + let mut buffer = Vec::with_capacity(epoch_incr_ser.size_hint()); + epoch_incr_ser.serialize(&value_3, &mut buffer); + db.merge(key_1, buffer).unwrap(); + + let get_key_1 = db.get(&key_1).unwrap().unwrap(); + let (_, get_value_3) = epoch_counter_deser.deserialize(&get_key_1).unwrap(); + + assert_eq!( + get_value_3, + EpochCounters { + epoch: 1.into(), + epoch_slice: 0.into(), + epoch_counter: 3, + epoch_slice_counter: 3, + } + ) + } + } +} diff --git a/prover/src/tier.rs b/prover/src/tier.rs index 3afc2d7..a62e4e6 100644 --- a/prover/src/tier.rs +++ b/prover/src/tier.rs @@ -1,17 +1,14 @@ use std::collections::{BTreeMap, HashSet}; -use std::ops::{ - ControlFlow, - Deref, DerefMut -}; +use std::ops::{ControlFlow, Deref, DerefMut}; // third-party -use derive_more::{From, Into}; use alloy::primitives::U256; +use derive_more::{From, Into}; // internal -use crate::user_db_service::SetTierLimitsError; +// use crate::user_db_service::SetTierLimitsError; use smart_contract::{Tier, TierIndex}; #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, From, Into)] -pub struct TierLimit(u64); +pub struct TierLimit(u32); #[derive(Debug, Clone, PartialEq, Eq, Hash, From, Into)] pub struct TierName(String); @@ -47,7 +44,7 @@ impl TierLimits { } /// Validate tier limits (unique names, increasing min & max karma ...) - pub(crate) fn validate(&self) -> Result<(), SetTierLimitsError> { + pub(crate) fn validate(&self) -> Result<(), ValidateTierLimitsError> { #[derive(Default)] struct Context<'a> { tier_names: HashSet, @@ -61,30 +58,30 @@ impl TierLimits { .iter() .try_fold(Context::default(), |mut state, (tier_index, tier)| { if !tier.active { - return Err(SetTierLimitsError::InactiveTier); + return Err(ValidateTierLimitsError::InactiveTier); } if *tier_index <= *state.prev_index.unwrap_or(&TierIndex::default()) { - return Err(SetTierLimitsError::InvalidTierIndex); + return Err(ValidateTierLimitsError::InvalidTierIndex); } if tier.min_karma >= tier.max_karma { - return Err(SetTierLimitsError::InvalidMaxAmount( + return Err(ValidateTierLimitsError::InvalidMaxAmount( tier.min_karma, tier.max_karma, )); } if tier.min_karma <= *state.prev_amount.unwrap_or(&U256::ZERO) { - return Err(SetTierLimitsError::InvalidKarmaAmount); + return Err(ValidateTierLimitsError::InvalidKarmaAmount); } if tier.tx_per_epoch <= *state.prev_tx_per_epoch.unwrap_or(&0) { - return Err(SetTierLimitsError::InvalidTierLimit); + return Err(ValidateTierLimitsError::InvalidTierLimit); } if state.tier_names.contains(&tier.name) { - return Err(SetTierLimitsError::NonUniqueTierName); + return Err(ValidateTierLimitsError::NonUniqueTierName); } state.prev_amount = Some(&tier.min_karma); @@ -99,7 +96,6 @@ impl TierLimits { /// Given some karma amount, find the matching Tier pub(crate) fn get_tier_by_karma(&self, karma_amount: &U256) -> Option<(TierIndex, Tier)> { - struct Context<'a> { prev: Option<(&'a TierIndex, &'a Tier)>, } @@ -124,3 +120,19 @@ impl TierLimits { } } } + +#[derive(Debug, thiserror::Error)] +pub enum ValidateTierLimitsError { + #[error("Invalid Karma amount (must be increasing)")] + InvalidKarmaAmount, + #[error("Invalid Karma max amount (min: {0} vs max: {1})")] + InvalidMaxAmount(U256, U256), + #[error("Invalid Tier limit (must be increasing)")] + InvalidTierLimit, + #[error("Invalid Tier index (must be increasing)")] + InvalidTierIndex, + #[error("Non unique Tier name")] + NonUniqueTierName, + #[error("Non active Tier")] + InactiveTier, +} diff --git a/prover/src/tiers_listener.rs b/prover/src/tiers_listener.rs index 4f848ce..4105002 100644 --- a/prover/src/tiers_listener.rs +++ b/prover/src/tiers_listener.rs @@ -9,7 +9,7 @@ use futures::StreamExt; use tracing::error; // internal use crate::error::AppError; -use crate::user_db_service::UserDb; +use crate::user_db::UserDb; use smart_contract::{AlloyWsProvider, KarmaTiersSC, Tier, TierIndex}; pub(crate) struct TiersListener { diff --git a/prover/src/user_db.rs b/prover/src/user_db.rs new file mode 100644 index 0000000..9a0c2ca --- /dev/null +++ b/prover/src/user_db.rs @@ -0,0 +1,1018 @@ +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; +// third-party +use alloy::primitives::{Address, U256}; +use ark_bn254::Fr; +use claims::debug_assert_lt; +use parking_lot::RwLock; +use rln::{ + hashers::poseidon_hash, + pm_tree_adapter::PmtreeConfig, + poseidon_tree::{MerkleProof, PoseidonTree}, + protocol::keygen, +}; +use rocksdb::{ + ColumnFamily, ColumnFamilyDescriptor, DB, Options, ReadOptions, WriteBatchWithIndex, +}; +use serde::{Deserialize, Serialize}; +// internal +use crate::epoch_service::{Epoch, EpochSlice}; +use crate::error::GetMerkleTreeProofError; +use crate::rocksdb_operands::{ + EpochCounterDeserializer, EpochCounterSerializer, EpochIncr, EpochIncrSerializer, + epoch_counters_operands, u64_counter_operands, +}; +use crate::tier::{TierLimit, TierLimits, TierName}; +use crate::user_db_error::{ + MerkleTreeIndexError, RegisterError, SetTierLimitsError, TxCounterError, UserDbOpenError, + UserMerkleTreeIndexError, UserTierInfoError, +}; +use crate::user_db_serialization::{ + MerkleTreeIndexDeserializer, MerkleTreeIndexSerializer, RlnUserIdentityDeserializer, + RlnUserIdentitySerializer, TierDeserializer, TierLimitsDeserializer, TierLimitsSerializer, +}; +use crate::user_db_types::{EpochCounter, EpochSliceCounter, MerkleTreeIndex, RateLimit}; +use rln_proof::{RlnUserIdentity, ZerokitMerkleTree}; +use smart_contract::{KarmaAmountExt, Tier, TierIndex}; + +const MERKLE_TREE_HEIGHT: usize = 20; +pub const USER_CF: &str = "user"; +pub const MERKLE_TREE_COUNTER_CF: &str = "mtree"; +pub const TX_COUNTER_CF: &str = "tx_counter"; +pub const TIER_LIMITS_CF: &str = "tier_limits"; + +const MERKLE_TREE_INDEX_KEY: &[u8; 4] = b"TREE"; +const TIER_LIMITS_KEY: &[u8; 7] = b"CURRENT"; +const TIER_LIMITS_NEXT_KEY: &[u8; 4] = b"NEXT"; + +#[derive(Debug, PartialEq)] +pub struct UserTierInfo { + pub(crate) current_epoch: Epoch, + pub(crate) current_epoch_slice: EpochSlice, + pub(crate) epoch_tx_count: u64, + pub(crate) epoch_slice_tx_count: u64, + pub(crate) karma_amount: U256, + pub(crate) tier_name: Option, + pub(crate) tier_limit: Option, +} + +#[derive(Clone)] +pub(crate) struct UserDb { + db: Arc, + merkle_tree: Arc>, + rate_limit: RateLimit, + pub(crate) epoch_store: Arc>, +} + +impl std::fmt::Debug for UserDb { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt.debug_struct("UserDb") + .field("db", &self.db) + .field("rate limit", &self.rate_limit) + .field("epoch store", &self.epoch_store) + .finish() + } +} + +impl UserDb { + /// Returns a new `UserRocksDB` instance + pub fn new( + db_path: PathBuf, + merkle_tree_path: PathBuf, + epoch_store: Arc>, + tier_limits: TierLimits, + rate_limit: RateLimit, + ) -> Result { + let db_options = { + let mut db_opts = Options::default(); + db_opts.set_max_open_files(820); + db_opts.create_if_missing(true); + db_opts.create_missing_column_families(true); + db_opts + }; + + let mut tx_counter_cf_opts = Options::default(); + tx_counter_cf_opts + .set_merge_operator_associative("counters operator", epoch_counters_operands); + let mut user_mtree_cf_opts = Options::default(); + user_mtree_cf_opts.set_merge_operator_associative("counter operator", u64_counter_operands); + + let db = DB::open_cf_descriptors( + &db_options, + db_path, + vec![ + // Db column for users, key: User address, value: RlnUserIdentity + MerkleTreeIndex + ColumnFamilyDescriptor::new(USER_CF, Options::default()), + // Db column for merkle tree index, key: tree, value: counter + ColumnFamilyDescriptor::new(MERKLE_TREE_COUNTER_CF, user_mtree_cf_opts), + // Db column for user tx counters, key: User address, value: EpochCounters + ColumnFamilyDescriptor::new(TX_COUNTER_CF, tx_counter_cf_opts), + // Db column for tier limits - key: current && next, value: TierLimits + // Note: only 2 keys in this column + ColumnFamilyDescriptor::new(TIER_LIMITS_CF, Options::default()), + ], + )?; + + debug_assert!(tier_limits.validate().is_ok()); + let tier_limits_serializer = TierLimitsSerializer::default(); + let mut buffer = Vec::with_capacity(tier_limits_serializer.size_hint(tier_limits.len())); + tier_limits_serializer.serialize(&tier_limits, &mut buffer)?; + + // unwrap safe - db is always created with this column + let cf = db.cf_handle(TIER_LIMITS_CF).unwrap(); + db.delete_cf(cf, TIER_LIMITS_NEXT_KEY.as_slice())?; + db.put_cf(cf, TIER_LIMITS_KEY.as_slice(), buffer)?; + + let db = Arc::new(db); + + // merkle tree index + + let cf_mtree = db.cf_handle(MERKLE_TREE_COUNTER_CF).unwrap(); + if let Err(e) = Self::get_merkle_tree_index_(db.clone(), cf_mtree) { + match e { + MerkleTreeIndexError::DbUninitialized => { + // Check if the value is already there (e.g. after a restart) + // if not, we create it + db.merge_cf(cf_mtree, MERKLE_TREE_INDEX_KEY, 0u64.to_le_bytes())?; + } + _ => return Err(UserDbOpenError::MerkleTreeIndex(e)), + } + } + + // merkle tree + + #[derive(Serialize, Deserialize)] + struct PmTreeConfigJson { + path: PathBuf, + temporary: bool, + cache_capacity: u64, + flush_every_ms: u64, + mode: String, + use_compression: bool, + } + + let config_ = PmTreeConfigJson { + path: merkle_tree_path, + temporary: false, + cache_capacity: 100_000, + flush_every_ms: 12_000, + mode: "HighThroughput".to_string(), + use_compression: false, + }; + let config_str = serde_json::to_string(&config_)?; + // Note: in Zerokit 0.8 this is the only way to initialize a PmTreeConfig + let config = PmtreeConfig::from_str(config_str.as_str())?; + let tree = PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), config)?; + + Ok(Self { + db, + merkle_tree: Arc::new(RwLock::new(tree)), + rate_limit, + epoch_store, + }) + } + + fn get_user_cf(&self) -> &ColumnFamily { + // unwrap safe - db is always created with this column + self.db.cf_handle(USER_CF).unwrap() + } + + fn get_mtree_cf(&self) -> &ColumnFamily { + // unwrap safe - db is always created with this column + self.db.cf_handle(MERKLE_TREE_COUNTER_CF).unwrap() + } + + fn get_counter_cf(&self) -> &ColumnFamily { + // unwrap safe - db is always created with this column + self.db.cf_handle(TX_COUNTER_CF).unwrap() + } + + fn get_tier_limits_cf(&self) -> &ColumnFamily { + // unwrap safe - db is always created with this column + self.db.cf_handle(TIER_LIMITS_CF).unwrap() + } + + fn register(&self, address: Address) -> Result { + let rln_identity_serializer = RlnUserIdentitySerializer {}; + let merkle_index_serializer = MerkleTreeIndexSerializer {}; + let merkle_index_deserializer = MerkleTreeIndexDeserializer {}; + + let (identity_secret_hash, id_commitment) = keygen(); + + let rln_identity = RlnUserIdentity::from(( + identity_secret_hash, + id_commitment, + Fr::from(self.rate_limit), + )); + + let key = address.as_slice(); + let mut buffer = + vec![0; rln_identity_serializer.size_hint() + merkle_index_serializer.size_hint()]; + + // unwrap safe - this is serialized by the Prover + RlnUserIdentitySerializer is unit tested + rln_identity_serializer + .serialize(&rln_identity, &mut buffer) + .unwrap(); + + let cf_user = self.get_user_cf(); + + let index = match self.db.get_cf(cf_user, key) { + Ok(Some(_)) => { + return Err(RegisterError::AlreadyRegistered(address)); + } + Ok(None) => { + let cf_mtree = self.get_mtree_cf(); + let cf_counter = self.get_counter_cf(); + + // Note: this should be updated with everything added to db_batch + debug_assert_lt!( + MERKLE_TREE_INDEX_KEY.len() + + size_of::() + + (2 * size_of::
()) + + EpochCounterSerializer::size_hint_() + + buffer.len(), + 1024 + ); + let mut db_batch = WriteBatchWithIndex::new(1024, true); + + // Increase merkle tree index + db_batch.merge_cf(cf_mtree, MERKLE_TREE_INDEX_KEY, 1u64.to_le_bytes()); + // Read the new index + // Unwrap safe - just used merge_cf + let batch_read = db_batch + .get_from_batch_and_db_cf( + &*self.db, + cf_mtree, + MERKLE_TREE_INDEX_KEY, + &ReadOptions::default(), + )? + .unwrap(); + // Unwrap safe - serialization is handled by the prover + let (_, new_index) = merkle_index_deserializer + .deserialize(batch_read.as_slice()) + .unwrap(); + + // Add index for user + merkle_index_serializer.serialize(&new_index, &mut buffer); + // Put user + db_batch.put_cf(cf_user, key, buffer.as_slice()); + // Put user tx counter + db_batch.put_cf( + cf_counter, + key, + EpochCounterSerializer::default().as_slice(), + ); + + self.db.write_wbwi(&db_batch).map_err(RegisterError::Db)?; + new_index + } + Err(e) => { + return Err(RegisterError::Db(e)); + } + }; + + let rate_commit = poseidon_hash(&[id_commitment, Fr::from(u64::from(self.rate_limit))]); + // FIXME: what to do if write to merkle tree fails? Should we include this in the Db transaction as well? + self.merkle_tree + .write() + .set(index.into(), rate_commit) + .map_err(|e| RegisterError::TreeError(e.to_string()))?; + + Ok(id_commitment) + } + + fn has_user(&self, address: &Address) -> Result { + let cf_user = self.get_user_cf(); + self.db + .get_pinned_cf(cf_user, address.as_slice()) + .map(|value| value.is_some()) + } + + pub fn get_user(&self, address: &Address) -> Option { + let cf_user = self.get_user_cf(); + let rln_identity_deserializer = RlnUserIdentityDeserializer {}; + match self.db.get_pinned_cf(cf_user, address.as_slice()) { + Ok(Some(value)) => { + // Here we silence the error - this is safe as the prover controls this + rln_identity_deserializer.deserialize(&value).ok() + } + Ok(None) => None, + Err(_e) => None, + } + } + + pub fn get_user_merkle_tree_index( + &self, + address: &Address, + ) -> Result { + let cf_user = self.get_user_cf(); + let rln_identity_serializer = RlnUserIdentitySerializer {}; + let merkle_tree_index_deserializer = MerkleTreeIndexDeserializer {}; + match self.db.get_pinned_cf(cf_user, address.as_slice()) { + Ok(Some(buffer)) => { + // Here we silence the error - this is safe as the prover controls this + let start = rln_identity_serializer.size_hint(); + let (_, index) = merkle_tree_index_deserializer + .deserialize(&buffer[start..]) + .unwrap(); + Ok(index) + } + Ok(None) => Err(UserMerkleTreeIndexError::NotRegistered(*address)), + Err(e) => Err(UserMerkleTreeIndexError::Db(e)), + } + } + + fn incr_tx_counter( + &self, + address: &Address, + incr_value: Option, + ) -> Result { + let incr_value = incr_value.unwrap_or(1); + let cf_counter = self.get_counter_cf(); + + let (epoch, epoch_slice) = *self.epoch_store.read(); + let incr = EpochIncr { + epoch, + epoch_slice, + incr_value, + }; + let incr_ser = EpochIncrSerializer {}; + let mut buffer = Vec::with_capacity(incr_ser.size_hint()); + incr_ser.serialize(&incr, &mut buffer); + + // Create a transaction + // By using a WriteBatchWithIndex, we can "read your own writes" so here we incr then read the new value + // https://rocksdb.org/blog/2015/02/27/write-batch-with-index.html + let mut batch = WriteBatchWithIndex::new(buffer.len() + size_of::
(), true); + batch.merge_cf(cf_counter, address.as_slice(), buffer); + let res = batch.get_from_batch_and_db_cf( + &*self.db, + cf_counter, + address.as_slice(), + &ReadOptions::default(), + )?; + self.db.write_wbwi(&batch).map_err(TxCounterError::Db)?; + let (_, epoch_slice_counter) = self.counters_from_key(address, res)?; + + Ok(epoch_slice_counter) + } + + fn get_tx_counter( + &self, + address: &Address, + ) -> Result<(EpochCounter, EpochSliceCounter), TxCounterError> { + let cf_counter = self.get_counter_cf(); + match self.db.get_cf(cf_counter, address.as_slice()) { + Ok(v) => self.counters_from_key(address, v), + Err(e) => Err(TxCounterError::Db(e)), + } + } + + fn counters_from_key( + &self, + address: &Address, + key: Option>, + ) -> Result<(EpochCounter, EpochSliceCounter), TxCounterError> { + let deserializer = EpochCounterDeserializer {}; + + match key { + Some(value) => { + let (_, counter) = deserializer.deserialize(&value).unwrap(); + let (epoch, epoch_slice) = *self.epoch_store.read(); + + let cmp = (counter.epoch == epoch, counter.epoch_slice == epoch_slice); + + match cmp { + (true, true) => { + // EpochCounter stored in DB == epoch store + // We query for an epoch / epoch slice and this is what is stored in the Db + // Return the counters + Ok(( + counter.epoch_counter.into(), + counter.epoch_slice_counter.into(), + )) + } + (true, false) => { + // EpochCounter.epoch_slice (stored in Db) != epoch_store.epoch_slice + // We query for an epoch slice after what is stored in Db + // This can happen if no Tx has updated the epoch slice counter (yet) + Ok((counter.epoch_counter.into(), EpochSliceCounter::from(0))) + } + (false, true) => { + // EpochCounter.epoch (stored in DB) != epoch_store.epoch + // We query for an epoch after what is stored in Db + // This can happen if no Tx has updated the epoch counter (yet) + Ok((EpochCounter::from(0), EpochSliceCounter::from(0))) + } + (false, false) => { + // EpochCounter (stored in DB) != epoch_store + // Outdated value (both for epoch & epoch slice) + Ok((EpochCounter::from(0), EpochSliceCounter::from(0))) + } + } + } + None => Err(TxCounterError::NotRegistered(*address)), + } + } + + // pub + + pub(crate) fn on_new_epoch(&self) {} + + pub(crate) fn on_new_epoch_slice(&self) {} + + pub fn on_new_user(&self, address: &Address) -> Result { + self.register(*address) + } + + #[cfg(test)] + fn get_merkle_tree_index(&self) -> Result { + let cf_mtree = self.get_mtree_cf(); + Self::get_merkle_tree_index_(self.db.clone(), cf_mtree) + } + + fn get_merkle_tree_index_( + db: Arc, + cf: &ColumnFamily, + ) -> Result { + let deserializer = MerkleTreeIndexDeserializer {}; + + match db.get_cf(cf, MERKLE_TREE_INDEX_KEY) { + Ok(Some(v)) => { + // Unwrap safe - serialization is done by the prover + let (_, index) = deserializer.deserialize(v.as_slice()).unwrap(); + Ok(index) + } + Ok(None) => Err(MerkleTreeIndexError::DbUninitialized), + Err(e) => Err(MerkleTreeIndexError::Db(e)), + } + } + + pub fn get_merkle_proof( + &self, + address: &Address, + ) -> Result { + let index = self + .get_user_merkle_tree_index(address) + .map_err(GetMerkleTreeProofError::MerkleTree)?; + self.merkle_tree + .read() + .proof(index.into()) + .map_err(|e| GetMerkleTreeProofError::TreeError(e.to_string())) + } + + pub(crate) fn on_new_tx( + &self, + address: &Address, + incr_value: Option, + ) -> Result { + let has_user = self.has_user(address).map_err(TxCounterError::Db)?; + + if has_user { + self.incr_tx_counter(address, incr_value) + } else { + Err(TxCounterError::NotRegistered(*address)) + } + } + + fn get_tier_limits(&self) -> Result { + let cf = self.get_tier_limits_cf(); + // Unwrap safe - Db is initialized with valid tier limits + let buffer = self.db.get_cf(cf, TIER_LIMITS_KEY.as_slice())?.unwrap(); + let tier_limits_deserializer = TierLimitsDeserializer { + tier_deserializer: TierDeserializer {}, + }; + + // Unwrap safe - serialized by the prover (should always deserialize) + let (_, tier_limits) = tier_limits_deserializer.deserialize(&buffer).unwrap(); + Ok(tier_limits) + } + + pub(crate) fn on_new_tier( + &self, + tier_index: TierIndex, + tier: Tier, + ) -> Result<(), SetTierLimitsError> { + let mut tier_limits = self.get_tier_limits()?; + tier_limits.insert(tier_index, tier); + tier_limits.validate()?; + + // Serialize + let tier_limits_serializer = TierLimitsSerializer::default(); + let mut buffer = Vec::with_capacity(tier_limits_serializer.size_hint(tier_limits.len())); + // Unwrap safe - already validated - should always serialize + tier_limits_serializer + .serialize(&tier_limits, &mut buffer) + .unwrap(); + + // Write + let cf = self.get_tier_limits_cf(); + self.db + .put_cf(cf, TIER_LIMITS_NEXT_KEY.as_slice(), buffer) + .map_err(SetTierLimitsError::Db) + } + + pub(crate) fn on_tier_updated( + &self, + tier_index: TierIndex, + tier: Tier, + ) -> Result<(), SetTierLimitsError> { + let mut tier_limits = self.get_tier_limits()?; + if !tier_limits.contains_key(&tier_index) { + return Err(SetTierLimitsError::InvalidUpdateTierIndex); + } + + tier_limits.entry(tier_index).and_modify(|e| *e = tier); + tier_limits.validate()?; + + // Serialize + let tier_limits_serializer = TierLimitsSerializer::default(); + let mut buffer = Vec::with_capacity(tier_limits_serializer.size_hint(tier_limits.len())); + // Unwrap safe - already validated - should always serialize + tier_limits_serializer + .serialize(&tier_limits, &mut buffer) + .unwrap(); + + // Write + let cf = self.get_tier_limits_cf(); + self.db + .put_cf(cf, TIER_LIMITS_NEXT_KEY.as_slice(), buffer) + .map_err(SetTierLimitsError::Db)?; + + Ok(()) + } + + /// Get user tier info + pub(crate) async fn user_tier_info>( + &self, + address: &Address, + karma_sc: &KSC, + ) -> Result> { + let has_user = self.has_user(address).map_err(UserTierInfoError::Db)?; + + if !has_user { + return Err(UserTierInfoError::NotRegistered(*address)); + } + + let (epoch_tx_count, epoch_slice_tx_count) = self.get_tx_counter(address)?; + + let karma_amount = karma_sc + .karma_amount(address) + .await + .map_err(|e| UserTierInfoError::Contract(e))?; + + let tier_limits = self.get_tier_limits()?; + let tier_info = tier_limits.get_tier_by_karma(&karma_amount); + + let user_tier_info = { + let (current_epoch, current_epoch_slice) = *self.epoch_store.read(); + let mut t = UserTierInfo { + current_epoch, + current_epoch_slice, + epoch_tx_count: epoch_tx_count.into(), + epoch_slice_tx_count: epoch_slice_tx_count.into(), + karma_amount, + tier_name: None, + tier_limit: None, + }; + if let Some((_tier_index, tier)) = tier_info { + t.tier_name = Some(tier.name.into()); + t.tier_limit = Some(TierLimit::from(tier.tx_per_epoch)); + } + t + }; + + Ok(user_tier_info) + } +} + +#[cfg(test)] +mod tests { + use super::*; + // std + // third-party + use alloy::primitives::address; + use async_trait::async_trait; + use claims::assert_matches; + use derive_more::Display; + + #[derive(Debug, Display, thiserror::Error)] + struct DummyError(); + + struct MockKarmaSc {} + + #[async_trait] + impl KarmaAmountExt for MockKarmaSc { + type Error = DummyError; + + async fn karma_amount(&self, _address: &Address) -> Result { + Ok(U256::from(10)) + } + } + + const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045"); + const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0"); + + /* + struct MockKarmaSc2 {} + + #[async_trait] + impl KarmaAmountExt for MockKarmaSc2 { + type Error = DummyError; + + async fn karma_amount(&self, address: &Address) -> Result { + if address == &ADDR_1 { + Ok(U256::from(10)) + } else if address == &ADDR_2 { + Ok(U256::from(2000)) + } else { + Ok(U256::ZERO) + } + } + } + */ + + #[test] + fn test_user_register() { + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); + let epoch_store = Arc::new(RwLock::new(Default::default())); + let user_db = UserDb::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), + epoch_store, + Default::default(), + Default::default(), + ) + .unwrap(); + + let addr = Address::new([0; 20]); + user_db.register(addr).unwrap(); + assert_matches!( + user_db.register(addr), + Err(RegisterError::AlreadyRegistered(_)) + ); + + assert!(user_db.get_user(&addr).is_some()); + assert_eq!(user_db.get_tx_counter(&addr).unwrap(), (0.into(), 0.into())); + + assert!(user_db.get_user(&ADDR_1).is_none()); + user_db.register(ADDR_1).unwrap(); + + assert!(user_db.get_user(&ADDR_1).is_some()); + assert_eq!(user_db.get_tx_counter(&addr).unwrap(), (0.into(), 0.into())); + user_db.incr_tx_counter(&addr, Some(42)).unwrap(); + assert_eq!( + user_db.get_tx_counter(&addr).unwrap(), + (42.into(), 42.into()) + ); + } + + #[test] + fn test_get_tx_counter() { + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); + let epoch_store = Arc::new(RwLock::new(Default::default())); + let user_db = UserDb::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), + epoch_store, + Default::default(), + Default::default(), + ) + .unwrap(); + + let addr = Address::new([0; 20]); + + user_db.register(addr).unwrap(); + + let (ec, ecs) = user_db.get_tx_counter(&addr).unwrap(); + assert_eq!(ec, 0.into()); + assert_eq!(ecs, 0.into()); + let ecs_2 = user_db.incr_tx_counter(&addr, Some(42)).unwrap(); + assert_eq!(ecs_2, 42.into()); + } + + #[tokio::test] + async fn test_incr_tx_counter() { + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); + let epoch_store = Arc::new(RwLock::new(Default::default())); + let user_db = UserDb::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), + epoch_store, + Default::default(), + Default::default(), + ) + .unwrap(); + + let addr = Address::new([0; 20]); + + // Try to update tx counter without registering first + assert_matches!( + user_db.on_new_tx(&addr, None), + Err(TxCounterError::NotRegistered(_)) + ); + + let tier_info = user_db.user_tier_info(&addr, &MockKarmaSc {}).await; + // User is not registered -> no tier info + assert!(matches!( + tier_info, + Err(UserTierInfoError::NotRegistered(_)) + )); + // Register user + user_db.register(addr).unwrap(); + // Now update user tx counter + assert_eq!( + user_db.on_new_tx(&addr, None), + Ok(EpochSliceCounter::from(1)) + ); + let tier_info = user_db + .user_tier_info(&addr, &MockKarmaSc {}) + .await + .unwrap(); + assert_eq!(tier_info.epoch_tx_count, 1); + assert_eq!(tier_info.epoch_slice_tx_count, 1); + } + + /* + #[tokio::test] + async fn test_update_on_epoch_changes() { + + let temp_folder = tempfile::tempdir().unwrap(); + let mut epoch = Epoch::from(11); + let mut epoch_slice = EpochSlice::from(42); + let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); + let user_db = UserRocksDb::new( + PathBuf::from(temp_folder.path()), + Default::default(), + epoch_store, + ).unwrap(); + + let tier_limits = BTreeMap::from([ + ( + TierIndex::from(1), + Tier { + name: "Basic".into(), + min_karma: U256::from(10), + max_karma: U256::from(49), + tx_per_epoch: 5, + active: true, + }, + ), + ( + TierIndex::from(2), + Tier { + name: "Active".into(), + min_karma: U256::from(50), + max_karma: U256::from(99), + tx_per_epoch: 10, + active: true, + }, + ), + ( + TierIndex::from(3), + Tier { + name: "Regular".into(), + min_karma: U256::from(100), + max_karma: U256::from(499), + tx_per_epoch: 15, + active: true, + }, + ), + ( + TierIndex::from(4), + Tier { + name: "Power User".into(), + min_karma: U256::from(500), + max_karma: U256::from(4999), + tx_per_epoch: 20, + active: true, + }, + ), + ( + TierIndex::from(5), + Tier { + name: "S-Tier".into(), + min_karma: U256::from(5000), + max_karma: U256::from(9999), + tx_per_epoch: 25, + active: true, + }, + ), + ]); + + let tier_limits: TierLimits = tier_limits.into(); + tier_limits.validate().unwrap(); + + let user_db_service = UserDbService2::new( + temp_folder.path().to_path_buf(), + Default::default(), + epoch_store.clone(), + 10.into(), + tier_limits, + ).unwrap(); + let user_db = user_db_service.get_user_db(); + + let addr_1_tx_count = 2; + let addr_2_tx_count = 820; + user_db.register(ADDR_1).unwrap(); + user_db.incr_tx_counter(&ADDR_1, Some(addr_1_tx_count)); + println!("user_db tx counter: {:?}", user_db.get_tx_counter(&ADDR_1)); + user_db.register(ADDR_2).unwrap(); + user_db.incr_tx_counter(&ADDR_2, Some(addr_2_tx_count)); + + // incr epoch slice (42 -> 43) + { + let new_epoch = epoch; + let new_epoch_slice = epoch_slice + 1; + // FIXME: UserRocksDb rely on EpochStore so is there still need for this func? + user_db_service.update_on_epoch_changes( + &mut epoch, + new_epoch, + &mut epoch_slice, + new_epoch_slice, + ); + + let mut guard = epoch_store.write(); + *guard = (new_epoch, epoch_slice); + drop(guard); + + let addr_1_tier_info = user_db + .user_tier_info(&ADDR_1, &MockKarmaSc2 {}) + .await + .unwrap(); + assert_eq!(addr_1_tier_info.epoch_tx_count, addr_1_tx_count); + assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0); + assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic"))); + + let addr_2_tier_info = user_db + .user_tier_info(&ADDR_2, &MockKarmaSc2 {}) + .await + .unwrap(); + assert_eq!(addr_2_tier_info.epoch_tx_count, addr_2_tx_count); + assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0); + assert_eq!( + addr_2_tier_info.tier_name, + Some(TierName::from("Power User")) + ); + } + + // incr epoch (11 -> 12, epoch slice reset) + { + let new_epoch = epoch + 1; + let new_epoch_slice = EpochSlice::from(0); + user_db_service.update_on_epoch_changes( + &mut epoch, + new_epoch, + &mut epoch_slice, + new_epoch_slice, + ); + let mut guard = epoch_store.write(); + *guard = (new_epoch, epoch_slice); + drop(guard); + + let addr_1_tier_info = user_db + .user_tier_info(&ADDR_1, &MockKarmaSc2 {}) + .await + .unwrap(); + assert_eq!(addr_1_tier_info.epoch_tx_count, 0); + assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0); + assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic"))); + + let addr_2_tier_info = user_db + .user_tier_info(&ADDR_2, &MockKarmaSc2 {}) + .await + .unwrap(); + assert_eq!(addr_2_tier_info.epoch_tx_count, 0); + assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0); + assert_eq!( + addr_2_tier_info.tier_name, + Some(TierName::from("Power User")) + ); + } + } + */ + + #[tokio::test] + async fn test_persistent_storage() { + let temp_folder = tempfile::tempdir().unwrap(); + let temp_folder_tree = tempfile::tempdir().unwrap(); + let epoch_store = Arc::new(RwLock::new(Default::default())); + + let addr = Address::new([0; 20]); + { + let user_db = UserDb::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), + epoch_store.clone(), + Default::default(), + Default::default(), + ) + .unwrap(); + + // Register user + assert_eq!( + user_db.get_merkle_tree_index().unwrap(), + MerkleTreeIndex::from(0) + ); + user_db.register(ADDR_1).unwrap(); + assert_eq!( + user_db.get_merkle_tree_index().unwrap(), + MerkleTreeIndex::from(1) + ); + user_db.register(ADDR_2).unwrap(); + assert_eq!( + user_db.get_merkle_tree_index().unwrap(), + MerkleTreeIndex::from(2) + ); + assert_eq!( + user_db.get_user_merkle_tree_index(&ADDR_1).unwrap(), + MerkleTreeIndex::from(1) + ); + assert_eq!( + user_db.get_user_merkle_tree_index(&ADDR_2).unwrap(), + MerkleTreeIndex::from(2) + ); + + assert_eq!( + user_db.on_new_tx(&ADDR_1, Some(2)), + Ok(EpochSliceCounter::from(2)) + ); + assert_eq!( + user_db.on_new_tx(&ADDR_2, Some(1000)), + Ok(EpochSliceCounter::from(1000)) + ); + + // Should be dropped but let's make it explicit + drop(user_db); + } + + { + // Reopen Db and check that is inside + let user_db = UserDb::new( + PathBuf::from(temp_folder.path()), + PathBuf::from(temp_folder_tree.path()), + epoch_store, + Default::default(), + Default::default(), + ) + .unwrap(); + + assert_eq!(user_db.has_user(&addr).unwrap(), false); + assert_eq!(user_db.has_user(&ADDR_1).unwrap(), true); + assert_eq!(user_db.has_user(&ADDR_2).unwrap(), true); + assert_eq!( + user_db.get_tx_counter(&ADDR_1).unwrap(), + (2.into(), 2.into()) + ); + assert_eq!( + user_db.get_tx_counter(&ADDR_2).unwrap(), + (1000.into(), 1000.into()) + ); + + assert_eq!( + user_db.get_merkle_tree_index().unwrap(), + MerkleTreeIndex::from(2) + ); + assert_eq!( + user_db.get_user_merkle_tree_index(&ADDR_1).unwrap(), + MerkleTreeIndex::from(1) + ); + assert_eq!( + user_db.get_user_merkle_tree_index(&ADDR_2).unwrap(), + MerkleTreeIndex::from(2) + ); + } + } + + /* + // Try to update tx counter without registering first + assert_matches!( + user_db.on_new_tx(&addr, None), + Err(TxCounterError::NotRegistered(_)) + ); + + let tier_info = user_db.user_tier_info(&addr, &MockKarmaSc {}).await; + // User is not registered -> no tier info + assert!(matches!( + tier_info, + Err(UserTierInfoError::NotRegistered(_)) + )); + // Register user + user_db.register(addr).unwrap(); + // Now update user tx counter + assert_eq!( + user_db.on_new_tx(&addr, None), + Ok(EpochSliceCounter::from(1)) + ); + let tier_info = user_db + .user_tier_info(&addr, &MockKarmaSc {}) + .await + .unwrap(); + assert_eq!(tier_info.epoch_tx_count, 1); + assert_eq!(tier_info.epoch_slice_tx_count, 1); + */ +} diff --git a/prover/src/user_db_error.rs b/prover/src/user_db_error.rs new file mode 100644 index 0000000..bcfdc62 --- /dev/null +++ b/prover/src/user_db_error.rs @@ -0,0 +1,78 @@ +use std::num::TryFromIntError; +// third-party +use alloy::primitives::Address; +use zerokit_utils::error::{FromConfigError, ZerokitMerkleTreeError}; +// internal +use crate::tier::ValidateTierLimitsError; + +#[derive(Debug, thiserror::Error)] +pub(crate) enum UserDbOpenError { + #[error(transparent)] + RocksDb(#[from] rocksdb::Error), + #[error("Serialization error: {0}")] + Serialization(#[from] TryFromIntError), + #[error(transparent)] + JsonSerialization(#[from] serde_json::Error), + #[error(transparent)] + TreeConfig(#[from] FromConfigError), + #[error(transparent)] + MerkleTree(#[from] ZerokitMerkleTreeError), + #[error(transparent)] + MerkleTreeIndex(#[from] MerkleTreeIndexError), +} + +#[derive(thiserror::Error, Debug)] +pub enum RegisterError { + #[error("User (address: {0:?}) has already been registered")] + AlreadyRegistered(Address), + #[error(transparent)] + Db(#[from] rocksdb::Error), + #[error("Merkle tree error: {0}")] + TreeError(String), +} + +#[derive(thiserror::Error, Debug, PartialEq)] +pub enum TxCounterError { + #[error("User (address: {0:?}) is not registered")] + NotRegistered(Address), + #[error(transparent)] + Db(#[from] rocksdb::Error), +} + +#[derive(thiserror::Error, Debug, PartialEq, Clone)] +pub enum MerkleTreeIndexError { + #[error("Uninitialized counter")] + DbUninitialized, + #[error(transparent)] + Db(#[from] rocksdb::Error), +} + +#[derive(thiserror::Error, Debug, PartialEq, Clone)] +pub enum UserMerkleTreeIndexError { + #[error("User (address: {0:?}) is not registered")] + NotRegistered(Address), + #[error(transparent)] + Db(#[from] rocksdb::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum SetTierLimitsError { + #[error(transparent)] + Validate(#[from] ValidateTierLimitsError), + #[error("Updating an invalid tier index")] + InvalidUpdateTierIndex, + #[error(transparent)] + Db(#[from] rocksdb::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum UserTierInfoError { + #[error("User {0} not registered")] + NotRegistered(Address), + #[error(transparent)] + Contract(E), + #[error(transparent)] + TxCounter(#[from] TxCounterError), + #[error(transparent)] + Db(#[from] rocksdb::Error), +} diff --git a/prover/src/user_db_serialization.rs b/prover/src/user_db_serialization.rs new file mode 100644 index 0000000..2090b05 --- /dev/null +++ b/prover/src/user_db_serialization.rs @@ -0,0 +1,317 @@ +use std::collections::BTreeMap; +use std::num::TryFromIntError; +use std::string::FromUtf8Error; +// third-party +use alloy::primitives::U256; +use ark_bn254::Fr; +use ark_ff::fields::AdditiveGroup; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError}; +use nom::{ + IResult, Parser, + bytes::complete::take, + error::{ContextError, context}, + multi::length_count, + number::complete::{le_u32, le_u64}, +}; +use rln_proof::RlnUserIdentity; +// internal +use crate::tier::TierLimits; +use crate::user_db_types::MerkleTreeIndex; +use smart_contract::{Tier, TierIndex}; + +pub(crate) struct RlnUserIdentitySerializer {} + +impl RlnUserIdentitySerializer { + pub(crate) fn serialize( + &self, + value: &RlnUserIdentity, + buffer: &mut Vec, + ) -> Result<(), SerializationError> { + buffer.resize(self.size_hint(), 0); + let compressed_size = value.commitment.compressed_size(); + let (co_buffer, rem_buffer) = buffer.split_at_mut(compressed_size); + value.commitment.serialize_compressed(co_buffer)?; + let (secret_buffer, user_limit_buffer) = rem_buffer.split_at_mut(compressed_size); + value.secret_hash.serialize_compressed(secret_buffer)?; + value.user_limit.serialize_compressed(user_limit_buffer)?; + Ok(()) + } + + pub(crate) fn size_hint(&self) -> usize { + Fr::ZERO.compressed_size() * 3 + } +} + +pub(crate) struct RlnUserIdentityDeserializer {} + +impl RlnUserIdentityDeserializer { + pub(crate) fn deserialize(&self, buffer: &[u8]) -> Result { + let compressed_size = Fr::ZERO.compressed_size(); + let (co_buffer, rem_buffer) = buffer.split_at(compressed_size); + let commitment: Fr = CanonicalDeserialize::deserialize_compressed(co_buffer)?; + let (secret_buffer, user_limit_buffer) = rem_buffer.split_at(compressed_size); + // TODO: IdSecret + let secret_hash: Fr = CanonicalDeserialize::deserialize_compressed(secret_buffer)?; + let user_limit: Fr = CanonicalDeserialize::deserialize_compressed(user_limit_buffer)?; + + Ok({ + RlnUserIdentity { + commitment, + secret_hash, + user_limit, + } + }) + } +} + +pub(crate) struct MerkleTreeIndexSerializer {} + +impl MerkleTreeIndexSerializer { + pub(crate) fn serialize(&self, value: &MerkleTreeIndex, buffer: &mut Vec) { + let value: u64 = (*value).into(); + buffer.extend(value.to_le_bytes()); + } + + pub(crate) fn size_hint(&self) -> usize { + // Note: Assume usize size == 8 bytes + size_of::() + } +} + +pub(crate) struct MerkleTreeIndexDeserializer {} + +impl MerkleTreeIndexDeserializer { + pub(crate) fn deserialize<'a>( + &self, + buffer: &'a [u8], + ) -> IResult<&'a [u8], MerkleTreeIndex, nom::error::Error<&'a [u8]>> { + le_u64(buffer).map(|(input, idx)| (input, MerkleTreeIndex::from(idx))) + } +} + +#[derive(Default)] +pub(crate) struct TierSerializer {} + +impl TierSerializer { + pub(crate) fn serialize( + &self, + value: &Tier, + buffer: &mut Vec, + ) -> Result<(), TryFromIntError> { + const U256_SIZE: usize = size_of::(); + buffer.extend(value.min_karma.to_le_bytes::().as_slice()); + buffer.extend(value.max_karma.to_le_bytes::().as_slice()); + + let name_len = u32::try_from(value.name.len())?; + buffer.extend(name_len.to_le_bytes()); + buffer.extend(value.name.as_bytes()); + buffer.extend(value.tx_per_epoch.to_le_bytes().as_slice()); + buffer.push(u8::from(value.active)); + Ok(()) + } + + pub(crate) fn size_hint(&self) -> usize { + size_of::() + } +} + +#[derive(Default)] +pub(crate) struct TierDeserializer {} + +#[derive(Debug, PartialEq)] +pub enum TierDeserializeError { + Utf8Error(FromUtf8Error), + TryFrom, + Nom(I, nom::error::ErrorKind), +} + +impl nom::error::ParseError for TierDeserializeError { + fn from_error_kind(input: I, kind: nom::error::ErrorKind) -> Self { + TierDeserializeError::Nom(input, kind) + } + + fn append(_: I, _: nom::error::ErrorKind, other: Self) -> Self { + other + } +} + +impl ContextError for TierDeserializeError {} + +impl TierDeserializer { + pub(crate) fn deserialize<'a>( + &self, + buffer: &'a [u8], + ) -> IResult<&'a [u8], Tier, TierDeserializeError<&'a [u8]>> { + let (input, min_karma) = take(32usize)(buffer)?; + let min_karma = U256::try_from_le_slice(min_karma) + .ok_or(nom::Err::Error(TierDeserializeError::TryFrom))?; + let (input, max_karma) = take(32usize)(input)?; + let max_karma = U256::try_from_le_slice(max_karma) + .ok_or(nom::Err::Error(TierDeserializeError::TryFrom))?; + let (input, name_len) = le_u32(input)?; + let name_len_ = usize::try_from(name_len) + .map_err(|_e| nom::Err::Error(TierDeserializeError::TryFrom))?; + let (input, name) = take(name_len_)(input)?; + let name = String::from_utf8(name.to_vec()) + .map_err(|e| nom::Err::Error(TierDeserializeError::Utf8Error(e)))?; + let (input, tx_per_epoch) = le_u32(input)?; + let (input, active) = take(1usize)(input)?; + let active = active[0] != 0; + + Ok(( + input, + Tier { + min_karma, + max_karma, + name, + tx_per_epoch, + active, + }, + )) + } +} + +#[derive(Default)] +pub(crate) struct TierLimitsSerializer { + tier_serializer: TierSerializer, +} + +impl TierLimitsSerializer { + pub(crate) fn serialize( + &self, + value: &TierLimits, + buffer: &mut Vec, + ) -> Result<(), TryFromIntError> { + let len = value.len() as u32; + buffer.extend(len.to_le_bytes()); + let mut tier_buffer = Vec::with_capacity(self.tier_serializer.size_hint()); + value.iter().try_for_each(|(k, v)| { + buffer.push(k.into()); + self.tier_serializer.serialize(v, &mut tier_buffer)?; + buffer.extend_from_slice(&tier_buffer); + tier_buffer.clear(); + Ok(()) + }) + } + + pub(crate) fn size_hint(&self, len: usize) -> usize { + size_of::() + len * self.tier_serializer.size_hint() + } +} + +#[derive(Default)] +pub(crate) struct TierLimitsDeserializer { + pub(crate) tier_deserializer: TierDeserializer, +} + +impl TierLimitsDeserializer { + pub(crate) fn deserialize<'a>( + &self, + buffer: &'a [u8], + ) -> IResult<&'a [u8], TierLimits, TierDeserializeError<&'a [u8]>> { + let (input, tiers): (&[u8], BTreeMap) = length_count( + le_u32, + context("Tier index & Tier deser", |input: &'a [u8]| { + let (input, tier_index) = take(1usize)(input)?; + let tier_index = TierIndex::from(tier_index[0]); + let (input, tier) = self.tier_deserializer.deserialize(input)?; + Ok((input, (tier_index, tier))) + }), + ) + .map(BTreeMap::from_iter) + .parse(buffer)?; + + Ok((input, TierLimits::from(tiers))) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_rln_ser_der() { + let rln_user_identity = RlnUserIdentity { + commitment: Fr::from(42), + secret_hash: Fr::from(u16::MAX), + user_limit: Fr::from(1_000_000), + }; + let serializer = RlnUserIdentitySerializer {}; + let mut buffer = Vec::with_capacity(serializer.size_hint()); + serializer + .serialize(&rln_user_identity, &mut buffer) + .unwrap(); + + let deserializer = RlnUserIdentityDeserializer {}; + let de = deserializer.deserialize(&buffer).unwrap(); + + assert_eq!(rln_user_identity, de); + } + + #[test] + fn test_mtree_ser_der() { + let index = MerkleTreeIndex::from(4242); + + let serializer = MerkleTreeIndexSerializer {}; + let mut buffer = Vec::with_capacity(serializer.size_hint()); + serializer.serialize(&index, &mut buffer); + + let deserializer = MerkleTreeIndexDeserializer {}; + let (_, de) = deserializer.deserialize(&buffer).unwrap(); + + assert_eq!(index, de); + } + + #[test] + fn test_tier_ser_der() { + let tier = Tier { + min_karma: U256::from(10), + max_karma: U256::from(u64::MAX), + name: "All".to_string(), + tx_per_epoch: 10_000_000, + active: false, + }; + + let serializer = TierSerializer {}; + let mut buffer = Vec::with_capacity(serializer.size_hint()); + serializer.serialize(&tier, &mut buffer).unwrap(); + + let deserializer = TierDeserializer {}; + let (_, de) = deserializer.deserialize(&buffer).unwrap(); + + assert_eq!(tier, de); + } + + #[test] + fn test_tier_limits_ser_der() { + let tier_1 = Tier { + min_karma: U256::from(2), + max_karma: U256::from(4), + name: "Basic".to_string(), + tx_per_epoch: 10_000, + active: false, + }; + let tier_2 = Tier { + min_karma: U256::from(10), + max_karma: U256::from(u64::MAX), + name: "Premium".to_string(), + tx_per_epoch: 1_000_000_000, + active: true, + }; + + let tier_limits = TierLimits::from(BTreeMap::from([ + (TierIndex::from(1), tier_1), + (TierIndex::from(2), tier_2), + ])); + + let serializer = TierLimitsSerializer::default(); + let mut buffer = Vec::with_capacity(serializer.size_hint(tier_limits.len())); + serializer.serialize(&tier_limits, &mut buffer).unwrap(); + + let deserializer = TierLimitsDeserializer::default(); + let (_, de) = deserializer.deserialize(&buffer).unwrap(); + + assert_eq!(tier_limits, de); + } +} diff --git a/prover/src/user_db_service.rs b/prover/src/user_db_service.rs index cbab4aa..186631d 100644 --- a/prover/src/user_db_service.rs +++ b/prover/src/user_db_service.rs @@ -1,349 +1,17 @@ -use std::ops::Deref; +// std +use parking_lot::RwLock; +use std::path::PathBuf; use std::sync::Arc; // third-party -use alloy::primitives::{Address, U256}; -use ark_bn254::Fr; -use derive_more::{Add, From, Into}; -use parking_lot::RwLock; -use rln::hashers::poseidon_hash; -use rln::poseidon_tree::{MerkleProof, PoseidonTree}; -use rln::protocol::keygen; -use scc::HashMap; use tokio::sync::Notify; use tracing::debug; // internal use crate::epoch_service::{Epoch, EpochSlice}; -use crate::error::{AppError, GetMerkleTreeProofError, RegisterError}; -use crate::tier::{TierLimit, TierLimits, TierName}; -use rln_proof::{RlnUserIdentity, ZerokitMerkleTree}; -use smart_contract::{KarmaAmountExt, Tier, TierIndex}; - -const MERKLE_TREE_HEIGHT: usize = 20; - -#[derive(Debug, Clone, Copy, From, Into)] -struct MerkleTreeIndex(usize); - -#[derive(Debug, Clone, Copy, Default, PartialOrd, PartialEq, From, Into)] -pub struct RateLimit(u64); - -impl RateLimit { - pub(crate) const ZERO: RateLimit = RateLimit(0); - - pub(crate) const fn new(value: u64) -> Self { - Self(value) - } -} - -impl From for Fr { - fn from(rate_limit: RateLimit) -> Self { - Fr::from(rate_limit.0) - } -} - -#[derive(Clone)] -pub(crate) struct UserRegistry { - inner: HashMap, - merkle_tree: Arc>, - rate_limit: RateLimit, -} - -impl std::fmt::Debug for UserRegistry { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "UserRegistry {{ inner: {:?} }}", self.inner) - } -} - -impl Default for UserRegistry { - fn default() -> Self { - Self { - inner: Default::default(), - // unwrap safe - no config - merkle_tree: Arc::new(RwLock::new( - PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default()) - .unwrap(), - )), - rate_limit: Default::default(), - } - } -} - -impl From for UserRegistry { - fn from(rate_limit: RateLimit) -> Self { - Self { - inner: Default::default(), - // unwrap safe - no config - merkle_tree: Arc::new(RwLock::new( - PoseidonTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default()) - .unwrap(), - )), - rate_limit, - } - } -} - -impl UserRegistry { - fn register(&self, address: Address) -> Result { - let (identity_secret_hash, id_commitment) = keygen(); - let index = self.inner.len(); - - self.inner - .insert( - address, - ( - RlnUserIdentity::from(( - identity_secret_hash, - id_commitment, - Fr::from(self.rate_limit), - )), - MerkleTreeIndex(index), - ), - ) - .map_err(|_e| RegisterError::AlreadyRegistered(address))?; - - let rate_commit = poseidon_hash(&[id_commitment, Fr::from(u64::from(self.rate_limit))]); - self.merkle_tree - .write() - .set(index, rate_commit) - .map_err(|e| RegisterError::TreeError(e.to_string()))?; - Ok(id_commitment) - } - - fn has_user(&self, address: &Address) -> bool { - self.inner.contains(address) - } - - fn get_user(&self, address: &Address) -> Option { - self.inner.get(address).map(|entry| entry.0.clone()) - } - - fn get_merkle_proof(&self, address: &Address) -> Result { - let index = self - .inner - .get(address) - .map(|entry| entry.1) - .ok_or(GetMerkleTreeProofError::NotRegistered)?; - self.merkle_tree - .read() - .proof(index.into()) - .map_err(|e| GetMerkleTreeProofError::TreeError(e.to_string())) - } -} - -#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)] -pub(crate) struct EpochCounter(u64); - -#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)] -pub(crate) struct EpochSliceCounter(u64); - -#[derive(Debug, Default, Clone)] -pub(crate) struct TxRegistry { - inner: HashMap, -} - -impl Deref for TxRegistry { - type Target = HashMap; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl TxRegistry { - /// Update the transaction counter for the given address - /// - /// If incr_value is None, the counter will be incremented by 1 - /// If incr_value is Some(x), the counter will be incremented by x - /// - /// Returns the new value of the counter - pub fn incr_counter(&self, address: &Address, incr_value: Option) -> EpochSliceCounter { - let incr_value = incr_value.unwrap_or(1); - let mut entry = self.inner.entry(*address).or_default(); - *entry = ( - entry.0 + EpochCounter(incr_value), - entry.1 + EpochSliceCounter(incr_value), - ); - entry.1 - } -} - -#[derive(Debug, PartialEq)] -pub struct UserTierInfo { - pub(crate) current_epoch: Epoch, - pub(crate) current_epoch_slice: EpochSlice, - pub(crate) epoch_tx_count: u64, - pub(crate) epoch_slice_tx_count: u64, - karma_amount: U256, - pub(crate) tier_name: Option, - pub(crate) tier_limit: Option, -} - -#[derive(Debug, thiserror::Error)] -pub enum UserTierInfoError { - #[error("User {0} not registered")] - NotRegistered(Address), - #[error(transparent)] - Contract(E), -} - -/// User registration + tx counters + tier limits storage -#[derive(Debug, Clone)] -pub struct UserDb { - user_registry: Arc, - tx_registry: Arc, - tier_limits: Arc>, - tier_limits_next: Arc>, - epoch_store: Arc>, -} - -impl UserDb { - fn on_new_epoch(&self) { - self.tx_registry.clear() - } - - fn on_new_epoch_slice(&self) { - self.tx_registry.retain(|_a, v| { - *v = (v.0, Default::default()); - true - }); - - let tier_limits_next_has_updates = !self.tier_limits_next.read().is_empty(); - if tier_limits_next_has_updates { - let mut guard = self.tier_limits_next.write(); - // mem::take will clear the TierLimits in tier_limits_next - let new_tier_limits = std::mem::take(&mut *guard); - debug!("Installing new tier limits: {:?}", new_tier_limits); - *self.tier_limits.write() = new_tier_limits; - } - } - - pub fn on_new_user(&self, address: Address) -> Result { - self.user_registry.register(address) - } - - pub fn get_user(&self, address: &Address) -> Option { - self.user_registry.get_user(address) - } - - pub fn get_merkle_proof( - &self, - address: &Address, - ) -> Result { - self.user_registry.get_merkle_proof(address) - } - - pub(crate) fn on_new_tx( - &self, - address: &Address, - incr_value: Option, - ) -> Option { - if self.user_registry.has_user(address) { - Some(self.tx_registry.incr_counter(address, incr_value)) - } else { - None - } - } - - pub(crate) fn on_new_tier_limits( - &self, - tier_limits: TierLimits, - ) -> Result<(), SetTierLimitsError> { - let tier_limits = tier_limits.clone().filter_inactive(); - tier_limits.validate()?; - *self.tier_limits_next.write() = tier_limits; - Ok(()) - } - - pub(crate) fn on_new_tier( - &self, - tier_index: TierIndex, - tier: Tier, - ) -> Result<(), SetTierLimitsError> { - let mut tier_limits = self.tier_limits.read().clone(); - tier_limits.insert(tier_index, tier); - tier_limits.validate()?; - // Write - *self.tier_limits_next.write() = tier_limits; - Ok(()) - } - - pub(crate) fn on_tier_updated( - &self, - tier_index: TierIndex, - tier: Tier, - ) -> Result<(), SetTierLimitsError> { - let mut tier_limits = self.tier_limits.read().clone(); - if !tier_limits.contains_key(&tier_index) { - return Err(SetTierLimitsError::InvalidTierIndex); - } - tier_limits.entry(tier_index).and_modify(|e| *e = tier); - tier_limits.validate()?; - // Write - *self.tier_limits_next.write() = tier_limits; - Ok(()) - } - - /// Get user tier info - pub(crate) async fn user_tier_info>( - &self, - address: &Address, - karma_sc: &KSC, - ) -> Result> { - if self.user_registry.has_user(address) { - let (epoch_tx_count, epoch_slice_tx_count) = self - .tx_registry - .get(address) - .map(|ref_v| (ref_v.0, ref_v.1)) - .unwrap_or_default(); - - let karma_amount = karma_sc - .karma_amount(address) - .await - .map_err(|e| UserTierInfoError::Contract(e))?; - let tier_limits_guard = self.tier_limits.read(); - let tier_info = tier_limits_guard.get_tier_by_karma(&karma_amount); - drop(tier_limits_guard); - - let user_tier_info = { - let (current_epoch, current_epoch_slice) = *self.epoch_store.read(); - let mut t = UserTierInfo { - current_epoch, - current_epoch_slice, - epoch_tx_count: epoch_tx_count.into(), - epoch_slice_tx_count: epoch_slice_tx_count.into(), - karma_amount, - tier_name: None, - tier_limit: None, - }; - if let Some((_tier_index, tier)) = tier_info { - t.tier_name = Some(tier.name.into()); - // TODO - t.tier_limit = Some(0.into()); - } - t - }; - - Ok(user_tier_info) - } else { - Err(UserTierInfoError::NotRegistered(*address)) - } - } -} - -#[derive(Debug, thiserror::Error)] -pub enum SetTierLimitsError { - #[error("Invalid Karma amount (must be increasing)")] - InvalidKarmaAmount, - #[error("Invalid Karma max amount (min: {0} vs max: {1})")] - InvalidMaxAmount(U256, U256), - #[error("Invalid Tier limit (must be increasing)")] - InvalidTierLimit, - #[error("Invalid Tier index (must be increasing)")] - InvalidTierIndex, - #[error("Non unique Tier name")] - NonUniqueTierName, - #[error("Non active Tier")] - InactiveTier, -} +use crate::error::AppError; +use crate::tier::TierLimits; +use crate::user_db::UserDb; +use crate::user_db_error::UserDbOpenError; +use crate::user_db_types::RateLimit; /// Async service to update a UserDb on epoch changes #[derive(Debug)] @@ -353,22 +21,25 @@ pub struct UserDbService { } impl UserDbService { - pub(crate) fn new( + pub fn new( + db_path: PathBuf, + merkle_tree_path: PathBuf, epoch_changes_notifier: Arc, epoch_store: Arc>, rate_limit: RateLimit, tier_limits: TierLimits, - ) -> Self { - Self { - user_db: UserDb { - user_registry: Arc::new(UserRegistry::from(rate_limit)), - tx_registry: Default::default(), - tier_limits: Arc::new(RwLock::new(tier_limits)), - tier_limits_next: Arc::new(Default::default()), - epoch_store, - }, + ) -> Result { + let user_db = UserDb::new( + db_path, + merkle_tree_path, + epoch_store, + tier_limits, + rate_limit, + )?; + Ok(Self { + user_db, epoch_changes: epoch_changes_notifier, - } + }) } pub fn get_user_db(&self) -> UserDb { @@ -412,507 +83,3 @@ impl UserDbService { *current_epoch_slice = new_epoch_slice; } } - -#[cfg(test)] -mod tests { - use super::*; - // std - use std::collections::BTreeMap; - // third-party - use alloy::primitives::address; - use async_trait::async_trait; - use claims::{assert_err, assert_matches}; - use derive_more::Display; - - #[derive(Debug, Display, thiserror::Error)] - struct DummyError(); - - struct MockKarmaSc {} - - #[async_trait] - impl KarmaAmountExt for MockKarmaSc { - type Error = DummyError; - - async fn karma_amount(&self, _address: &Address) -> Result { - Ok(U256::from(10)) - } - } - - const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045"); - const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0"); - - struct MockKarmaSc2 {} - - #[async_trait] - impl KarmaAmountExt for MockKarmaSc2 { - type Error = DummyError; - - async fn karma_amount(&self, address: &Address) -> Result { - if address == &ADDR_1 { - Ok(U256::from(10)) - } else if address == &ADDR_2 { - Ok(U256::from(2000)) - } else { - Ok(U256::ZERO) - } - } - } - - #[test] - fn test_user_register() { - let user_db = UserDb { - user_registry: Default::default(), - tx_registry: Default::default(), - tier_limits: Arc::new(RwLock::new(Default::default())), - tier_limits_next: Arc::new(Default::default()), - epoch_store: Arc::new(RwLock::new(Default::default())), - }; - let addr = Address::new([0; 20]); - user_db.user_registry.register(addr).unwrap(); - assert_matches!( - user_db.user_registry.register(addr), - Err(RegisterError::AlreadyRegistered(_)) - ); - } - - #[tokio::test] - async fn test_incr_tx_counter() { - let user_db = UserDb { - user_registry: Default::default(), - tx_registry: Default::default(), - tier_limits: Arc::new(RwLock::new(Default::default())), - tier_limits_next: Arc::new(Default::default()), - epoch_store: Arc::new(RwLock::new(Default::default())), - }; - let addr = Address::new([0; 20]); - - // Try to update tx counter without registering first - assert_eq!(user_db.on_new_tx(&addr, None), None); - let tier_info = user_db.user_tier_info(&addr, &MockKarmaSc {}).await; - // User is not registered -> no tier info - assert!(matches!( - tier_info, - Err(UserTierInfoError::NotRegistered(_)) - )); - // Register user - user_db.user_registry.register(addr).unwrap(); - // Now update user tx counter - assert_eq!(user_db.on_new_tx(&addr, None), Some(EpochSliceCounter(1))); - let tier_info = user_db - .user_tier_info(&addr, &MockKarmaSc {}) - .await - .unwrap(); - assert_eq!(tier_info.epoch_tx_count, 1); - assert_eq!(tier_info.epoch_slice_tx_count, 1); - } - - #[tokio::test] - async fn test_update_on_epoch_changes() { - let mut epoch = Epoch::from(11); - let mut epoch_slice = EpochSlice::from(42); - let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(0), - Tier { - name: "Basic".into(), - min_karma: U256::from(10), - max_karma: U256::from(49), - tx_per_epoch: 5, - active: true, - }, - ), - ( - TierIndex::from(1), - Tier { - name: "Active".into(), - min_karma: U256::from(50), - max_karma: U256::from(99), - tx_per_epoch: 10, - active: true, - }, - ), - ( - TierIndex::from(2), - Tier { - name: "Regular".into(), - min_karma: U256::from(100), - max_karma: U256::from(499), - tx_per_epoch: 15, - active: true, - }, - ), - ( - TierIndex::from(3), - Tier { - name: "Power User".into(), - min_karma: U256::from(500), - max_karma: U256::from(4999), - tx_per_epoch: 20, - active: true, - }, - ), - ( - TierIndex::from(4), - Tier { - name: "S-Tier".into(), - min_karma: U256::from(5000), - max_karma: U256::from(U256::MAX), - tx_per_epoch: 25, - active: true, - }, - ), - ]); - - let user_db_service = UserDbService::new( - Default::default(), - epoch_store, - 10.into(), - tier_limits.into(), - ); - let user_db = user_db_service.get_user_db(); - - let addr_1_tx_count = 2; - let addr_2_tx_count = 820; - user_db.user_registry.register(ADDR_1).unwrap(); - user_db - .tx_registry - .incr_counter(&ADDR_1, Some(addr_1_tx_count)); - user_db.user_registry.register(ADDR_2).unwrap(); - user_db - .tx_registry - .incr_counter(&ADDR_2, Some(addr_2_tx_count)); - - // incr epoch slice (42 -> 43) - { - let new_epoch = epoch; - let new_epoch_slice = epoch_slice + 1; - user_db_service.update_on_epoch_changes( - &mut epoch, - new_epoch, - &mut epoch_slice, - new_epoch_slice, - ); - let addr_1_tier_info = user_db - .user_tier_info(&ADDR_1, &MockKarmaSc2 {}) - .await - .unwrap(); - assert_eq!(addr_1_tier_info.epoch_tx_count, addr_1_tx_count); - assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0); - assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic"))); - - let addr_2_tier_info = user_db - .user_tier_info(&ADDR_2, &MockKarmaSc2 {}) - .await - .unwrap(); - assert_eq!(addr_2_tier_info.epoch_tx_count, addr_2_tx_count); - assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0); - assert_eq!( - addr_2_tier_info.tier_name, - Some(TierName::from("Power User")) - ); - } - - // incr epoch (11 -> 12, epoch slice reset) - { - let new_epoch = epoch + 1; - let new_epoch_slice = EpochSlice::from(0); - user_db_service.update_on_epoch_changes( - &mut epoch, - new_epoch, - &mut epoch_slice, - new_epoch_slice, - ); - let addr_1_tier_info = user_db - .user_tier_info(&ADDR_1, &MockKarmaSc2 {}) - .await - .unwrap(); - assert_eq!(addr_1_tier_info.epoch_tx_count, 0); - assert_eq!(addr_1_tier_info.epoch_slice_tx_count, 0); - assert_eq!(addr_1_tier_info.tier_name, Some(TierName::from("Basic"))); - - let addr_2_tier_info = user_db - .user_tier_info(&ADDR_2, &MockKarmaSc2 {}) - .await - .unwrap(); - assert_eq!(addr_2_tier_info.epoch_tx_count, 0); - assert_eq!(addr_2_tier_info.epoch_slice_tx_count, 0); - assert_eq!( - addr_2_tier_info.tier_name, - Some(TierName::from("Power User")) - ); - } - } - - #[test] - #[tracing_test::traced_test] - fn test_set_tier_limits() { - // Check if we can update tier limits (and it updates after an epoch slice change) - - let mut epoch = Epoch::from(11); - let mut epoch_slice = EpochSlice::from(42); - let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - - let user_db_service = UserDbService::new( - Default::default(), - epoch_store, - 10.into(), - Default::default(), - ); - let user_db = user_db_service.get_user_db(); - let tier_limits_original = user_db.tier_limits.read().clone(); - - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(1), - Tier { - name: "Basic".into(), - min_karma: U256::from(10), - max_karma: U256::from(49), - tx_per_epoch: 5, - active: true, - }, - ), - ( - TierIndex::from(2), - Tier { - name: "Power User".into(), - min_karma: U256::from(50), - max_karma: U256::from(299), - tx_per_epoch: 20, - active: true, - }, - ), - ]); - let tier_limits = TierLimits::from(tier_limits); - - user_db.on_new_tier_limits(tier_limits.clone()).unwrap(); - // Check it is not yet installed - assert_ne!(*user_db.tier_limits.read(), tier_limits); - assert_eq!(*user_db.tier_limits.read(), tier_limits_original); - assert_eq!(*user_db.tier_limits_next.read(), tier_limits); - - let new_epoch = epoch; - let new_epoch_slice = epoch_slice + 1; - user_db_service.update_on_epoch_changes( - &mut epoch, - new_epoch, - &mut epoch_slice, - new_epoch_slice, - ); - - // Should be installed now - assert_eq!(*user_db.tier_limits.read(), tier_limits); - // And the tier_limits_next field is expected to be empty - assert!(user_db.tier_limits_next.read().is_empty()); - } - - #[test] - #[tracing_test::traced_test] - fn test_set_invalid_tier_limits() { - // Check we cannot update with invalid tier limits - - let epoch = Epoch::from(11); - let epoch_slice = EpochSlice::from(42); - let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - let user_db_service = UserDbService::new( - Default::default(), - epoch_store, - 10.into(), - Default::default(), - ); - let user_db = user_db_service.get_user_db(); - - let tier_limits_original = user_db.tier_limits.read().clone(); - - // Invalid: non unique index - { - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(0), - Tier { - min_karma: Default::default(), - max_karma: Default::default(), - name: "Basic".to_string(), - tx_per_epoch: 100, - active: true, - }, - ), - ( - TierIndex::from(0), - Tier { - min_karma: Default::default(), - max_karma: Default::default(), - name: "Power User".to_string(), - tx_per_epoch: 200, - active: true, - }, - ), - ]); - let tier_limits = TierLimits::from(tier_limits); - - assert_err!(user_db.on_new_tier_limits(tier_limits.clone())); - assert_eq!(*user_db.tier_limits.read(), tier_limits_original); - } - - // Invalid: min Karma amount not increasing - { - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(0), - Tier { - min_karma: U256::from(10), - max_karma: U256::from(49), - name: "Basic".to_string(), - tx_per_epoch: 5, - active: true, - }, - ), - ( - TierIndex::from(1), - Tier { - min_karma: U256::from(50), - max_karma: U256::from(99), - name: "Power User".to_string(), - tx_per_epoch: 10, - active: true, - }, - ), - ( - TierIndex::from(2), - Tier { - min_karma: U256::from(60), - max_karma: U256::from(99), - name: "Power User".to_string(), - tx_per_epoch: 15, - active: true, - }, - ), - ]); - let tier_limits = TierLimits::from(tier_limits); - - assert_err!(user_db.on_new_tier_limits(tier_limits.clone())); - assert_eq!(*user_db.tier_limits.read(), tier_limits_original); - } - - // Invalid: Non unique tier name - { - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(0), - Tier { - min_karma: U256::from(10), - max_karma: U256::from(49), - name: "Basic".to_string(), - tx_per_epoch: 5, - active: true, - }, - ), - ( - TierIndex::from(1), - Tier { - min_karma: U256::from(50), - max_karma: U256::from(99), - name: "Power User".to_string(), - tx_per_epoch: 10, - active: true, - }, - ), - ( - TierIndex::from(2), - Tier { - min_karma: U256::from(100), - max_karma: U256::from(999), - name: "Power User".to_string(), - tx_per_epoch: 15, - active: true, - }, - ), - ]); - let tier_limits = TierLimits::from(tier_limits); - - assert_err!(user_db.on_new_tier_limits(tier_limits.clone())); - assert_eq!(*user_db.tier_limits.read(), tier_limits_original); - } - } - - #[test] - #[tracing_test::traced_test] - fn test_set_invalid_tier_limits_2() { - // Check we cannot update with invalid tier limits - - let epoch = Epoch::from(11); - let epoch_slice = EpochSlice::from(42); - let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - let user_db_service = UserDbService::new( - Default::default(), - epoch_store, - 10.into(), - Default::default(), - ); - let user_db = user_db_service.get_user_db(); - - let tier_limits_original = user_db.tier_limits.read().clone(); - - // Invalid: inactive tier - { - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(0), - Tier { - min_karma: U256::from(10), - max_karma: U256::from(49), - name: "Basic".to_string(), - tx_per_epoch: 5, - active: true, - }, - ), - ( - TierIndex::from(1), - Tier { - min_karma: U256::from(50), - max_karma: U256::from(99), - name: "Power User".to_string(), - tx_per_epoch: 10, - active: true, - }, - ), - ]); - let tier_limits = TierLimits::from(tier_limits); - - assert_err!(user_db.on_new_tier_limits(tier_limits.clone())); - assert_eq!(*user_db.tier_limits.read(), tier_limits_original); - } - - // Invalid: non-increasing tx_per_epoch - { - let tier_limits = BTreeMap::from([ - ( - TierIndex::from(0), - Tier { - min_karma: U256::from(10), - max_karma: U256::from(49), - name: "Basic".to_string(), - tx_per_epoch: 5, - active: true, - }, - ), - ( - TierIndex::from(1), - Tier { - min_karma: U256::from(50), - max_karma: U256::from(99), - name: "Power User".to_string(), - tx_per_epoch: 5, - active: true, - }, - ), - ]); - let tier_limits = TierLimits::from(tier_limits); - - assert_err!(user_db.on_new_tier_limits(tier_limits.clone())); - assert_eq!(*user_db.tier_limits.read(), tier_limits_original); - } - } -} diff --git a/prover/src/user_db_types.rs b/prover/src/user_db_types.rs new file mode 100644 index 0000000..d3b4509 --- /dev/null +++ b/prover/src/user_db_types.rs @@ -0,0 +1,37 @@ +use ark_bn254::Fr; +use derive_more::{Add, From, Into}; + +/// A wrapper type over u64 +#[derive(Debug, Clone, Copy, From, Into, PartialEq)] +pub(crate) struct MerkleTreeIndex(u64); + +impl From for usize { + fn from(value: MerkleTreeIndex) -> Self { + // TODO: compile time assert + value.0 as usize + } +} + +/// A wrapper type over u64 +#[derive(Debug, Clone, Copy, Default, PartialOrd, PartialEq, From, Into)] +pub struct RateLimit(u64); + +impl RateLimit { + pub(crate) const ZERO: RateLimit = RateLimit(0); + + pub(crate) const fn new(value: u64) -> Self { + Self(value) + } +} + +impl From for Fr { + fn from(rate_limit: RateLimit) -> Self { + Fr::from(rate_limit.0) + } +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)] +pub(crate) struct EpochCounter(u64); + +#[derive(Debug, Default, Clone, Copy, PartialEq, From, Into, Add)] +pub(crate) struct EpochSliceCounter(u64); diff --git a/rln_proof/Cargo.toml b/rln_proof/Cargo.toml index 30fcd91..3241fe3 100644 --- a/rln_proof/Cargo.toml +++ b/rln_proof/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -rln = { git = "https://github.com/vacp2p/zerokit", package = "rln", default-features = false } +rln = { git = "https://github.com/vacp2p/zerokit", package = "rln", features = ["pmtree-ft"] } zerokit_utils = { git = "https://github.com/vacp2p/zerokit", package = "zerokit_utils", features = ["default"] } ark-bn254.workspace = true ark-groth16.workspace = true diff --git a/rln_proof/benches/generate_proof.rs b/rln_proof/benches/generate_proof.rs index e73c601..5b8a1a6 100644 --- a/rln_proof/benches/generate_proof.rs +++ b/rln_proof/benches/generate_proof.rs @@ -6,6 +6,7 @@ use criterion::{Criterion, criterion_group, criterion_main}; use ark_bn254::Fr; use ark_serialize::CanonicalSerialize; use rln::hashers::{hash_to_field, poseidon_hash}; +use rln::poseidon_tree::PoseidonTree; use rln::protocol::{keygen, serialize_proof_values}; use zerokit_utils::OptimalMerkleTree; // internal @@ -29,7 +30,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { // Merkle tree let tree_height = 20; - let mut tree = OptimalMerkleTree::new(tree_height, Fr::from(0), Default::default()).unwrap(); + let mut tree = PoseidonTree::new(tree_height, Fr::from(0), Default::default()).unwrap(); let rate_commit = poseidon_hash(&[rln_identity.commitment, rln_identity.user_limit]); tree.set(0, rate_commit).unwrap(); let merkle_proof = tree.proof(0).unwrap(); diff --git a/rln_proof/src/proof.rs b/rln_proof/src/proof.rs index a2ef443..827f503 100644 --- a/rln_proof/src/proof.rs +++ b/rln_proof/src/proof.rs @@ -15,7 +15,7 @@ use rln::{ }; /// A RLN user identity & limit -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct RlnUserIdentity { pub commitment: Fr, pub secret_hash: Fr, @@ -42,7 +42,6 @@ pub struct RlnIdentifier { impl RlnIdentifier { pub fn new(identifier: &[u8]) -> Self { - // TODO: valid / correct ? let pk_and_matrices = { let mut reader = Cursor::new(ZKEY_BYTES); read_zkey(&mut reader).unwrap() @@ -96,8 +95,9 @@ pub fn compute_rln_proof_and_values( #[cfg(test)] mod tests { use super::*; + use rln::poseidon_tree::PoseidonTree; use rln::protocol::{compute_id_secret, keygen}; - use zerokit_utils::{OptimalMerkleTree, ZerokitMerkleTree}; + use zerokit_utils::ZerokitMerkleTree; #[test] fn test_recover_secret_hash() { @@ -105,7 +105,8 @@ mod tests { let epoch = hash_to_field(b"foo"); let spam_limit = Fr::from(10); - let mut tree = OptimalMerkleTree::new(20, Default::default(), Default::default()).unwrap(); + // let mut tree = OptimalMerkleTree::new(20, Default::default(), Default::default()).unwrap(); + let mut tree = PoseidonTree::new(20, Default::default(), Default::default()).unwrap(); tree.set(0, spam_limit).unwrap(); let m_proof = tree.proof(0).unwrap(); diff --git a/smart_contract/src/karma_tiers.rs b/smart_contract/src/karma_tiers.rs index 95c0b92..c133271 100644 --- a/smart_contract/src/karma_tiers.rs +++ b/smart_contract/src/karma_tiers.rs @@ -61,6 +61,12 @@ impl KarmaTiersSC::KarmaTiersSCInstance { #[derive(Debug, Clone, Default, Copy, From, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct TierIndex(u8); +impl From<&TierIndex> for u8 { + fn from(value: &TierIndex) -> u8 { + value.0 + } +} + #[derive(Debug, Clone, PartialEq)] pub struct Tier { pub min_karma: U256,