diff --git a/Cargo.lock b/Cargo.lock index 839be72..30a4acb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4407,6 +4407,7 @@ version = "0.1.0" dependencies = [ "alloy", "ark-bn254", + "ark-groth16", "ark-serialize 0.5.0", "async-channel", "bytesize", diff --git a/proto/net/vac/prover/prover.proto b/proto/net/vac/prover/prover.proto index a12f7db..76d6a0b 100644 --- a/proto/net/vac/prover/prover.proto +++ b/proto/net/vac/prover/prover.proto @@ -12,27 +12,29 @@ service RlnProver { rpc SendTransaction (SendTransactionRequest) returns (SendTransactionReply); rpc RegisterUser (RegisterUserRequest) returns (RegisterUserReply); // Server side streaming RPC: 1 request -> X responses (stream) - rpc GetProofs(RlnProofFilter) returns (stream RlnProof); + rpc GetProofs(RlnProofFilter) returns (stream RlnProofReply); rpc GetUserTierInfo(GetUserTierInfoRequest) returns (GetUserTierInfoReply); rpc SetTierLimits(SetTierLimitsRequest) returns (SetTierLimitsReply); } +/* // TransactionType: https://github.com/Consensys/linea-besu/blob/09cbed1142cfe4d29b50ecf2f156639a4bc8c854/datatypes/src/main/java/org/hyperledger/besu/datatypes/TransactionType.java#L22 enum TransactionType { - /** The Frontier. */ - // FRONTIER(0xf8 /* this is serialized as 0x0 in TransactionCompleteResult */), + // The Frontier + // FRONTIER(0xf8), // FIXME: is this 0xF8 or 0x00 ? FRONTIER = 0; - /** Access list transaction type. */ + // Access list transaction type ACCESS_LIST = 1; // 0x01 - /** Eip1559 transaction type. */ + // Eip1559 transaction type EIP1559 = 2; // 0x02 - /** Blob transaction type. */ + // Blob transaction type BLOB = 3; // 0x03 - /** Eip7702 transaction type. */ + // Eip7702 transaction type DELEGATE_CODE = 4; // 0x04 } +*/ extend google.protobuf.FieldOptions { optional uint32 max_size = 50000; @@ -53,42 +55,57 @@ message Address { bytes value = 1 [(max_size) = 20]; } +/* message SECPSignature { // https://github.com/Consensys/linea-besu/blob/zkbesu/crypto/algorithms/src/main/java/org/hyperledger/besu/crypto/SECPSignature.java#L30 bytes value = 1 [(max_size) = 65]; } +*/ +/* message StorageKey { bytes value = 1 [(max_size) = 32]; } +*/ +/* message AccessListEntry { Address address = 1; repeated StorageKey storageKeys = 2; } +*/ +/* message AccessListEntries { // https://github.com/Consensys/linea-besu/blob/zkbesu/datatypes/src/main/java/org/hyperledger/besu/datatypes/AccessListEntry.java#L31 repeated AccessListEntry entries = 1; } +*/ +/* message VersionedHash { // https://github.com/Consensys/linea-besu/blob/zkbesu/datatypes/src/main/java/org/hyperledger/besu/datatypes/VersionedHash.java#L28 bytes value = 1 [(max_size) = 32]; } +*/ +/* message BlobsWithCommitments { // https://github.com/Consensys/linea-besu/blob/zkbesu/datatypes/src/main/java/org/hyperledger/besu/datatypes/BlobsWithCommitments.java#L23 // TODO: need this? } +*/ +/* message CodeDelegation { // https://github.com/Consensys/linea-besu/blob/zkbesu/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/core/CodeDelegation.java#L40 // TODO: need this? } +*/ // Transaction: https://github.com/Consensys/linea-besu/blob/c99bdbd533707a45fad97017fb964578c3e87fde/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/core/Transaction.java#L168 message SendTransactionRequest { + /* bool forCopy = 1; TransactionType transactionType = 2; // Java long == signed 64-bit integer @@ -113,21 +130,45 @@ message SendTransactionRequest { repeated BlobsWithCommitments blobsWithCommitments = 17; repeated CodeDelegation maybeCodeDelegationList = 18; optional bytes rawRlp = 19; + */ + optional Wei gasPrice = 1; + optional Address sender = 2; + optional U256 chainId = 3; + bytes transactionHash = 4 [(max_size) = 32]; } message SendTransactionReply { bool result = 1; - // string message = 2; } message RlnProofFilter { optional string address = 1; } +message RlnProofReply { + oneof resp { + // variant for success + RlnProof proof = 1; + // variant for error + RlnProofError error = 2; + } +} + message RlnProof { - string sender = 1; - string id_commitment = 2; - bytes proof = 3; + // From https://rfc.vac.dev/vac/32/rln-v1#sending-the-output-message + bytes sender = 1; + bytes tx_hash = 2; // Transaction hash for the proof (non hash signal) + bytes proof = 3; // The RLN proof itself, hex encoded + // bytes internal_nullifier = 4; + // bytes x = 5; // signal hash + // bytes y = 6; + // bytes rln_identifier = 7; + // bytes merkle_proof_root = 8; + // bytes epoch = 9; +} + +message RlnProofError { + string error = 2; } message RegisterUserRequest { diff --git a/prover/Cargo.toml b/prover/Cargo.toml index a64a51d..6b659c9 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -37,6 +37,7 @@ tonic-build = "*" [dev-dependencies] criterion = "0.6.0" claims = "0.8" +ark-groth16 = "0.5.0" [[bench]] name = "user_db_heavy_write" diff --git a/prover/src/epoch_service.rs b/prover/src/epoch_service.rs index be25c75..6145858 100644 --- a/prover/src/epoch_service.rs +++ b/prover/src/epoch_service.rs @@ -487,7 +487,7 @@ mod tests { .map_err(|e| AppErrorExt::AppError(e)), // Wait for 3 epoch slices + 100 ms (to wait to receive notif + counter incr) tokio::time::timeout( - epoch_slice_duration * 3 + Duration::from_millis(100), + epoch_slice_duration * 3 + Duration::from_millis(500), async move { loop { notifier.notified().await; diff --git a/prover/src/error.rs b/prover/src/error.rs index 42b8270..e4e9363 100644 --- a/prover/src/error.rs +++ b/prover/src/error.rs @@ -3,6 +3,8 @@ use alloy::{ primitives::Address, transports::{RpcError, TransportErrorKind}, }; +use ark_serialize::SerializationError; +use rln::protocol::ProofError; #[derive(thiserror::Error, Debug)] pub enum AppError { @@ -18,6 +20,7 @@ pub enum AppError { EpochError(#[from] WaitUntilError), } +/* #[derive(thiserror::Error, Debug)] pub enum RegistrationError { #[error("Transaction has no sender address")] @@ -27,3 +30,62 @@ pub enum RegistrationError { #[error("Cannot find id_commitment for address: {0:?}")] NotFound(Address), } +*/ + +#[derive(thiserror::Error, Debug)] +pub enum ProofGenerationError { + #[error("Proof generation failed: {0}")] + Proof(#[from] ProofError), + #[error("Proof serialization failed: {0}")] + Serialization(#[from] SerializationError), + #[error("Proof serialization failed: {0}")] + SerializationWrite(#[from] std::io::Error), + #[error(transparent)] + MerkleProofError(#[from] GetMerkleTreeProofError), +} + +/// Same as ProofGenerationError but can be Cloned (can be used in Tokio broadcast channels) +#[derive(thiserror::Error, Debug, Clone)] +pub enum ProofGenerationStringError { + #[error("Proof generation failed: {0}")] + Proof(String), + #[error("Proof serialization failed: {0}")] + Serialization(String), + #[error("Proof serialization failed: {0}")] + SerializationWrite(String), + #[error(transparent)] + MerkleProofError(#[from] GetMerkleTreeProofError), +} + +impl From for ProofGenerationStringError { + fn from(value: ProofGenerationError) -> Self { + match value { + ProofGenerationError::Proof(e) => ProofGenerationStringError::Proof(e.to_string()), + ProofGenerationError::Serialization(e) => { + ProofGenerationStringError::Serialization(e.to_string()) + } + ProofGenerationError::SerializationWrite(e) => { + ProofGenerationStringError::SerializationWrite(e.to_string()) + } + ProofGenerationError::MerkleProofError(e) => { + ProofGenerationStringError::MerkleProofError(e) + } + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum RegisterError { + #[error("User (address: {0:?}) has already been registered")] + AlreadyRegistered(Address), + #[error("Merkle tree error: {0}")] + TreeError(String), +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum GetMerkleTreeProofError { + #[error("User not registered")] + NotRegistered, + #[error("Merkle tree error: {0}")] + TreeError(String), +} diff --git a/prover/src/grpc_service.rs b/prover/src/grpc_service.rs index f152d2f..83d7722 100644 --- a/prover/src/grpc_service.rs +++ b/prover/src/grpc_service.rs @@ -1,38 +1,27 @@ -use std::collections::BTreeMap; // std +use std::collections::BTreeMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; // third-party use alloy::primitives::{Address, U256}; -use ark_bn254::Fr; use async_channel::Sender; use bytesize::ByteSize; use futures::TryFutureExt; use http::Method; use tokio::sync::{broadcast, mpsc}; use tonic::{ - Request, - Response, - Status, - codegen::tokio_stream::wrappers::ReceiverStream, - transport::Server, - // codec::CompressionEncoding + Request, Response, Status, codegen::tokio_stream::wrappers::ReceiverStream, transport::Server, }; use tonic_web::GrpcWebLayer; use tower_http::cors::{Any, CorsLayer}; -use tracing::{ - debug, - // error, - // info -}; +use tracing::debug; // internal -use crate::error::{ - AppError, - // RegistrationError -}; +use crate::error::{AppError, ProofGenerationStringError, RegisterError}; +use crate::proof_generation::{ProofGenerationData, ProofSendingData}; +use crate::tier::{KarmaAmount, TierLimit, TierName}; use crate::user_db_service::{KarmaAmountExt, UserDb, UserTierInfo}; -use rln_proof::{RlnIdentifier, RlnUserIdentity}; +use rln_proof::RlnIdentifier; pub mod prover_proto { @@ -42,12 +31,13 @@ pub mod prover_proto { pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = tonic::include_file_descriptor_set!("prover_descriptor"); } -use crate::tier::{KarmaAmount, TierLimit, TierName}; use prover_proto::{ - GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply, RegisterUserRequest, RlnProof, - RlnProofFilter, SendTransactionReply, SendTransactionRequest, SetTierLimitsReply, - SetTierLimitsRequest, Tier, UserTierInfoError, UserTierInfoResult, + GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply, RegisterUserRequest, + RegistrationStatus, RlnProof, RlnProofFilter, RlnProofReply, SendTransactionReply, + SendTransactionRequest, SetTierLimitsReply, SetTierLimitsRequest, Tier, UserTierInfoError, + UserTierInfoResult, get_user_tier_info_reply::Resp, + rln_proof_reply::Resp as GetProofsResp, rln_prover_server::{RlnProver, RlnProverServer}, }; @@ -62,15 +52,16 @@ const PROVER_SERVICE_HTTP2_MAX_FRAME_SIZE: ByteSize = ByteSize::kib(16); const PROVER_SERVICE_MESSAGE_DECODING_MAX_SIZE: ByteSize = ByteSize::mib(5); // Max size for Message (encoding, e.g., 5 Mb) const PROVER_SERVICE_MESSAGE_ENCODING_MAX_SIZE: ByteSize = ByteSize::mib(5); -const PROVER_SPAM_LIMIT: u64 = 10_000; #[derive(Debug)] pub struct ProverService { - proof_sender: Sender<(RlnUserIdentity, Arc, u64)>, + proof_sender: Sender, user_db: UserDb, rln_identifier: Arc, - spam_limit: u64, - broadcast_channel: (broadcast::Sender>, broadcast::Receiver>), + broadcast_channel: ( + broadcast::Sender>, + broadcast::Receiver>, + ), } #[tonic::async_trait] @@ -101,18 +92,26 @@ impl RlnProver for ProverService { // Update the counter as soon as possible (should help to prevent spamming...) let counter = self.user_db.on_new_tx(&sender).unwrap_or_default(); - let user_identity = RlnUserIdentity { - secret_hash: user_id.secret_hash, - commitment: user_id.commitment, - user_limit: Fr::from(self.spam_limit), - }; + if req.transaction_hash.len() != 32 { + return Err(Status::invalid_argument( + "Invalid transaction hash (should be 32 bytes)", + )); + } // Inexpensive clone (behind Arc ptr) let rln_identifier = self.rln_identifier.clone(); + let proof_data = ProofGenerationData::from(( + user_id, + rln_identifier, + counter.into(), + sender, + req.transaction_hash, + )); + // Send some data to one of the proof services self.proof_sender - .send((user_identity, rln_identifier, counter.into())) + .send(proof_data) .await .map_err(|e| Status::from_error(Box::new(e)))?; @@ -122,13 +121,36 @@ impl RlnProver for ProverService { async fn register_user( &self, - _request: Request, + request: Request, ) -> Result, Status> { - let reply = RegisterUserReply { status: 0 }; + debug!("register_user request: {:?}", request); + + let req = request.into_inner(); + let user = if let Some(user) = req.user { + if let Ok(user) = Address::try_from(user.value.as_slice()) { + user + } else { + return Err(Status::invalid_argument("Invalid sender address")); + } + } else { + return Err(Status::invalid_argument("No sender address")); + }; + + let result = self.user_db.on_new_user(user); + + let status = match result { + Ok(_) => RegistrationStatus::Success, + Err(RegisterError::AlreadyRegistered(_a)) => RegistrationStatus::AlreadyRegistered, + _ => RegistrationStatus::Failure, + }; + + let reply = RegisterUserReply { + status: status.into(), + }; Ok(Response::new(reply)) } - type GetProofsStream = ReceiverStream>; + type GetProofsStream = ReceiverStream>; async fn get_proofs( &self, @@ -139,13 +161,19 @@ impl RlnProver for ProverService { let (tx, rx) = mpsc::channel(100); let mut rx2 = self.broadcast_channel.0.subscribe(); tokio::spawn(async move { - while let Ok(data) = rx2.recv().await { + // FIXME: Should we send the error here? + while let Ok(Ok(data)) = rx2.recv().await { let rln_proof = RlnProof { - sender: "0xAA".to_string(), - id_commitment: "1".to_string(), - proof: data, + sender: data.tx_sender.to_vec(), + tx_hash: data.tx_hash, + proof: data.proof, }; - if let Err(e) = tx.send(Ok(rln_proof)).await { + + let resp = RlnProofReply { + resp: Some(GetProofsResp::Proof(rln_proof)), + }; + + if let Err(e) = tx.send(Ok(resp)).await { debug!("Done: sending dummy rln proofs: {}", e); break; }; @@ -235,8 +263,11 @@ impl RlnProver for ProverService { } pub(crate) struct GrpcProverService { - pub proof_sender: Sender<(RlnUserIdentity, Arc, u64)>, - pub broadcast_channel: (broadcast::Sender>, broadcast::Receiver>), + pub proof_sender: Sender, + pub broadcast_channel: ( + broadcast::Sender>, + broadcast::Receiver>, + ), pub addr: SocketAddr, pub rln_identifier: RlnIdentifier, pub user_db: UserDb, @@ -248,7 +279,6 @@ impl GrpcProverService { proof_sender: self.proof_sender.clone(), user_db: self.user_db.clone(), rln_identifier: Arc::new(self.rln_identifier.clone()), - spam_limit: PROVER_SPAM_LIMIT, broadcast_channel: ( self.broadcast_channel.0.clone(), self.broadcast_channel.0.subscribe(), diff --git a/prover/src/main.rs b/prover/src/main.rs index c355842..2db3690 100644 --- a/prover/src/main.rs +++ b/prover/src/main.rs @@ -3,8 +3,8 @@ mod args; mod epoch_service; mod error; mod grpc_service; +mod proof_generation; mod proof_service; -mod registry; mod registry_listener; mod tier; mod user_db_service; @@ -29,9 +29,10 @@ use crate::args::AppArgs; use crate::epoch_service::EpochService; use crate::grpc_service::GrpcProverService; use crate::proof_service::ProofService; -use crate::user_db_service::UserDbService; +use crate::user_db_service::{RateLimit, UserDbService}; const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier"; +const PROVER_SPAM_LIMIT: RateLimit = RateLimit::new(10_000u64); const PROOF_SERVICE_COUNT: u8 = 8; const GENESIS: DateTime = DateTime::from_timestamp(1431648000, 0).unwrap(); @@ -63,6 +64,7 @@ async fn main() -> Result<(), Box> { let user_db_service = UserDbService::new( epoch_service.epoch_changes.clone(), epoch_service.current_epoch.clone(), + PROVER_SPAM_LIMIT, ); // proof service @@ -76,6 +78,7 @@ async fn main() -> Result<(), Box> { let rln_identifier = RlnIdentifier::new(RLN_IDENTIFIER_NAME); let addr = SocketAddr::new(app_args.ip, app_args.port); debug!("Listening on: {}", addr); + // TODO: broadcast subscribe? let prover_grpc_service = GrpcProverService { proof_sender, broadcast_channel: (tx.clone(), rx), @@ -89,9 +92,16 @@ async fn main() -> Result<(), Box> { let proof_recv = proof_receiver.clone(); let broadcast_sender = tx.clone(); let current_epoch = epoch_service.current_epoch.clone(); + let user_db = user_db_service.get_user_db(); set.spawn(async { - let proof_service = ProofService::new(proof_recv, broadcast_sender, current_epoch); + let proof_service = ProofService::new( + proof_recv, + broadcast_sender, + current_epoch, + user_db, + PROVER_SPAM_LIMIT, + ); proof_service.serve().await }); } diff --git a/prover/src/proof_generation.rs b/prover/src/proof_generation.rs new file mode 100644 index 0000000..28be833 --- /dev/null +++ b/prover/src/proof_generation.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; +// third-party +use alloy::primitives::Address; +// internal +use rln_proof::{RlnIdentifier, RlnUserIdentity}; + +#[derive(Debug, Clone)] +pub(crate) struct ProofGenerationData { + pub(crate) user_identity: RlnUserIdentity, + pub(crate) rln_identifier: Arc, + pub(crate) tx_counter: u64, + pub(crate) tx_sender: Address, + pub(crate) tx_hash: Vec, +} + +impl From<(RlnUserIdentity, Arc, u64, Address, Vec)> for ProofGenerationData { + /// Create a new ProofGenerationData - assume tx_hash is 32 bytes long + fn from( + (user_identity, rln_identifier, tx_counter, tx_sender, tx_hash): ( + RlnUserIdentity, + Arc, + u64, + Address, + Vec, + ), + ) -> Self { + debug_assert!(tx_hash.len() == 32); + Self { + user_identity, + rln_identifier, + tx_counter, + tx_sender, + tx_hash, + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ProofSendingData { + pub(crate) tx_hash: Vec, + pub(crate) tx_sender: Address, + pub(crate) proof: Vec, +} diff --git a/prover/src/proof_service.rs b/prover/src/proof_service.rs index 253620b..e098b4b 100644 --- a/prover/src/proof_service.rs +++ b/prover/src/proof_service.rs @@ -2,50 +2,48 @@ use std::io::{Cursor, Write}; use std::sync::Arc; // third-party use ark_bn254::Fr; -use ark_serialize::{CanonicalSerialize, SerializationError}; +use ark_serialize::CanonicalSerialize; use async_channel::Receiver; use parking_lot::RwLock; -use rln::hashers::{hash_to_field, poseidon_hash}; -use rln::pm_tree_adapter::PmTree; -use rln::protocol::{ProofError, serialize_proof_values}; -use tracing::debug; +use rln::hashers::hash_to_field; +use rln::protocol::serialize_proof_values; +use tracing::{debug, info}; // internal use crate::epoch_service::{Epoch, EpochSlice}; -use crate::error::AppError; -use rln_proof::{ - RlnData, RlnIdentifier, RlnUserIdentity, ZerokitMerkleTree, compute_rln_proof_and_values, -}; +use crate::error::{AppError, ProofGenerationError, ProofGenerationStringError}; +use crate::proof_generation::{ProofGenerationData, ProofSendingData}; +use crate::user_db_service::{RateLimit, UserDb}; +use rln_proof::{RlnData, compute_rln_proof_and_values}; -#[derive(thiserror::Error, Debug)] -enum ProofGenerationError { - #[error("Proof generation failed: {0}")] - Proof(#[from] ProofError), - #[error("Proof serialization failed: {0}")] - Serialization(#[from] SerializationError), - #[error("Proof serialization failed: {0}")] - SerializationWrite(#[from] std::io::Error), - #[error("Error: {0}")] - Misc(String), -} +const PROOF_SIZE: usize = 512; /// A service to generate a RLN proof (and then to broadcast it) -#[derive(Debug)] pub struct ProofService { - receiver: Receiver<(RlnUserIdentity, Arc, u64)>, - broadcast_sender: tokio::sync::broadcast::Sender>, + receiver: Receiver, + broadcast_sender: + tokio::sync::broadcast::Sender>, current_epoch: Arc>, + user_db: UserDb, + rate_limit: RateLimit, } impl ProofService { pub(crate) fn new( - receiver: Receiver<(RlnUserIdentity, Arc, u64)>, - broadcast_sender: tokio::sync::broadcast::Sender>, + receiver: Receiver, + broadcast_sender: tokio::sync::broadcast::Sender< + Result, + >, current_epoch: Arc>, + user_db: UserDb, + rate_limit: RateLimit, ) -> Self { + debug_assert!(rate_limit > RateLimit::ZERO); Self { receiver, broadcast_sender, current_epoch, + user_db, + rate_limit, } } @@ -54,19 +52,33 @@ impl ProofService { let received = self.receiver.recv().await; if let Err(e) = received { - debug!("Stopping proof generation service: {}", e); + info!("Stopping proof generation service: {}", e); break; } - let (user_identity, rln_identifier, counter) = received.unwrap(); + + let proof_generation_data = received.unwrap(); let (current_epoch, current_epoch_slice) = *self.current_epoch.read(); + let user_db = self.user_db.clone(); + let proof_generation_data_ = proof_generation_data.clone(); + let rate_limit = self.rate_limit; // Move to a task (as generating the proof can take quite some time) let blocking_task = tokio::task::spawn_blocking(move || { + let message_id = { + let mut m_id = proof_generation_data.tx_counter; + // Note: Zerokit can only recover user secret hash with 2 messages with the + // same message_id so here we force to use the previous message_id + // so the Verifier could recover the secret hash + if RateLimit::from(m_id) == rate_limit { + m_id -= 1; + } + m_id + }; + let rln_data = RlnData { - message_id: Fr::from(counter), - // TODO: tx hash to field - data: hash_to_field(b"RLN is awesome"), + message_id: Fr::from(message_id), + data: hash_to_field(proof_generation_data.tx_hash.as_slice()), }; let epoch_bytes = { @@ -76,33 +88,22 @@ impl ProofService { }; let epoch = hash_to_field(epoch_bytes.as_slice()); - // FIXME: maintain tree in Prover or query RLN Reg SC ? - // Merkle tree - let tree_height = 20; - let mut tree = PmTree::new(tree_height, Fr::from(0), Default::default()) - .map_err(|e| ProofGenerationError::Misc(e.to_string()))?; - - // let mut tree = OptimalMerkleTree::new(tree_height, Fr::from(0), Default::default()).unwrap(); - let rate_commit = - poseidon_hash(&[user_identity.commitment, user_identity.user_limit]); - tree.set(0, rate_commit) - .map_err(|e| ProofGenerationError::Misc(e.to_string()))?; - let merkle_proof = tree - .proof(0) - .map_err(|e| ProofGenerationError::Misc(e.to_string()))?; + let merkle_proof = user_db.get_merkle_proof(&proof_generation_data.tx_sender)?; let (proof, proof_values) = compute_rln_proof_and_values( - &user_identity, - &rln_identifier, + &proof_generation_data.user_identity, + &proof_generation_data.rln_identifier, rln_data, epoch, &merkle_proof, ) .map_err(ProofGenerationError::Proof)?; + debug!("proof: {:?}", proof); + debug!("proof_values: {:?}", proof_values); + // Serialize proof - // FIXME: proof size? - let mut output_buffer = Cursor::new(Vec::with_capacity(512)); + let mut output_buffer = Cursor::new(Vec::with_capacity(PROOF_SIZE)); proof .serialize_compressed(&mut output_buffer) .map_err(ProofGenerationError::Serialization)?; @@ -114,18 +115,446 @@ impl ProofService { }); let result = blocking_task.await; - // if let Err(e) = result { - // return Err(Status::from_error(Box::new(e))); - // } - // blocking_task returns Result, _>> // Result (1st) is a JoinError (and should not happen) // Result (2nd) is a ProofGenerationError - let result = result.unwrap().unwrap(); - // TODO: no unwrap() - // FIXME: send proof + other info - self.broadcast_sender.send(result).unwrap(); + let result = result.unwrap(); // Should never happen (but should panic if it does) + + let proof_sending_data = result + .map(|r| ProofSendingData { + tx_hash: proof_generation_data_.tx_hash, + tx_sender: proof_generation_data_.tx_sender, + proof: r, + }) + .map_err(ProofGenerationStringError::from); + + if let Err(e) = self.broadcast_sender.send(proof_sending_data) { + info!("Stopping proof generation service: {}", e); + break; + }; } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + // third-party + use alloy::primitives::{Address, address}; + use ark_groth16::{Proof as ArkProof, Proof, VerifyingKey}; + use ark_serialize::CanonicalDeserialize; + use claims::assert_matches; + use futures::TryFutureExt; + use rln::circuit::{Curve, zkey_from_folder}; + use tokio::sync::broadcast; + use tracing::info; + // third-party: zerokit + use rln::protocol::{compute_id_secret, deserialize_proof_values, keygen, verify_proof}; + // internal + use crate::user_db_service::UserDbService; + use rln_proof::RlnIdentifier; + + const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045"); + const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0"); + + const TX_HASH_1: [u8; 32] = [0x011; 32]; + const TX_HASH_1_2: [u8; 32] = [0x12; 32]; + + #[derive(thiserror::Error, Debug)] + enum AppErrorExt { + #[error("AppError: {0}")] + AppError(#[from] AppError), + #[error("Future timeout")] + Elapsed, + #[error("Proof generation failed: {0}")] + ProofGeneration(#[from] ProofGenerationStringError), + #[error("Proof verification failed")] + ProofVerification, + #[error("Exiting...")] + Exit, + #[error("Recovered secret")] + RecoveredSecret(Fr), + } + + async fn proof_sender( + sender: Address, + proof_tx: &mut async_channel::Sender, + rln_identifier: Arc, + user_db: &UserDb, + ) -> Result<(), AppErrorExt> { + // used by test_proof_generation unit test + + debug!("Starting proof sender..."); + debug!("Waiting a bit before sending proof..."); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + debug!("Sending proof..."); + proof_tx + .send(ProofGenerationData { + user_identity: user_db.get_user(&ADDR_1).unwrap(), + rln_identifier, + tx_counter: 0, + tx_sender: sender, + tx_hash: TX_HASH_1.to_vec(), + }) + .await + .unwrap(); + debug!("Sending proof done"); + // tokio::time::sleep(std::time::Duration::from_secs(10)).await; + Ok::<(), AppErrorExt>(()) + } + + async fn proof_verifier( + broadcast_receiver: &mut broadcast::Receiver< + Result, + >, + verifying_key: &VerifyingKey, + ) -> Result<(), AppErrorExt> { + // used by test_proof_generation unit test + + debug!("Starting broadcast receiver..."); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let res = + tokio::time::timeout(std::time::Duration::from_secs(5), broadcast_receiver.recv()) + .await + .map_err(|_e| AppErrorExt::Elapsed)?; + debug!("res: {:?}", res); + + let res = res.unwrap(); + let res = res?; + let mut proof_cursor = Cursor::new(&res.proof); + debug!("proof cursor: {:?}", proof_cursor); + let proof = ArkProof::deserialize_compressed(&mut proof_cursor).unwrap(); + let position = proof_cursor.position() as usize; + let proof_cursor_2 = &proof_cursor.get_ref().as_slice()[position..]; + let (proof_values, _) = deserialize_proof_values(proof_cursor_2); + debug!("[proof verifier] proof: {:?}", proof); + debug!("[proof verifier] proof_values: {:?}", proof_values); + + let verified = verify_proof(verifying_key, &proof, &proof_values) + .map_err(|_e| AppErrorExt::ProofVerification)?; + + debug!("verified: {:?}", verified); + + // Exit after receiving one proof + Err::<(), AppErrorExt>(AppErrorExt::Exit) + } + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_proof_generation() { + // Queues + let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2); + let mut broadcast_receiver = broadcast_sender.subscribe(); + let (mut proof_tx, proof_rx) = async_channel::unbounded(); + + // Epoch + let epoch = Epoch::from(11); + let epoch_slice = EpochSlice::from(42); + let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); + + // User db + let user_db_service = + UserDbService::new(Default::default(), epoch_store.clone(), 10.into()); + let user_db = user_db_service.get_user_db(); + user_db.on_new_user(ADDR_1).unwrap(); + user_db.on_new_user(ADDR_2).unwrap(); + + let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); + + // Proof service + let proof_service = ProofService::new( + proof_rx, + broadcast_sender, + epoch_store, + user_db.clone(), + RateLimit::from(10), + ); + + // Verification + let proving_key = zkey_from_folder(); + let verification_key = &proving_key.0.vk; + + info!("Starting..."); + let res = tokio::try_join!( + proof_service.serve().map_err(AppErrorExt::AppError), + proof_verifier(&mut broadcast_receiver, verification_key), + proof_sender(ADDR_1, &mut proof_tx, rln_identifier.clone(), &user_db), + ); + + // Everything ok if proof_verifier return AppErrorExt::Exit else there is a real error + assert_matches!(res, Err(AppErrorExt::Exit)); + } + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_user_not_registered() { + // Ask for a proof for an unregistered user + + // Queues + let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2); + let mut broadcast_receiver = broadcast_sender.subscribe(); + let (mut proof_tx, proof_rx) = async_channel::unbounded(); + + // Epoch + let epoch = Epoch::from(11); + let epoch_slice = EpochSlice::from(42); + let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); + + // User db + let user_db_service = + UserDbService::new(Default::default(), epoch_store.clone(), 10.into()); + let user_db = user_db_service.get_user_db(); + user_db.on_new_user(ADDR_1).unwrap(); + // user_db.on_new_user(ADDR_2).unwrap(); + + let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); + + // Proof service + let proof_service = ProofService::new( + proof_rx, + broadcast_sender, + epoch_store, + user_db.clone(), + RateLimit::from(10), + ); + + // Verification + let proving_key = zkey_from_folder(); + let verification_key = &proving_key.0.vk; + + info!("Starting..."); + let res = tokio::try_join!( + proof_service.serve().map_err(AppErrorExt::AppError), + proof_verifier(&mut broadcast_receiver, verification_key), + proof_sender(ADDR_2, &mut proof_tx, rln_identifier.clone(), &user_db), + ); + + // Expect this error (any other error is a real error) + assert_matches!( + res, + Err(AppErrorExt::ProofGeneration( + ProofGenerationStringError::MerkleProofError(_) + )) + ); + } + + async fn proof_reveal_secret( + broadcast_receiver: &mut broadcast::Receiver< + Result, + >, + verifying_key: &VerifyingKey, + ) -> Result<(), AppErrorExt> { + // used by test_user_spamming unit test + + debug!("Starting broadcast receiver..."); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let mut proof_values_store = vec![]; + + loop { + let res = + tokio::time::timeout(std::time::Duration::from_secs(5), broadcast_receiver.recv()) + .await + .map_err(|_e| AppErrorExt::Elapsed)?; + + let res = res.unwrap(); + let res = res?; + let mut proof_cursor = Cursor::new(&res.proof); + let proof: Proof = ArkProof::deserialize_compressed(&mut proof_cursor).unwrap(); + let position = proof_cursor.position() as usize; + let proof_cursor_2 = &proof_cursor.get_ref().as_slice()[position..]; + let (proof_values, _) = deserialize_proof_values(proof_cursor_2); + proof_values_store.push(proof_values); + if proof_values_store.len() >= 2 { + break; + } + } + + debug!("Now recovering secret hash..."); + let proof_values_0 = proof_values_store.get(0).unwrap(); + let proof_values_1 = proof_values_store.get(1).unwrap(); + println!("proof_values_0: {:?}", proof_values_0); + println!("proof_values_1: {:?}", proof_values_1); + let share1 = (proof_values_0.x, proof_values_0.y); + let share2 = (proof_values_1.x, proof_values_1.y); + + // TODO: should we check external nullifier as well? + let recovered_identity_secret_hash = compute_id_secret(share1, share2).unwrap(); + + debug!( + "recovered_identity_secret_hash: {:?}", + recovered_identity_secret_hash + ); + + // Exit after receiving one proof + Err::<(), AppErrorExt>(AppErrorExt::RecoveredSecret(recovered_identity_secret_hash)) + } + + async fn proof_sender_2( + proof_tx: &mut async_channel::Sender, + rln_identifier: Arc, + user_db: &UserDb, + sender: Address, + tx_hashes: ([u8; 32], [u8; 32]), + ) -> Result<(), AppErrorExt> { + // used by test_proof_generation unit test + + debug!("Starting proof sender 2..."); + debug!("Waiting a bit before sending proof..."); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + debug!("Sending proof..."); + proof_tx + .send(ProofGenerationData { + user_identity: user_db.get_user(&sender).unwrap(), + rln_identifier: rln_identifier.clone(), + tx_counter: 0, + tx_sender: sender.clone(), + tx_hash: tx_hashes.0.to_vec(), + }) + .await + .unwrap(); + debug!("Sending proof done"); + + debug!("Waiting a bit before sending 2nd proof..."); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + debug!("Sending 2nd proof..."); + proof_tx + .send(ProofGenerationData { + user_identity: user_db.get_user(&sender).unwrap(), + rln_identifier, + tx_counter: 1, + tx_sender: sender, + tx_hash: tx_hashes.1.to_vec(), + }) + .await + .unwrap(); + debug!("Sending 2nd proof done"); + + Ok::<(), AppErrorExt>(()) + } + + #[tokio::test] + #[tracing_test::traced_test] + async fn test_user_spamming() { + // Recover secret from a user spamming the system + + // Queues + let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2); + let mut broadcast_receiver = broadcast_sender.subscribe(); + let (mut proof_tx, proof_rx) = async_channel::unbounded(); + + // Epoch + let epoch = Epoch::from(11); + let epoch_slice = EpochSlice::from(42); + let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); + + // Limits + let rate_limit = RateLimit::from(1); + + // User db + let user_db_service = + UserDbService::new(Default::default(), epoch_store.clone(), rate_limit); + let user_db = user_db_service.get_user_db(); + user_db.on_new_user(ADDR_1).unwrap(); + let user_addr_1 = user_db.get_user(&ADDR_1).unwrap(); + user_db.on_new_user(ADDR_2).unwrap(); + + let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); + + // Proof service + let proof_service = ProofService::new( + proof_rx, + broadcast_sender, + epoch_store, + user_db.clone(), + rate_limit, + ); + + // Verification + let proving_key = zkey_from_folder(); + let verification_key = &proving_key.0.vk; + + info!("Starting..."); + let res = tokio::try_join!( + proof_service.serve().map_err(AppErrorExt::AppError), + proof_reveal_secret(&mut broadcast_receiver, verification_key), + proof_sender_2( + &mut proof_tx, + rln_identifier.clone(), + &user_db, + ADDR_1, + (TX_HASH_1, TX_HASH_1_2) + ), + ); + + match res { + Err(AppErrorExt::RecoveredSecret(secret_hash)) => { + assert_eq!(secret_hash, user_addr_1.secret_hash); + } + _ => { + panic!("Unexpected result"); + } + } + } + + #[tokio::test] + #[ignore] + #[tracing_test::traced_test] + async fn test_user_spamming_same_signal() { + // Recover secret from a user spamming the system + + // Queues + let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2); + let mut broadcast_receiver = broadcast_sender.subscribe(); + let (mut proof_tx, proof_rx) = async_channel::unbounded(); + + // Epoch + let epoch = Epoch::from(11); + let epoch_slice = EpochSlice::from(42); + let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); + + // Limits + let rate_limit = RateLimit::from(1); + + // User db - limit is 1 message per epoch + let user_db_service = + UserDbService::new(Default::default(), epoch_store.clone(), rate_limit.into()); + let user_db = user_db_service.get_user_db(); + user_db.on_new_user(ADDR_1).unwrap(); + let user_addr_1 = user_db.get_user(&ADDR_1).unwrap(); + debug!("user_addr_1: {:?}", user_addr_1); + user_db.on_new_user(ADDR_2).unwrap(); + + let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz")); + + // Proof service + let proof_service = ProofService::new( + proof_rx, + broadcast_sender, + epoch_store, + user_db.clone(), + rate_limit, + ); + + // Verification + let proving_key = zkey_from_folder(); + let verification_key = &proving_key.0.vk; + + info!("Starting..."); + let res = tokio::try_join!( + proof_service.serve().map_err(AppErrorExt::AppError), + proof_reveal_secret(&mut broadcast_receiver, verification_key), + proof_sender_2( + &mut proof_tx, + rln_identifier.clone(), + &user_db, + ADDR_1, + (TX_HASH_1, TX_HASH_1) + ), + ); + + // TODO: wait for Zerokit 0.8 + // assert_matches!(res, Err(AppErrorExt::Exit)); + } +} diff --git a/prover/src/registry.rs b/prover/src/registry.rs deleted file mode 100644 index fee0ace..0000000 --- a/prover/src/registry.rs +++ /dev/null @@ -1,51 +0,0 @@ -// use alloy::primitives::Address; -// use ark_bn254::Fr; -// use dashmap::DashMap; -// use dashmap::mapref::one::Ref; -// use rln::protocol::keygen; - -/* -#[derive(Debug)] -pub(crate) struct UserRegistry { - inner: DashMap, -} - -impl UserRegistry { - fn new() -> Self { - Self { - inner: DashMap::new(), - } - } - - pub(crate) fn get(&self, address: &Address) -> Option> { - self.inner.get(address) - } - - fn register(&self, address: Address) { - let (identity_secret_hash, id_commitment) = keygen(); - self.inner - .insert(address, (identity_secret_hash, id_commitment)); - } -} - -impl Default for UserRegistry { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use alloy::primitives::address; - - #[test] - fn test_user_registration() { - let address = address!("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f"); - let reg = UserRegistry::default(); - reg.register(address); - - assert!(reg.get(&address).is_some()); - } -} -*/ diff --git a/prover/src/user_db_service.rs b/prover/src/user_db_service.rs index 04d4df9..44f0945 100644 --- a/prover/src/user_db_service.rs +++ b/prover/src/user_db_service.rs @@ -4,31 +4,107 @@ use std::ops::Deref; use std::sync::Arc; // third-party use alloy::primitives::{Address, U256}; +use ark_bn254::Fr; use derive_more::{Add, From, Into}; use parking_lot::RwLock; +use rln::hashers::poseidon_hash; +use rln::pm_tree_adapter::{PmTree, PmTreeProof}; use rln::protocol::keygen; use scc::HashMap; use tokio::sync::Notify; use tracing::debug; // internal use crate::epoch_service::{Epoch, EpochSlice}; -use crate::error::AppError; +use crate::error::{AppError, GetMerkleTreeProofError, RegisterError}; use crate::tier::{KarmaAmount, TIER_LIMITS, TierLimit, TierName}; -use rln_proof::RlnUserIdentity; +use rln_proof::{RlnUserIdentity, ZerokitMerkleTree}; -#[derive(Debug, Default, Clone)] -pub(crate) struct UserRegistry { - inner: HashMap, +const MERKLE_TREE_HEIGHT: usize = 20; + +#[derive(Debug, Clone, Copy, From, Into)] +struct MerkleTreeIndex(usize); + +#[derive(Debug, Clone, Copy, Default, PartialOrd, PartialEq, From, Into)] +pub struct RateLimit(u64); + +impl RateLimit { + pub(crate) const ZERO: RateLimit = RateLimit(0); + + pub(crate) const fn new(value: u64) -> Self { + Self(value) + } } + +impl From for Fr { + fn from(rate_limit: RateLimit) -> Self { + Fr::from(rate_limit.0) + } +} + +#[derive(Clone)] +pub(crate) struct UserRegistry { + inner: HashMap, + tree: Arc>, + rate_limit: RateLimit, +} + +impl std::fmt::Debug for UserRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "UserRegistry {{ inner: {:?} }}", self.inner) + } +} + +impl Default for UserRegistry { + fn default() -> Self { + Self { + inner: Default::default(), + // unwrap safe - no config + tree: Arc::new(RwLock::new( + PmTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default()).unwrap(), + )), + rate_limit: Default::default(), + } + } +} + +impl From for UserRegistry { + fn from(rate_limit: RateLimit) -> Self { + Self { + inner: Default::default(), + // unwrap safe - no config + tree: Arc::new(RwLock::new( + PmTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default()).unwrap(), + )), + rate_limit, + } + } +} + impl UserRegistry { - fn register(&self, address: Address) -> bool { + fn register(&self, address: Address) -> Result<(), RegisterError> { let (identity_secret_hash, id_commitment) = keygen(); - self.inner + let index = self.inner.len(); + let res = self + .inner .insert( address, - RlnUserIdentity::from((identity_secret_hash, id_commitment)), + ( + RlnUserIdentity::from(( + identity_secret_hash, + id_commitment, + Fr::from(self.rate_limit), + )), + MerkleTreeIndex(index), + ), ) - .is_ok() + .map_err(|_e| RegisterError::AlreadyRegistered(address)); + + let rate_commit = poseidon_hash(&[id_commitment, Fr::from(u64::from(self.rate_limit))]); + self.tree + .write() + .set(index, rate_commit) + .map_err(|e| RegisterError::TreeError(e.to_string()))?; + res } fn has_user(&self, address: &Address) -> bool { @@ -36,7 +112,19 @@ impl UserRegistry { } fn get_user(&self, address: &Address) -> Option { - self.inner.get(address).map(|entry| entry.clone()) + self.inner.get(address).map(|entry| entry.0.clone()) + } + + fn get_merkle_proof(&self, address: &Address) -> Result { + let index = self + .inner + .get(address) + .map(|entry| entry.1) + .ok_or(GetMerkleTreeProofError::NotRegistered)?; + self.tree + .read() + .proof(index.into()) + .map_err(|e| GetMerkleTreeProofError::TreeError(e.to_string())) } } @@ -129,10 +217,21 @@ impl UserDb { } } + pub fn on_new_user(&self, address: Address) -> Result<(), RegisterError> { + self.user_registry.register(address) + } + pub fn get_user(&self, address: &Address) -> Option { self.user_registry.get_user(address) } + pub fn get_merkle_proof( + &self, + address: &Address, + ) -> Result { + self.user_registry.get_merkle_proof(address) + } + pub(crate) fn on_new_tx(&self, address: &Address) -> Option { if self.user_registry.has_user(address) { Some(self.tx_registry.incr_counter(address, None)) @@ -249,10 +348,11 @@ impl UserDbService { pub(crate) fn new( epoch_changes_notifier: Arc, epoch_store: Arc>, + rate_limit: RateLimit, ) -> Self { Self { user_db: UserDb { - user_registry: Default::default(), + user_registry: Arc::new(UserRegistry::from(rate_limit)), tx_registry: Default::default(), tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())), tier_limits_next: Arc::new(Default::default()), @@ -308,7 +408,7 @@ impl UserDbService { mod tests { use super::*; use alloy::primitives::address; - use claims::assert_err; + use claims::{assert_err, assert_matches}; struct MockKarmaSc {} @@ -335,6 +435,23 @@ mod tests { } } + #[test] + fn test_user_register() { + let user_db = UserDb { + user_registry: Default::default(), + tx_registry: Default::default(), + tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())), + tier_limits_next: Arc::new(Default::default()), + epoch_store: Arc::new(RwLock::new(Default::default())), + }; + let addr = Address::new([0; 20]); + user_db.user_registry.register(addr).unwrap(); + assert_matches!( + user_db.user_registry.register(addr), + Err(RegisterError::AlreadyRegistered(_)) + ); + } + #[tokio::test] async fn test_incr_tx_counter() { let user_db = UserDb { @@ -354,8 +471,9 @@ mod tests { tier_info, Err(UserTierInfoError::NotRegistered(_)) )); - // Register user + update tx counter - user_db.user_registry.register(addr); + // Register user + user_db.user_registry.register(addr).unwrap(); + // Now update user tx counter assert_eq!(user_db.on_new_tx(&addr), Some(EpochSliceCounter(1))); let tier_info = user_db.user_tier_info(&addr, MockKarmaSc {}).await.unwrap(); assert_eq!(tier_info.epoch_tx_count, 1); @@ -367,16 +485,16 @@ mod tests { let mut epoch = Epoch::from(11); let mut epoch_slice = EpochSlice::from(42); let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - let user_db_service = UserDbService::new(Default::default(), epoch_store); + let user_db_service = UserDbService::new(Default::default(), epoch_store, 10.into()); let user_db = user_db_service.get_user_db(); let addr_1_tx_count = 2; let addr_2_tx_count = 820; - user_db.user_registry.register(ADDR_1); + user_db.user_registry.register(ADDR_1).unwrap(); user_db .tx_registry .incr_counter(&ADDR_1, Some(addr_1_tx_count)); - user_db.user_registry.register(ADDR_2); + user_db.user_registry.register(ADDR_2).unwrap(); user_db .tx_registry .incr_counter(&ADDR_2, Some(addr_2_tx_count)); @@ -450,7 +568,7 @@ mod tests { let mut epoch = Epoch::from(11); let mut epoch_slice = EpochSlice::from(42); let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - let user_db_service = UserDbService::new(Default::default(), epoch_store); + let user_db_service = UserDbService::new(Default::default(), epoch_store, 10.into()); let user_db = user_db_service.get_user_db(); let tier_limits_original = user_db.tier_limits.read().clone(); @@ -498,7 +616,7 @@ mod tests { let epoch = Epoch::from(11); let epoch_slice = EpochSlice::from(42); let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice))); - let user_db_service = UserDbService::new(Default::default(), epoch_store); + let user_db_service = UserDbService::new(Default::default(), epoch_store, 10.into()); let user_db = user_db_service.get_user_db(); let tier_limits_original = user_db.tier_limits.read().clone(); diff --git a/rln_proof/src/proof.rs b/rln_proof/src/proof.rs index bf422ed..3825133 100644 --- a/rln_proof/src/proof.rs +++ b/rln_proof/src/proof.rs @@ -7,8 +7,10 @@ use ark_relations::r1cs::ConstraintMatrices; use rln::circuit::ZKEY_BYTES; use rln::circuit::zkey::read_zkey; use rln::hashers::{hash_to_field, poseidon_hash}; +use rln::pm_tree_adapter::PmTree; use rln::protocol::{ - ProofError, RLNProofValues, generate_proof, proof_values_from_witness, rln_witness_from_values, + ProofError, RLNProofValues, compute_id_secret, generate_proof, keygen, + proof_values_from_witness, rln_witness_from_values, }; /// A RLN user identity & limit @@ -19,12 +21,12 @@ pub struct RlnUserIdentity { pub user_limit: Fr, } -impl From<(Fr, Fr)> for RlnUserIdentity { - fn from((commitment, secret_hash): (Fr, Fr)) -> Self { +impl From<(Fr, Fr, Fr)> for RlnUserIdentity { + fn from((commitment, secret_hash, user_limit): (Fr, Fr, Fr)) -> Self { Self { commitment, secret_hash, - user_limit: Fr::from(0), + user_limit, } } } @@ -89,3 +91,61 @@ pub fn compute_rln_proof_and_values( )?; Ok((proof, proof_values)) } + +#[cfg(test)] +mod tests { + use super::*; + use zerokit_utils::ZerokitMerkleTree; + + #[test] + fn test_recover_secret_hash() { + let (user_co, user_sh) = keygen(); + let epoch = hash_to_field(b"foo"); + let spam_limit = Fr::from(10); + + let mut tree = PmTree::new(20, Default::default(), Default::default()).unwrap(); + tree.set(0, spam_limit).unwrap(); + let m_proof = tree.proof(0).unwrap(); + + let rln_identifier = RlnIdentifier::new(b"rln id test"); + + let message_id = Fr::from(1); + + let (_proof_0, proof_values_0) = compute_rln_proof_and_values( + &RlnUserIdentity { + commitment: user_co, + secret_hash: user_sh, + user_limit: spam_limit, + }, + &rln_identifier, + RlnData { + message_id, + data: hash_to_field(b"sig"), + }, + epoch, + &m_proof, + ) + .unwrap(); + + let (_proof_1, proof_values_1) = compute_rln_proof_and_values( + &RlnUserIdentity { + commitment: user_co, + secret_hash: user_sh, + user_limit: spam_limit, + }, + &rln_identifier, + RlnData { + message_id, + data: hash_to_field(b"sig 2"), + }, + epoch, + &m_proof, + ) + .unwrap(); + + let share1 = (proof_values_0.x, proof_values_0.y); + let share2 = (proof_values_1.x, proof_values_1.y); + let recovered_identity_secret_hash = compute_id_secret(share1, share2).unwrap(); + assert_eq!(user_sh, recovered_identity_secret_hash); + } +}