Add ProofGenerationData structure (#3)

* Add ProofGenerationData structure
* Use shared merkle tree
* Add unit tests for proof service
This commit is contained in:
Sydhds
2025-06-05 10:03:31 +02:00
committed by GitHub
parent 194ce101fc
commit 7525c8c226
12 changed files with 932 additions and 188 deletions

1
Cargo.lock generated
View File

@@ -4407,6 +4407,7 @@ version = "0.1.0"
dependencies = [
"alloy",
"ark-bn254",
"ark-groth16",
"ark-serialize 0.5.0",
"async-channel",
"bytesize",

View File

@@ -12,27 +12,29 @@ service RlnProver {
rpc SendTransaction (SendTransactionRequest) returns (SendTransactionReply);
rpc RegisterUser (RegisterUserRequest) returns (RegisterUserReply);
// Server side streaming RPC: 1 request -> X responses (stream)
rpc GetProofs(RlnProofFilter) returns (stream RlnProof);
rpc GetProofs(RlnProofFilter) returns (stream RlnProofReply);
rpc GetUserTierInfo(GetUserTierInfoRequest) returns (GetUserTierInfoReply);
rpc SetTierLimits(SetTierLimitsRequest) returns (SetTierLimitsReply);
}
/*
// TransactionType: https://github.com/Consensys/linea-besu/blob/09cbed1142cfe4d29b50ecf2f156639a4bc8c854/datatypes/src/main/java/org/hyperledger/besu/datatypes/TransactionType.java#L22
enum TransactionType {
/** The Frontier. */
// FRONTIER(0xf8 /* this is serialized as 0x0 in TransactionCompleteResult */),
// The Frontier
// FRONTIER(0xf8),
// FIXME: is this 0xF8 or 0x00 ?
FRONTIER = 0;
/** Access list transaction type. */
// Access list transaction type
ACCESS_LIST = 1; // 0x01
/** Eip1559 transaction type. */
// Eip1559 transaction type
EIP1559 = 2; // 0x02
/** Blob transaction type. */
// Blob transaction type
BLOB = 3; // 0x03
/** Eip7702 transaction type. */
// Eip7702 transaction type
DELEGATE_CODE = 4; // 0x04
}
*/
extend google.protobuf.FieldOptions {
optional uint32 max_size = 50000;
@@ -53,42 +55,57 @@ message Address {
bytes value = 1 [(max_size) = 20];
}
/*
message SECPSignature {
// https://github.com/Consensys/linea-besu/blob/zkbesu/crypto/algorithms/src/main/java/org/hyperledger/besu/crypto/SECPSignature.java#L30
bytes value = 1 [(max_size) = 65];
}
*/
/*
message StorageKey {
bytes value = 1 [(max_size) = 32];
}
*/
/*
message AccessListEntry {
Address address = 1;
repeated StorageKey storageKeys = 2;
}
*/
/*
message AccessListEntries {
// https://github.com/Consensys/linea-besu/blob/zkbesu/datatypes/src/main/java/org/hyperledger/besu/datatypes/AccessListEntry.java#L31
repeated AccessListEntry entries = 1;
}
*/
/*
message VersionedHash {
// https://github.com/Consensys/linea-besu/blob/zkbesu/datatypes/src/main/java/org/hyperledger/besu/datatypes/VersionedHash.java#L28
bytes value = 1 [(max_size) = 32];
}
*/
/*
message BlobsWithCommitments {
// https://github.com/Consensys/linea-besu/blob/zkbesu/datatypes/src/main/java/org/hyperledger/besu/datatypes/BlobsWithCommitments.java#L23
// TODO: need this?
}
*/
/*
message CodeDelegation {
// https://github.com/Consensys/linea-besu/blob/zkbesu/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/core/CodeDelegation.java#L40
// TODO: need this?
}
*/
// Transaction: https://github.com/Consensys/linea-besu/blob/c99bdbd533707a45fad97017fb964578c3e87fde/ethereum/core/src/main/java/org/hyperledger/besu/ethereum/core/Transaction.java#L168
message SendTransactionRequest {
/*
bool forCopy = 1;
TransactionType transactionType = 2;
// Java long == signed 64-bit integer
@@ -113,21 +130,45 @@ message SendTransactionRequest {
repeated BlobsWithCommitments blobsWithCommitments = 17;
repeated CodeDelegation maybeCodeDelegationList = 18;
optional bytes rawRlp = 19;
*/
optional Wei gasPrice = 1;
optional Address sender = 2;
optional U256 chainId = 3;
bytes transactionHash = 4 [(max_size) = 32];
}
message SendTransactionReply {
bool result = 1;
// string message = 2;
}
message RlnProofFilter {
optional string address = 1;
}
message RlnProofReply {
oneof resp {
// variant for success
RlnProof proof = 1;
// variant for error
RlnProofError error = 2;
}
}
message RlnProof {
string sender = 1;
string id_commitment = 2;
bytes proof = 3;
// From https://rfc.vac.dev/vac/32/rln-v1#sending-the-output-message
bytes sender = 1;
bytes tx_hash = 2; // Transaction hash for the proof (non hash signal)
bytes proof = 3; // The RLN proof itself, hex encoded
// bytes internal_nullifier = 4;
// bytes x = 5; // signal hash
// bytes y = 6;
// bytes rln_identifier = 7;
// bytes merkle_proof_root = 8;
// bytes epoch = 9;
}
message RlnProofError {
string error = 2;
}
message RegisterUserRequest {

View File

@@ -37,6 +37,7 @@ tonic-build = "*"
[dev-dependencies]
criterion = "0.6.0"
claims = "0.8"
ark-groth16 = "0.5.0"
[[bench]]
name = "user_db_heavy_write"

View File

@@ -487,7 +487,7 @@ mod tests {
.map_err(|e| AppErrorExt::AppError(e)),
// Wait for 3 epoch slices + 100 ms (to wait to receive notif + counter incr)
tokio::time::timeout(
epoch_slice_duration * 3 + Duration::from_millis(100),
epoch_slice_duration * 3 + Duration::from_millis(500),
async move {
loop {
notifier.notified().await;

View File

@@ -3,6 +3,8 @@ use alloy::{
primitives::Address,
transports::{RpcError, TransportErrorKind},
};
use ark_serialize::SerializationError;
use rln::protocol::ProofError;
#[derive(thiserror::Error, Debug)]
pub enum AppError {
@@ -18,6 +20,7 @@ pub enum AppError {
EpochError(#[from] WaitUntilError),
}
/*
#[derive(thiserror::Error, Debug)]
pub enum RegistrationError {
#[error("Transaction has no sender address")]
@@ -27,3 +30,62 @@ pub enum RegistrationError {
#[error("Cannot find id_commitment for address: {0:?}")]
NotFound(Address),
}
*/
#[derive(thiserror::Error, Debug)]
pub enum ProofGenerationError {
#[error("Proof generation failed: {0}")]
Proof(#[from] ProofError),
#[error("Proof serialization failed: {0}")]
Serialization(#[from] SerializationError),
#[error("Proof serialization failed: {0}")]
SerializationWrite(#[from] std::io::Error),
#[error(transparent)]
MerkleProofError(#[from] GetMerkleTreeProofError),
}
/// Same as ProofGenerationError but can be Cloned (can be used in Tokio broadcast channels)
#[derive(thiserror::Error, Debug, Clone)]
pub enum ProofGenerationStringError {
#[error("Proof generation failed: {0}")]
Proof(String),
#[error("Proof serialization failed: {0}")]
Serialization(String),
#[error("Proof serialization failed: {0}")]
SerializationWrite(String),
#[error(transparent)]
MerkleProofError(#[from] GetMerkleTreeProofError),
}
impl From<ProofGenerationError> for ProofGenerationStringError {
fn from(value: ProofGenerationError) -> Self {
match value {
ProofGenerationError::Proof(e) => ProofGenerationStringError::Proof(e.to_string()),
ProofGenerationError::Serialization(e) => {
ProofGenerationStringError::Serialization(e.to_string())
}
ProofGenerationError::SerializationWrite(e) => {
ProofGenerationStringError::SerializationWrite(e.to_string())
}
ProofGenerationError::MerkleProofError(e) => {
ProofGenerationStringError::MerkleProofError(e)
}
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum RegisterError {
#[error("User (address: {0:?}) has already been registered")]
AlreadyRegistered(Address),
#[error("Merkle tree error: {0}")]
TreeError(String),
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum GetMerkleTreeProofError {
#[error("User not registered")]
NotRegistered,
#[error("Merkle tree error: {0}")]
TreeError(String),
}

View File

@@ -1,38 +1,27 @@
use std::collections::BTreeMap;
// std
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
// third-party
use alloy::primitives::{Address, U256};
use ark_bn254::Fr;
use async_channel::Sender;
use bytesize::ByteSize;
use futures::TryFutureExt;
use http::Method;
use tokio::sync::{broadcast, mpsc};
use tonic::{
Request,
Response,
Status,
codegen::tokio_stream::wrappers::ReceiverStream,
transport::Server,
// codec::CompressionEncoding
Request, Response, Status, codegen::tokio_stream::wrappers::ReceiverStream, transport::Server,
};
use tonic_web::GrpcWebLayer;
use tower_http::cors::{Any, CorsLayer};
use tracing::{
debug,
// error,
// info
};
use tracing::debug;
// internal
use crate::error::{
AppError,
// RegistrationError
};
use crate::error::{AppError, ProofGenerationStringError, RegisterError};
use crate::proof_generation::{ProofGenerationData, ProofSendingData};
use crate::tier::{KarmaAmount, TierLimit, TierName};
use crate::user_db_service::{KarmaAmountExt, UserDb, UserTierInfo};
use rln_proof::{RlnIdentifier, RlnUserIdentity};
use rln_proof::RlnIdentifier;
pub mod prover_proto {
@@ -42,12 +31,13 @@ pub mod prover_proto {
pub(crate) const FILE_DESCRIPTOR_SET: &[u8] =
tonic::include_file_descriptor_set!("prover_descriptor");
}
use crate::tier::{KarmaAmount, TierLimit, TierName};
use prover_proto::{
GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply, RegisterUserRequest, RlnProof,
RlnProofFilter, SendTransactionReply, SendTransactionRequest, SetTierLimitsReply,
SetTierLimitsRequest, Tier, UserTierInfoError, UserTierInfoResult,
GetUserTierInfoReply, GetUserTierInfoRequest, RegisterUserReply, RegisterUserRequest,
RegistrationStatus, RlnProof, RlnProofFilter, RlnProofReply, SendTransactionReply,
SendTransactionRequest, SetTierLimitsReply, SetTierLimitsRequest, Tier, UserTierInfoError,
UserTierInfoResult,
get_user_tier_info_reply::Resp,
rln_proof_reply::Resp as GetProofsResp,
rln_prover_server::{RlnProver, RlnProverServer},
};
@@ -62,15 +52,16 @@ const PROVER_SERVICE_HTTP2_MAX_FRAME_SIZE: ByteSize = ByteSize::kib(16);
const PROVER_SERVICE_MESSAGE_DECODING_MAX_SIZE: ByteSize = ByteSize::mib(5);
// Max size for Message (encoding, e.g., 5 Mb)
const PROVER_SERVICE_MESSAGE_ENCODING_MAX_SIZE: ByteSize = ByteSize::mib(5);
const PROVER_SPAM_LIMIT: u64 = 10_000;
#[derive(Debug)]
pub struct ProverService {
proof_sender: Sender<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
proof_sender: Sender<ProofGenerationData>,
user_db: UserDb,
rln_identifier: Arc<RlnIdentifier>,
spam_limit: u64,
broadcast_channel: (broadcast::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>),
broadcast_channel: (
broadcast::Sender<Result<ProofSendingData, ProofGenerationStringError>>,
broadcast::Receiver<Result<ProofSendingData, ProofGenerationStringError>>,
),
}
#[tonic::async_trait]
@@ -101,18 +92,26 @@ impl RlnProver for ProverService {
// Update the counter as soon as possible (should help to prevent spamming...)
let counter = self.user_db.on_new_tx(&sender).unwrap_or_default();
let user_identity = RlnUserIdentity {
secret_hash: user_id.secret_hash,
commitment: user_id.commitment,
user_limit: Fr::from(self.spam_limit),
};
if req.transaction_hash.len() != 32 {
return Err(Status::invalid_argument(
"Invalid transaction hash (should be 32 bytes)",
));
}
// Inexpensive clone (behind Arc ptr)
let rln_identifier = self.rln_identifier.clone();
let proof_data = ProofGenerationData::from((
user_id,
rln_identifier,
counter.into(),
sender,
req.transaction_hash,
));
// Send some data to one of the proof services
self.proof_sender
.send((user_identity, rln_identifier, counter.into()))
.send(proof_data)
.await
.map_err(|e| Status::from_error(Box::new(e)))?;
@@ -122,13 +121,36 @@ impl RlnProver for ProverService {
async fn register_user(
&self,
_request: Request<RegisterUserRequest>,
request: Request<RegisterUserRequest>,
) -> Result<Response<RegisterUserReply>, Status> {
let reply = RegisterUserReply { status: 0 };
debug!("register_user request: {:?}", request);
let req = request.into_inner();
let user = if let Some(user) = req.user {
if let Ok(user) = Address::try_from(user.value.as_slice()) {
user
} else {
return Err(Status::invalid_argument("Invalid sender address"));
}
} else {
return Err(Status::invalid_argument("No sender address"));
};
let result = self.user_db.on_new_user(user);
let status = match result {
Ok(_) => RegistrationStatus::Success,
Err(RegisterError::AlreadyRegistered(_a)) => RegistrationStatus::AlreadyRegistered,
_ => RegistrationStatus::Failure,
};
let reply = RegisterUserReply {
status: status.into(),
};
Ok(Response::new(reply))
}
type GetProofsStream = ReceiverStream<Result<RlnProof, Status>>;
type GetProofsStream = ReceiverStream<Result<RlnProofReply, Status>>;
async fn get_proofs(
&self,
@@ -139,13 +161,19 @@ impl RlnProver for ProverService {
let (tx, rx) = mpsc::channel(100);
let mut rx2 = self.broadcast_channel.0.subscribe();
tokio::spawn(async move {
while let Ok(data) = rx2.recv().await {
// FIXME: Should we send the error here?
while let Ok(Ok(data)) = rx2.recv().await {
let rln_proof = RlnProof {
sender: "0xAA".to_string(),
id_commitment: "1".to_string(),
proof: data,
sender: data.tx_sender.to_vec(),
tx_hash: data.tx_hash,
proof: data.proof,
};
if let Err(e) = tx.send(Ok(rln_proof)).await {
let resp = RlnProofReply {
resp: Some(GetProofsResp::Proof(rln_proof)),
};
if let Err(e) = tx.send(Ok(resp)).await {
debug!("Done: sending dummy rln proofs: {}", e);
break;
};
@@ -235,8 +263,11 @@ impl RlnProver for ProverService {
}
pub(crate) struct GrpcProverService {
pub proof_sender: Sender<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
pub broadcast_channel: (broadcast::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>),
pub proof_sender: Sender<ProofGenerationData>,
pub broadcast_channel: (
broadcast::Sender<Result<ProofSendingData, ProofGenerationStringError>>,
broadcast::Receiver<Result<ProofSendingData, ProofGenerationStringError>>,
),
pub addr: SocketAddr,
pub rln_identifier: RlnIdentifier,
pub user_db: UserDb,
@@ -248,7 +279,6 @@ impl GrpcProverService {
proof_sender: self.proof_sender.clone(),
user_db: self.user_db.clone(),
rln_identifier: Arc::new(self.rln_identifier.clone()),
spam_limit: PROVER_SPAM_LIMIT,
broadcast_channel: (
self.broadcast_channel.0.clone(),
self.broadcast_channel.0.subscribe(),

View File

@@ -3,8 +3,8 @@ mod args;
mod epoch_service;
mod error;
mod grpc_service;
mod proof_generation;
mod proof_service;
mod registry;
mod registry_listener;
mod tier;
mod user_db_service;
@@ -29,9 +29,10 @@ use crate::args::AppArgs;
use crate::epoch_service::EpochService;
use crate::grpc_service::GrpcProverService;
use crate::proof_service::ProofService;
use crate::user_db_service::UserDbService;
use crate::user_db_service::{RateLimit, UserDbService};
const RLN_IDENTIFIER_NAME: &[u8] = b"test-rln-identifier";
const PROVER_SPAM_LIMIT: RateLimit = RateLimit::new(10_000u64);
const PROOF_SERVICE_COUNT: u8 = 8;
const GENESIS: DateTime<Utc> = DateTime::from_timestamp(1431648000, 0).unwrap();
@@ -63,6 +64,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let user_db_service = UserDbService::new(
epoch_service.epoch_changes.clone(),
epoch_service.current_epoch.clone(),
PROVER_SPAM_LIMIT,
);
// proof service
@@ -76,6 +78,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let rln_identifier = RlnIdentifier::new(RLN_IDENTIFIER_NAME);
let addr = SocketAddr::new(app_args.ip, app_args.port);
debug!("Listening on: {}", addr);
// TODO: broadcast subscribe?
let prover_grpc_service = GrpcProverService {
proof_sender,
broadcast_channel: (tx.clone(), rx),
@@ -89,9 +92,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let proof_recv = proof_receiver.clone();
let broadcast_sender = tx.clone();
let current_epoch = epoch_service.current_epoch.clone();
let user_db = user_db_service.get_user_db();
set.spawn(async {
let proof_service = ProofService::new(proof_recv, broadcast_sender, current_epoch);
let proof_service = ProofService::new(
proof_recv,
broadcast_sender,
current_epoch,
user_db,
PROVER_SPAM_LIMIT,
);
proof_service.serve().await
});
}

View File

@@ -0,0 +1,43 @@
use std::sync::Arc;
// third-party
use alloy::primitives::Address;
// internal
use rln_proof::{RlnIdentifier, RlnUserIdentity};
#[derive(Debug, Clone)]
pub(crate) struct ProofGenerationData {
pub(crate) user_identity: RlnUserIdentity,
pub(crate) rln_identifier: Arc<RlnIdentifier>,
pub(crate) tx_counter: u64,
pub(crate) tx_sender: Address,
pub(crate) tx_hash: Vec<u8>,
}
impl From<(RlnUserIdentity, Arc<RlnIdentifier>, u64, Address, Vec<u8>)> for ProofGenerationData {
/// Create a new ProofGenerationData - assume tx_hash is 32 bytes long
fn from(
(user_identity, rln_identifier, tx_counter, tx_sender, tx_hash): (
RlnUserIdentity,
Arc<RlnIdentifier>,
u64,
Address,
Vec<u8>,
),
) -> Self {
debug_assert!(tx_hash.len() == 32);
Self {
user_identity,
rln_identifier,
tx_counter,
tx_sender,
tx_hash,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ProofSendingData {
pub(crate) tx_hash: Vec<u8>,
pub(crate) tx_sender: Address,
pub(crate) proof: Vec<u8>,
}

View File

@@ -2,50 +2,48 @@ use std::io::{Cursor, Write};
use std::sync::Arc;
// third-party
use ark_bn254::Fr;
use ark_serialize::{CanonicalSerialize, SerializationError};
use ark_serialize::CanonicalSerialize;
use async_channel::Receiver;
use parking_lot::RwLock;
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::pm_tree_adapter::PmTree;
use rln::protocol::{ProofError, serialize_proof_values};
use tracing::debug;
use rln::hashers::hash_to_field;
use rln::protocol::serialize_proof_values;
use tracing::{debug, info};
// internal
use crate::epoch_service::{Epoch, EpochSlice};
use crate::error::AppError;
use rln_proof::{
RlnData, RlnIdentifier, RlnUserIdentity, ZerokitMerkleTree, compute_rln_proof_and_values,
};
use crate::error::{AppError, ProofGenerationError, ProofGenerationStringError};
use crate::proof_generation::{ProofGenerationData, ProofSendingData};
use crate::user_db_service::{RateLimit, UserDb};
use rln_proof::{RlnData, compute_rln_proof_and_values};
#[derive(thiserror::Error, Debug)]
enum ProofGenerationError {
#[error("Proof generation failed: {0}")]
Proof(#[from] ProofError),
#[error("Proof serialization failed: {0}")]
Serialization(#[from] SerializationError),
#[error("Proof serialization failed: {0}")]
SerializationWrite(#[from] std::io::Error),
#[error("Error: {0}")]
Misc(String),
}
const PROOF_SIZE: usize = 512;
/// A service to generate a RLN proof (and then to broadcast it)
#[derive(Debug)]
pub struct ProofService {
receiver: Receiver<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
broadcast_sender: tokio::sync::broadcast::Sender<Vec<u8>>,
receiver: Receiver<ProofGenerationData>,
broadcast_sender:
tokio::sync::broadcast::Sender<Result<ProofSendingData, ProofGenerationStringError>>,
current_epoch: Arc<RwLock<(Epoch, EpochSlice)>>,
user_db: UserDb,
rate_limit: RateLimit,
}
impl ProofService {
pub(crate) fn new(
receiver: Receiver<(RlnUserIdentity, Arc<RlnIdentifier>, u64)>,
broadcast_sender: tokio::sync::broadcast::Sender<Vec<u8>>,
receiver: Receiver<ProofGenerationData>,
broadcast_sender: tokio::sync::broadcast::Sender<
Result<ProofSendingData, ProofGenerationStringError>,
>,
current_epoch: Arc<RwLock<(Epoch, EpochSlice)>>,
user_db: UserDb,
rate_limit: RateLimit,
) -> Self {
debug_assert!(rate_limit > RateLimit::ZERO);
Self {
receiver,
broadcast_sender,
current_epoch,
user_db,
rate_limit,
}
}
@@ -54,19 +52,33 @@ impl ProofService {
let received = self.receiver.recv().await;
if let Err(e) = received {
debug!("Stopping proof generation service: {}", e);
info!("Stopping proof generation service: {}", e);
break;
}
let (user_identity, rln_identifier, counter) = received.unwrap();
let proof_generation_data = received.unwrap();
let (current_epoch, current_epoch_slice) = *self.current_epoch.read();
let user_db = self.user_db.clone();
let proof_generation_data_ = proof_generation_data.clone();
let rate_limit = self.rate_limit;
// Move to a task (as generating the proof can take quite some time)
let blocking_task = tokio::task::spawn_blocking(move || {
let message_id = {
let mut m_id = proof_generation_data.tx_counter;
// Note: Zerokit can only recover user secret hash with 2 messages with the
// same message_id so here we force to use the previous message_id
// so the Verifier could recover the secret hash
if RateLimit::from(m_id) == rate_limit {
m_id -= 1;
}
m_id
};
let rln_data = RlnData {
message_id: Fr::from(counter),
// TODO: tx hash to field
data: hash_to_field(b"RLN is awesome"),
message_id: Fr::from(message_id),
data: hash_to_field(proof_generation_data.tx_hash.as_slice()),
};
let epoch_bytes = {
@@ -76,33 +88,22 @@ impl ProofService {
};
let epoch = hash_to_field(epoch_bytes.as_slice());
// FIXME: maintain tree in Prover or query RLN Reg SC ?
// Merkle tree
let tree_height = 20;
let mut tree = PmTree::new(tree_height, Fr::from(0), Default::default())
.map_err(|e| ProofGenerationError::Misc(e.to_string()))?;
// let mut tree = OptimalMerkleTree::new(tree_height, Fr::from(0), Default::default()).unwrap();
let rate_commit =
poseidon_hash(&[user_identity.commitment, user_identity.user_limit]);
tree.set(0, rate_commit)
.map_err(|e| ProofGenerationError::Misc(e.to_string()))?;
let merkle_proof = tree
.proof(0)
.map_err(|e| ProofGenerationError::Misc(e.to_string()))?;
let merkle_proof = user_db.get_merkle_proof(&proof_generation_data.tx_sender)?;
let (proof, proof_values) = compute_rln_proof_and_values(
&user_identity,
&rln_identifier,
&proof_generation_data.user_identity,
&proof_generation_data.rln_identifier,
rln_data,
epoch,
&merkle_proof,
)
.map_err(ProofGenerationError::Proof)?;
debug!("proof: {:?}", proof);
debug!("proof_values: {:?}", proof_values);
// Serialize proof
// FIXME: proof size?
let mut output_buffer = Cursor::new(Vec::with_capacity(512));
let mut output_buffer = Cursor::new(Vec::with_capacity(PROOF_SIZE));
proof
.serialize_compressed(&mut output_buffer)
.map_err(ProofGenerationError::Serialization)?;
@@ -114,18 +115,446 @@ impl ProofService {
});
let result = blocking_task.await;
// if let Err(e) = result {
// return Err(Status::from_error(Box::new(e)));
// }
// blocking_task returns Result<Result<Vec<u8>, _>>
// Result (1st) is a JoinError (and should not happen)
// Result (2nd) is a ProofGenerationError
let result = result.unwrap().unwrap();
// TODO: no unwrap()
// FIXME: send proof + other info
self.broadcast_sender.send(result).unwrap();
let result = result.unwrap(); // Should never happen (but should panic if it does)
let proof_sending_data = result
.map(|r| ProofSendingData {
tx_hash: proof_generation_data_.tx_hash,
tx_sender: proof_generation_data_.tx_sender,
proof: r,
})
.map_err(ProofGenerationStringError::from);
if let Err(e) = self.broadcast_sender.send(proof_sending_data) {
info!("Stopping proof generation service: {}", e);
break;
};
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
// third-party
use alloy::primitives::{Address, address};
use ark_groth16::{Proof as ArkProof, Proof, VerifyingKey};
use ark_serialize::CanonicalDeserialize;
use claims::assert_matches;
use futures::TryFutureExt;
use rln::circuit::{Curve, zkey_from_folder};
use tokio::sync::broadcast;
use tracing::info;
// third-party: zerokit
use rln::protocol::{compute_id_secret, deserialize_proof_values, keygen, verify_proof};
// internal
use crate::user_db_service::UserDbService;
use rln_proof::RlnIdentifier;
const ADDR_1: Address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
const ADDR_2: Address = address!("0xb20a608c624Ca5003905aA834De7156C68b2E1d0");
const TX_HASH_1: [u8; 32] = [0x011; 32];
const TX_HASH_1_2: [u8; 32] = [0x12; 32];
#[derive(thiserror::Error, Debug)]
enum AppErrorExt {
#[error("AppError: {0}")]
AppError(#[from] AppError),
#[error("Future timeout")]
Elapsed,
#[error("Proof generation failed: {0}")]
ProofGeneration(#[from] ProofGenerationStringError),
#[error("Proof verification failed")]
ProofVerification,
#[error("Exiting...")]
Exit,
#[error("Recovered secret")]
RecoveredSecret(Fr),
}
async fn proof_sender(
sender: Address,
proof_tx: &mut async_channel::Sender<ProofGenerationData>,
rln_identifier: Arc<RlnIdentifier>,
user_db: &UserDb,
) -> Result<(), AppErrorExt> {
// used by test_proof_generation unit test
debug!("Starting proof sender...");
debug!("Waiting a bit before sending proof...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
debug!("Sending proof...");
proof_tx
.send(ProofGenerationData {
user_identity: user_db.get_user(&ADDR_1).unwrap(),
rln_identifier,
tx_counter: 0,
tx_sender: sender,
tx_hash: TX_HASH_1.to_vec(),
})
.await
.unwrap();
debug!("Sending proof done");
// tokio::time::sleep(std::time::Duration::from_secs(10)).await;
Ok::<(), AppErrorExt>(())
}
async fn proof_verifier(
broadcast_receiver: &mut broadcast::Receiver<
Result<ProofSendingData, ProofGenerationStringError>,
>,
verifying_key: &VerifyingKey<Curve>,
) -> Result<(), AppErrorExt> {
// used by test_proof_generation unit test
debug!("Starting broadcast receiver...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let res =
tokio::time::timeout(std::time::Duration::from_secs(5), broadcast_receiver.recv())
.await
.map_err(|_e| AppErrorExt::Elapsed)?;
debug!("res: {:?}", res);
let res = res.unwrap();
let res = res?;
let mut proof_cursor = Cursor::new(&res.proof);
debug!("proof cursor: {:?}", proof_cursor);
let proof = ArkProof::deserialize_compressed(&mut proof_cursor).unwrap();
let position = proof_cursor.position() as usize;
let proof_cursor_2 = &proof_cursor.get_ref().as_slice()[position..];
let (proof_values, _) = deserialize_proof_values(proof_cursor_2);
debug!("[proof verifier] proof: {:?}", proof);
debug!("[proof verifier] proof_values: {:?}", proof_values);
let verified = verify_proof(verifying_key, &proof, &proof_values)
.map_err(|_e| AppErrorExt::ProofVerification)?;
debug!("verified: {:?}", verified);
// Exit after receiving one proof
Err::<(), AppErrorExt>(AppErrorExt::Exit)
}
#[tokio::test]
#[tracing_test::traced_test]
async fn test_proof_generation() {
// Queues
let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2);
let mut broadcast_receiver = broadcast_sender.subscribe();
let (mut proof_tx, proof_rx) = async_channel::unbounded();
// Epoch
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
// User db
let user_db_service =
UserDbService::new(Default::default(), epoch_store.clone(), 10.into());
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
user_db.on_new_user(ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
// Proof service
let proof_service = ProofService::new(
proof_rx,
broadcast_sender,
epoch_store,
user_db.clone(),
RateLimit::from(10),
);
// Verification
let proving_key = zkey_from_folder();
let verification_key = &proving_key.0.vk;
info!("Starting...");
let res = tokio::try_join!(
proof_service.serve().map_err(AppErrorExt::AppError),
proof_verifier(&mut broadcast_receiver, verification_key),
proof_sender(ADDR_1, &mut proof_tx, rln_identifier.clone(), &user_db),
);
// Everything ok if proof_verifier return AppErrorExt::Exit else there is a real error
assert_matches!(res, Err(AppErrorExt::Exit));
}
#[tokio::test]
#[tracing_test::traced_test]
async fn test_user_not_registered() {
// Ask for a proof for an unregistered user
// Queues
let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2);
let mut broadcast_receiver = broadcast_sender.subscribe();
let (mut proof_tx, proof_rx) = async_channel::unbounded();
// Epoch
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
// User db
let user_db_service =
UserDbService::new(Default::default(), epoch_store.clone(), 10.into());
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
// user_db.on_new_user(ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
// Proof service
let proof_service = ProofService::new(
proof_rx,
broadcast_sender,
epoch_store,
user_db.clone(),
RateLimit::from(10),
);
// Verification
let proving_key = zkey_from_folder();
let verification_key = &proving_key.0.vk;
info!("Starting...");
let res = tokio::try_join!(
proof_service.serve().map_err(AppErrorExt::AppError),
proof_verifier(&mut broadcast_receiver, verification_key),
proof_sender(ADDR_2, &mut proof_tx, rln_identifier.clone(), &user_db),
);
// Expect this error (any other error is a real error)
assert_matches!(
res,
Err(AppErrorExt::ProofGeneration(
ProofGenerationStringError::MerkleProofError(_)
))
);
}
async fn proof_reveal_secret(
broadcast_receiver: &mut broadcast::Receiver<
Result<ProofSendingData, ProofGenerationStringError>,
>,
verifying_key: &VerifyingKey<Curve>,
) -> Result<(), AppErrorExt> {
// used by test_user_spamming unit test
debug!("Starting broadcast receiver...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let mut proof_values_store = vec![];
loop {
let res =
tokio::time::timeout(std::time::Duration::from_secs(5), broadcast_receiver.recv())
.await
.map_err(|_e| AppErrorExt::Elapsed)?;
let res = res.unwrap();
let res = res?;
let mut proof_cursor = Cursor::new(&res.proof);
let proof: Proof<Curve> = ArkProof::deserialize_compressed(&mut proof_cursor).unwrap();
let position = proof_cursor.position() as usize;
let proof_cursor_2 = &proof_cursor.get_ref().as_slice()[position..];
let (proof_values, _) = deserialize_proof_values(proof_cursor_2);
proof_values_store.push(proof_values);
if proof_values_store.len() >= 2 {
break;
}
}
debug!("Now recovering secret hash...");
let proof_values_0 = proof_values_store.get(0).unwrap();
let proof_values_1 = proof_values_store.get(1).unwrap();
println!("proof_values_0: {:?}", proof_values_0);
println!("proof_values_1: {:?}", proof_values_1);
let share1 = (proof_values_0.x, proof_values_0.y);
let share2 = (proof_values_1.x, proof_values_1.y);
// TODO: should we check external nullifier as well?
let recovered_identity_secret_hash = compute_id_secret(share1, share2).unwrap();
debug!(
"recovered_identity_secret_hash: {:?}",
recovered_identity_secret_hash
);
// Exit after receiving one proof
Err::<(), AppErrorExt>(AppErrorExt::RecoveredSecret(recovered_identity_secret_hash))
}
async fn proof_sender_2(
proof_tx: &mut async_channel::Sender<ProofGenerationData>,
rln_identifier: Arc<RlnIdentifier>,
user_db: &UserDb,
sender: Address,
tx_hashes: ([u8; 32], [u8; 32]),
) -> Result<(), AppErrorExt> {
// used by test_proof_generation unit test
debug!("Starting proof sender 2...");
debug!("Waiting a bit before sending proof...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
debug!("Sending proof...");
proof_tx
.send(ProofGenerationData {
user_identity: user_db.get_user(&sender).unwrap(),
rln_identifier: rln_identifier.clone(),
tx_counter: 0,
tx_sender: sender.clone(),
tx_hash: tx_hashes.0.to_vec(),
})
.await
.unwrap();
debug!("Sending proof done");
debug!("Waiting a bit before sending 2nd proof...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
debug!("Sending 2nd proof...");
proof_tx
.send(ProofGenerationData {
user_identity: user_db.get_user(&sender).unwrap(),
rln_identifier,
tx_counter: 1,
tx_sender: sender,
tx_hash: tx_hashes.1.to_vec(),
})
.await
.unwrap();
debug!("Sending 2nd proof done");
Ok::<(), AppErrorExt>(())
}
#[tokio::test]
#[tracing_test::traced_test]
async fn test_user_spamming() {
// Recover secret from a user spamming the system
// Queues
let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2);
let mut broadcast_receiver = broadcast_sender.subscribe();
let (mut proof_tx, proof_rx) = async_channel::unbounded();
// Epoch
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
// Limits
let rate_limit = RateLimit::from(1);
// User db
let user_db_service =
UserDbService::new(Default::default(), epoch_store.clone(), rate_limit);
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
let user_addr_1 = user_db.get_user(&ADDR_1).unwrap();
user_db.on_new_user(ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
// Proof service
let proof_service = ProofService::new(
proof_rx,
broadcast_sender,
epoch_store,
user_db.clone(),
rate_limit,
);
// Verification
let proving_key = zkey_from_folder();
let verification_key = &proving_key.0.vk;
info!("Starting...");
let res = tokio::try_join!(
proof_service.serve().map_err(AppErrorExt::AppError),
proof_reveal_secret(&mut broadcast_receiver, verification_key),
proof_sender_2(
&mut proof_tx,
rln_identifier.clone(),
&user_db,
ADDR_1,
(TX_HASH_1, TX_HASH_1_2)
),
);
match res {
Err(AppErrorExt::RecoveredSecret(secret_hash)) => {
assert_eq!(secret_hash, user_addr_1.secret_hash);
}
_ => {
panic!("Unexpected result");
}
}
}
#[tokio::test]
#[ignore]
#[tracing_test::traced_test]
async fn test_user_spamming_same_signal() {
// Recover secret from a user spamming the system
// Queues
let (broadcast_sender, _broadcast_receiver) = broadcast::channel(2);
let mut broadcast_receiver = broadcast_sender.subscribe();
let (mut proof_tx, proof_rx) = async_channel::unbounded();
// Epoch
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
// Limits
let rate_limit = RateLimit::from(1);
// User db - limit is 1 message per epoch
let user_db_service =
UserDbService::new(Default::default(), epoch_store.clone(), rate_limit.into());
let user_db = user_db_service.get_user_db();
user_db.on_new_user(ADDR_1).unwrap();
let user_addr_1 = user_db.get_user(&ADDR_1).unwrap();
debug!("user_addr_1: {:?}", user_addr_1);
user_db.on_new_user(ADDR_2).unwrap();
let rln_identifier = Arc::new(RlnIdentifier::new(b"foo bar baz"));
// Proof service
let proof_service = ProofService::new(
proof_rx,
broadcast_sender,
epoch_store,
user_db.clone(),
rate_limit,
);
// Verification
let proving_key = zkey_from_folder();
let verification_key = &proving_key.0.vk;
info!("Starting...");
let res = tokio::try_join!(
proof_service.serve().map_err(AppErrorExt::AppError),
proof_reveal_secret(&mut broadcast_receiver, verification_key),
proof_sender_2(
&mut proof_tx,
rln_identifier.clone(),
&user_db,
ADDR_1,
(TX_HASH_1, TX_HASH_1)
),
);
// TODO: wait for Zerokit 0.8
// assert_matches!(res, Err(AppErrorExt::Exit));
}
}

View File

@@ -1,51 +0,0 @@
// use alloy::primitives::Address;
// use ark_bn254::Fr;
// use dashmap::DashMap;
// use dashmap::mapref::one::Ref;
// use rln::protocol::keygen;
/*
#[derive(Debug)]
pub(crate) struct UserRegistry {
inner: DashMap<Address, (Fr, Fr)>,
}
impl UserRegistry {
fn new() -> Self {
Self {
inner: DashMap::new(),
}
}
pub(crate) fn get(&self, address: &Address) -> Option<Ref<Address, (Fr, Fr)>> {
self.inner.get(address)
}
fn register(&self, address: Address) {
let (identity_secret_hash, id_commitment) = keygen();
self.inner
.insert(address, (identity_secret_hash, id_commitment));
}
}
impl Default for UserRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy::primitives::address;
#[test]
fn test_user_registration() {
let address = address!("0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f");
let reg = UserRegistry::default();
reg.register(address);
assert!(reg.get(&address).is_some());
}
}
*/

View File

@@ -4,31 +4,107 @@ use std::ops::Deref;
use std::sync::Arc;
// third-party
use alloy::primitives::{Address, U256};
use ark_bn254::Fr;
use derive_more::{Add, From, Into};
use parking_lot::RwLock;
use rln::hashers::poseidon_hash;
use rln::pm_tree_adapter::{PmTree, PmTreeProof};
use rln::protocol::keygen;
use scc::HashMap;
use tokio::sync::Notify;
use tracing::debug;
// internal
use crate::epoch_service::{Epoch, EpochSlice};
use crate::error::AppError;
use crate::error::{AppError, GetMerkleTreeProofError, RegisterError};
use crate::tier::{KarmaAmount, TIER_LIMITS, TierLimit, TierName};
use rln_proof::RlnUserIdentity;
use rln_proof::{RlnUserIdentity, ZerokitMerkleTree};
#[derive(Debug, Default, Clone)]
pub(crate) struct UserRegistry {
inner: HashMap<Address, RlnUserIdentity>,
const MERKLE_TREE_HEIGHT: usize = 20;
#[derive(Debug, Clone, Copy, From, Into)]
struct MerkleTreeIndex(usize);
#[derive(Debug, Clone, Copy, Default, PartialOrd, PartialEq, From, Into)]
pub struct RateLimit(u64);
impl RateLimit {
pub(crate) const ZERO: RateLimit = RateLimit(0);
pub(crate) const fn new(value: u64) -> Self {
Self(value)
}
}
impl From<RateLimit> for Fr {
fn from(rate_limit: RateLimit) -> Self {
Fr::from(rate_limit.0)
}
}
#[derive(Clone)]
pub(crate) struct UserRegistry {
inner: HashMap<Address, (RlnUserIdentity, MerkleTreeIndex)>,
tree: Arc<RwLock<PmTree>>,
rate_limit: RateLimit,
}
impl std::fmt::Debug for UserRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "UserRegistry {{ inner: {:?} }}", self.inner)
}
}
impl Default for UserRegistry {
fn default() -> Self {
Self {
inner: Default::default(),
// unwrap safe - no config
tree: Arc::new(RwLock::new(
PmTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default()).unwrap(),
)),
rate_limit: Default::default(),
}
}
}
impl From<RateLimit> for UserRegistry {
fn from(rate_limit: RateLimit) -> Self {
Self {
inner: Default::default(),
// unwrap safe - no config
tree: Arc::new(RwLock::new(
PmTree::new(MERKLE_TREE_HEIGHT, Default::default(), Default::default()).unwrap(),
)),
rate_limit,
}
}
}
impl UserRegistry {
fn register(&self, address: Address) -> bool {
fn register(&self, address: Address) -> Result<(), RegisterError> {
let (identity_secret_hash, id_commitment) = keygen();
self.inner
let index = self.inner.len();
let res = self
.inner
.insert(
address,
RlnUserIdentity::from((identity_secret_hash, id_commitment)),
(
RlnUserIdentity::from((
identity_secret_hash,
id_commitment,
Fr::from(self.rate_limit),
)),
MerkleTreeIndex(index),
),
)
.is_ok()
.map_err(|_e| RegisterError::AlreadyRegistered(address));
let rate_commit = poseidon_hash(&[id_commitment, Fr::from(u64::from(self.rate_limit))]);
self.tree
.write()
.set(index, rate_commit)
.map_err(|e| RegisterError::TreeError(e.to_string()))?;
res
}
fn has_user(&self, address: &Address) -> bool {
@@ -36,7 +112,19 @@ impl UserRegistry {
}
fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
self.inner.get(address).map(|entry| entry.clone())
self.inner.get(address).map(|entry| entry.0.clone())
}
fn get_merkle_proof(&self, address: &Address) -> Result<PmTreeProof, GetMerkleTreeProofError> {
let index = self
.inner
.get(address)
.map(|entry| entry.1)
.ok_or(GetMerkleTreeProofError::NotRegistered)?;
self.tree
.read()
.proof(index.into())
.map_err(|e| GetMerkleTreeProofError::TreeError(e.to_string()))
}
}
@@ -129,10 +217,21 @@ impl UserDb {
}
}
pub fn on_new_user(&self, address: Address) -> Result<(), RegisterError> {
self.user_registry.register(address)
}
pub fn get_user(&self, address: &Address) -> Option<RlnUserIdentity> {
self.user_registry.get_user(address)
}
pub fn get_merkle_proof(
&self,
address: &Address,
) -> Result<PmTreeProof, GetMerkleTreeProofError> {
self.user_registry.get_merkle_proof(address)
}
pub(crate) fn on_new_tx(&self, address: &Address) -> Option<EpochSliceCounter> {
if self.user_registry.has_user(address) {
Some(self.tx_registry.incr_counter(address, None))
@@ -249,10 +348,11 @@ impl UserDbService {
pub(crate) fn new(
epoch_changes_notifier: Arc<Notify>,
epoch_store: Arc<RwLock<(Epoch, EpochSlice)>>,
rate_limit: RateLimit,
) -> Self {
Self {
user_db: UserDb {
user_registry: Default::default(),
user_registry: Arc::new(UserRegistry::from(rate_limit)),
tx_registry: Default::default(),
tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())),
tier_limits_next: Arc::new(Default::default()),
@@ -308,7 +408,7 @@ impl UserDbService {
mod tests {
use super::*;
use alloy::primitives::address;
use claims::assert_err;
use claims::{assert_err, assert_matches};
struct MockKarmaSc {}
@@ -335,6 +435,23 @@ mod tests {
}
}
#[test]
fn test_user_register() {
let user_db = UserDb {
user_registry: Default::default(),
tx_registry: Default::default(),
tier_limits: Arc::new(RwLock::new(TIER_LIMITS.clone())),
tier_limits_next: Arc::new(Default::default()),
epoch_store: Arc::new(RwLock::new(Default::default())),
};
let addr = Address::new([0; 20]);
user_db.user_registry.register(addr).unwrap();
assert_matches!(
user_db.user_registry.register(addr),
Err(RegisterError::AlreadyRegistered(_))
);
}
#[tokio::test]
async fn test_incr_tx_counter() {
let user_db = UserDb {
@@ -354,8 +471,9 @@ mod tests {
tier_info,
Err(UserTierInfoError::NotRegistered(_))
));
// Register user + update tx counter
user_db.user_registry.register(addr);
// Register user
user_db.user_registry.register(addr).unwrap();
// Now update user tx counter
assert_eq!(user_db.on_new_tx(&addr), Some(EpochSliceCounter(1)));
let tier_info = user_db.user_tier_info(&addr, MockKarmaSc {}).await.unwrap();
assert_eq!(tier_info.epoch_tx_count, 1);
@@ -367,16 +485,16 @@ mod tests {
let mut epoch = Epoch::from(11);
let mut epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let user_db_service = UserDbService::new(Default::default(), epoch_store);
let user_db_service = UserDbService::new(Default::default(), epoch_store, 10.into());
let user_db = user_db_service.get_user_db();
let addr_1_tx_count = 2;
let addr_2_tx_count = 820;
user_db.user_registry.register(ADDR_1);
user_db.user_registry.register(ADDR_1).unwrap();
user_db
.tx_registry
.incr_counter(&ADDR_1, Some(addr_1_tx_count));
user_db.user_registry.register(ADDR_2);
user_db.user_registry.register(ADDR_2).unwrap();
user_db
.tx_registry
.incr_counter(&ADDR_2, Some(addr_2_tx_count));
@@ -450,7 +568,7 @@ mod tests {
let mut epoch = Epoch::from(11);
let mut epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let user_db_service = UserDbService::new(Default::default(), epoch_store);
let user_db_service = UserDbService::new(Default::default(), epoch_store, 10.into());
let user_db = user_db_service.get_user_db();
let tier_limits_original = user_db.tier_limits.read().clone();
@@ -498,7 +616,7 @@ mod tests {
let epoch = Epoch::from(11);
let epoch_slice = EpochSlice::from(42);
let epoch_store = Arc::new(RwLock::new((epoch, epoch_slice)));
let user_db_service = UserDbService::new(Default::default(), epoch_store);
let user_db_service = UserDbService::new(Default::default(), epoch_store, 10.into());
let user_db = user_db_service.get_user_db();
let tier_limits_original = user_db.tier_limits.read().clone();

View File

@@ -7,8 +7,10 @@ use ark_relations::r1cs::ConstraintMatrices;
use rln::circuit::ZKEY_BYTES;
use rln::circuit::zkey::read_zkey;
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::pm_tree_adapter::PmTree;
use rln::protocol::{
ProofError, RLNProofValues, generate_proof, proof_values_from_witness, rln_witness_from_values,
ProofError, RLNProofValues, compute_id_secret, generate_proof, keygen,
proof_values_from_witness, rln_witness_from_values,
};
/// A RLN user identity & limit
@@ -19,12 +21,12 @@ pub struct RlnUserIdentity {
pub user_limit: Fr,
}
impl From<(Fr, Fr)> for RlnUserIdentity {
fn from((commitment, secret_hash): (Fr, Fr)) -> Self {
impl From<(Fr, Fr, Fr)> for RlnUserIdentity {
fn from((commitment, secret_hash, user_limit): (Fr, Fr, Fr)) -> Self {
Self {
commitment,
secret_hash,
user_limit: Fr::from(0),
user_limit,
}
}
}
@@ -89,3 +91,61 @@ pub fn compute_rln_proof_and_values(
)?;
Ok((proof, proof_values))
}
#[cfg(test)]
mod tests {
use super::*;
use zerokit_utils::ZerokitMerkleTree;
#[test]
fn test_recover_secret_hash() {
let (user_co, user_sh) = keygen();
let epoch = hash_to_field(b"foo");
let spam_limit = Fr::from(10);
let mut tree = PmTree::new(20, Default::default(), Default::default()).unwrap();
tree.set(0, spam_limit).unwrap();
let m_proof = tree.proof(0).unwrap();
let rln_identifier = RlnIdentifier::new(b"rln id test");
let message_id = Fr::from(1);
let (_proof_0, proof_values_0) = compute_rln_proof_and_values(
&RlnUserIdentity {
commitment: user_co,
secret_hash: user_sh,
user_limit: spam_limit,
},
&rln_identifier,
RlnData {
message_id,
data: hash_to_field(b"sig"),
},
epoch,
&m_proof,
)
.unwrap();
let (_proof_1, proof_values_1) = compute_rln_proof_and_values(
&RlnUserIdentity {
commitment: user_co,
secret_hash: user_sh,
user_limit: spam_limit,
},
&rln_identifier,
RlnData {
message_id,
data: hash_to_field(b"sig 2"),
},
epoch,
&m_proof,
)
.unwrap();
let share1 = (proof_values_0.x, proof_values_0.y);
let share2 = (proof_values_1.x, proof_values_1.y);
let recovered_identity_secret_hash = compute_id_secret(share1, share2).unwrap();
assert_eq!(user_sh, recovered_identity_secret_hash);
}
}