Compare commits

...

25 Commits

Author SHA1 Message Date
sydhds
f4c085ae42 Add missing test file 2025-07-18 10:53:09 +02:00
sydhds
ee712ea84f Add unit tests of BE keygen related functions against Q value 2025-07-18 10:52:48 +02:00
sydhds
2749be14c6 Fix fr_to_bytes_be function 2025-07-11 09:37:16 +02:00
sydhds
0f67f0ecd5 Add be functions in ffi 2025-07-10 08:24:45 +02:00
sydhds
acf313e032 Add hash_be and poseidon_hash_be functions 2025-06-30 12:18:47 +02:00
sydhds
2e3528c9b2 Add BE functions in public API too 2025-06-30 12:18:47 +02:00
sydhds
833bbd1fc3 Cargo fmt pass 2025-06-30 12:18:47 +02:00
sydhds
ce9e05484e Add BE serialization for Fr and identity pair/tuple 2025-06-30 12:18:47 +02:00
Sydhds
baf474e747 Use Vec::with_capacity for bytes_le_to_vec_fr (#321) 2025-06-23 10:13:39 +02:00
Ekaterina Broslavskaya
dc0b31752c release v0.8.0 (#315) 2025-06-05 12:23:06 +03:00
Sydhds
36013bf4ba Remove not explicit use statement (#317) 2025-06-05 10:32:43 +02:00
Sydhds
211b2d4830 Add error for return type of compute_id_secret function (#316) 2025-06-04 09:00:27 +02:00
Sydhds
5f4bcb74ce Eyre removal 2 (#311)
Co-authored-by: Ekaterina Broslavskaya <seemenkina@gmail.com>
2025-06-02 10:32:13 +02:00
Jakub Sokołowski
de5fd36add nix: add RLN targets for different platforms
Wanted to be able to build `wakucanary` without having to build `zerokit` manually.
Also adds the `release` flag which can be set to `false` for a debug build.

Signed-off-by: Jakub Sokołowski <jakub@status.im>
2025-05-29 10:30:02 +02:00
Jakub Sokołowski
19c0f551c8 nix: use rust tooling from rust-overlay for builds
Noticed the builds in `nix/default.nix` were not using the tooling
from `rust-overlay` but instead using older one from `pkgs`.

This also removes the need to compile LLVM before building Zerokit.

Signed-off-by: Jakub Sokołowski <jakub@status.im>
2025-05-29 09:56:31 +02:00
vinhtc27
4133f1f8c3 fix: bumps deps, downgrade hex-literal to avoid Rust edition 2024 issue
Signed-off-by: Jakub Sokołowski <jakub@status.im>
2025-05-29 09:56:30 +02:00
markoburcul
149096f7a6 flake: add rust overlay and shell dependencies 2025-05-15 11:51:55 +02:00
Vinh Trịnh
7023e85fce Enable parallel execution for Merkle Tree (#306) 2025-05-14 12:19:37 +07:00
Vinh Trịnh
a4cafa6adc Enable parallel execution for rln-wasm module (#296)
## Changes

- Enabled parallelism in the browser for `rln-wasm` with the
`multithread` feature flag.
- Added browser tests for both single-threaded and multi-threaded modes.
- Enabled browser tests in the CI workflow.
- Pending: resolving hanging issue with `wasm-bindgen-rayon`
([comment](https://github.com/RReverser/wasm-bindgen-rayon/issues/6#issuecomment-2814372940)).
- Forked [this
commit](42887c80e6)
into a separate
[branch](https://github.com/vacp2p/zerokit/tree/benchmark-v0.8.0), which
includes an HTML benchmark file and a test case for the multithreaded
feature in `rln-wasm`.
- The test case still has the known issue above, so it's temporarily
disabled in this PR and will be addressed in the future.
- Improve the `make installdeps` which resolves the issue of NVM not
enabling Node.js in the current terminal session.
- Reduce the build size of the `.wasm` blob using the `wasm-opt` tool
from [Binaryen](https://github.com/WebAssembly/binaryen).
- Maybe we can close this draft
[PR](https://github.com/vacp2p/zerokit/pull/226), which is already very
outdated?
2025-05-13 13:15:05 +07:00
Sydhds
4077357e3f Add merkle tree glossary + mermaid graph example (#298) 2025-04-18 15:11:55 +02:00
Sydhds
84d9799d09 Add poseidon hash benchmark + optim (#300) 2025-04-18 15:08:43 +02:00
Sydhds
c576af8e62 fix(tree): fix OptimalMerkleTree set_range & override_range performance issue (#295) 2025-04-16 17:40:58 +02:00
Sydhds
81470b9678 Add poseidon hash unit test (against ref values) (#299) 2025-04-16 16:23:58 +02:00
Vinh Trịnh
9d4198c205 feat(rln-wasm): bring back wasm support for zerokit
# Bring Back WebAssembly Support for ZeroKit

- Update minor versions of all dependencies.
- Update documentation to reflect these changes.
- ~~Vendor `wasmer` v4.4.0 in [my git
repository](https://github.com/vinhtc27/wasmer) for `ark-circom`
v0.5.0.~~
- Resolve `wasm-pack` build failures (`os error 2`) caused by a Node.js
version mismatch.
- Restore the previous CI pipeline for the `rln-wasm` feature and update
to the stable toolchain.
- ~~Use `ark-circom` with the `wasm` feature for WebAssembly
compatibility and the `rln.wasm` file for witness calculation.~~
- ~~Fix dependency issues related to `ark-circom` v0.5.0, which
currently uses `wasmer` v4.4.0 and is affected by this
[issue](https://github.com/rust-lang/rust/issues/91632#issuecomment-1477914703).~~
- Install WABT with `brew` and `apt-get` instead of cloning to fix
`wasm-strip not found` issue in the CI workflow.
- Install `wasm-pack` with `curl` instead of using `wasm-pack-action` to
fix parse exception error in the CI workflow.
- Use the `.wasm` file with JS bindings for witness calculation, which
is generated from [`iden3/circom`](https://github.com/iden3/circom)
during circuit compilation. This allows witness computation outside RLN
instance.
- Refactor the `rln` module by moving circuit-related files to the
`src/circuit` folder for better organization.
- Remove `ark-circom` and `wasmer` by cloning the
[CircomReduction](3c95ed98e2/src/circom/qap.rs (L12))
struct and the
[read_zkey](3c95ed98e2/src/zkey.rs (L53))
function into the `rln` module, which reduces the repository's build
size and speeds up compilation time and the CI workflow duration.
- These change also address
[#282](https://github.com/vacp2p/zerokit/issues/282) by removing
`wasmer` and `wasmer-wasix`, which lack x32 system support.
- Benchmark `rln-wasm` with `wasm_bindgen_test`, covering RLN instance
creation, key generation, witness calculation, proving, and
verification. Also, add them to `v0.6.1` in
[benchmark-v0.6.1](https://github.com/vacp2p/zerokit/tree/benchmark-v0.6.1)
for comparison.
- Add `arkzkey` feature for rln-wasm, including tests, benchmarks, CI
workflow updates, and related documentation.
- Benchmark rln-wasm in the browser using HTML, covering initialization,
RLN instance creation, proving, and verification; fork to the
`benchmark-v0.7.0` branch for later use
[here](https://github.com/vacp2p/zerokit/tree/benchmark-v0.7.0).
- Fix clippy error: "this `repeat().take()` can be written more
concisely" on CI workflow for `utils` module.
([error](https://github.com/vacp2p/zerokit/actions/runs/14258579070/job/39965568013))
- Update Makefile.toml to be able to run `make build`, `make test`, and
`make bench` from root and inside each modules.
2025-04-08 13:37:18 +07:00
markoburcul
c60e0c33fc nix: add flake and derivation for android-arm64 arch
Referenced issue: https://github.com/waku-org/nwaku/issues/3232
2025-04-04 10:50:26 +02:00
69 changed files with 5095 additions and 4719 deletions

View File

@@ -5,6 +5,7 @@ on:
paths-ignore:
- "**.md"
- "!.github/workflows/*.yml"
- "!rln-wasm/**"
- "!rln/src/**"
- "!rln/resources/**"
- "!utils/src/**"
@@ -12,6 +13,7 @@ on:
paths-ignore:
- "**.md"
- "!.github/workflows/*.yml"
- "!rln-wasm/**"
- "!rln/src/**"
- "!rln/resources/**"
- "!utils/src/**"
@@ -22,8 +24,8 @@ jobs:
utils-test:
strategy:
matrix:
platform: [ ubuntu-latest, macos-latest ]
crate: [ utils ]
platform: [ubuntu-latest, macos-latest]
crate: [utils]
runs-on: ${{ matrix.platform }}
timeout-minutes: 60
@@ -48,9 +50,9 @@ jobs:
rln-test:
strategy:
matrix:
platform: [ ubuntu-latest, macos-latest ]
crate: [ rln ]
feature: [ "default", "arkzkey", "stateless" ]
platform: [ubuntu-latest, macos-latest]
crate: [rln]
feature: ["default", "arkzkey", "stateless"]
runs-on: ${{ matrix.platform }}
timeout-minutes: 60
@@ -76,12 +78,98 @@ jobs:
fi
working-directory: ${{ matrix.crate }}
rln-wasm-test:
strategy:
matrix:
platform: [ubuntu-latest, macos-latest]
crate: [rln-wasm]
feature: ["default", "arkzkey"]
runs-on: ${{ matrix.platform }}
timeout-minutes: 60
name: test - ${{ matrix.crate }} - ${{ matrix.platform }} - ${{ matrix.feature }}
steps:
- uses: actions/checkout@v3
- name: Install stable toolchain
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: Swatinem/rust-cache@v2
- name: Install Dependencies
run: make installdeps
- name: cargo-make build
run: |
if [ ${{ matrix.feature }} == default ]; then
cargo make build
else
cargo make build_${{ matrix.feature }}
fi
working-directory: ${{ matrix.crate }}
- name: cargo-make test
run: |
if [ ${{ matrix.feature }} == default ]; then
cargo make test --release
else
cargo make test_${{ matrix.feature }} --release
fi
working-directory: ${{ matrix.crate }}
- name: cargo-make test browser
run: |
if [ ${{ matrix.feature }} == default ]; then
cargo make test_browser --release
else
cargo make test_browser_${{ matrix.feature }} --release
fi
working-directory: ${{ matrix.crate }}
# rln-wasm-multihread-test:
# strategy:
# matrix:
# platform: [ubuntu-latest, macos-latest]
# crate: [rln-wasm]
# feature: ["multithread", "multithread_arkzkey"]
# runs-on: ${{ matrix.platform }}
# timeout-minutes: 60
# name: test - ${{ matrix.crate }} - ${{ matrix.platform }} - ${{ matrix.feature }}
# steps:
# - uses: actions/checkout@v3
# - name: Install nightly toolchain
# uses: actions-rs/toolchain@v1
# with:
# profile: minimal
# toolchain: nightly
# override: true
# components: rust-src
# target: wasm32-unknown-unknown
# - uses: Swatinem/rust-cache@v2
# - name: Install Dependencies
# run: make installdeps
# - name: cargo-make build
# run: |
# if [ ${{ matrix.feature }} == default ]; then
# cargo make build
# else
# cargo make build_${{ matrix.feature }}
# fi
# working-directory: ${{ matrix.crate }}
# - name: cargo-make test
# run: |
# if [ ${{ matrix.feature }} == default ]; then
# cargo make test --release
# else
# cargo make test_${{ matrix.feature }} --release
# fi
# working-directory: ${{ matrix.crate }}
lint:
strategy:
matrix:
# we run lint tests only on ubuntu
platform: [ ubuntu-latest ]
crate: [ rln, utils ]
platform: [ubuntu-latest]
crate: [rln, rln-wasm, utils]
runs-on: ${{ matrix.platform }}
timeout-minutes: 60
@@ -106,7 +194,7 @@ jobs:
- name: cargo clippy
if: success() || failure()
run: |
cargo clippy --release -- -D warnings
cargo clippy --release
working-directory: ${{ matrix.crate }}
benchmark-utils:
@@ -115,12 +203,12 @@ jobs:
strategy:
matrix:
# we run benchmark tests only on ubuntu
platform: [ ubuntu-latest ]
crate: [ utils ]
platform: [ubuntu-latest]
crate: [utils]
runs-on: ${{ matrix.platform }}
timeout-minutes: 60
name: benchmark - ${{ matrix.platform }} - ${{ matrix.crate }}
name: benchmark - ${{ matrix.crate }} - ${{ matrix.platform }}
steps:
- name: Checkout sources
uses: actions/checkout@v3
@@ -136,13 +224,13 @@ jobs:
strategy:
matrix:
# we run benchmark tests only on ubuntu
platform: [ ubuntu-latest ]
crate: [ rln ]
feature: [ "default", "arkzkey" ]
platform: [ubuntu-latest]
crate: [rln]
feature: ["default", "arkzkey"]
runs-on: ${{ matrix.platform }}
timeout-minutes: 60
name: benchmark - ${{ matrix.platform }} - ${{ matrix.crate }} - ${{ matrix.feature }}
name: benchmark - ${{ matrix.crate }} - ${{ matrix.platform }} - ${{ matrix.feature }}
steps:
- name: Checkout sources
uses: actions/checkout@v3
@@ -151,4 +239,4 @@ jobs:
with:
branchName: ${{ github.base_ref }}
cwd: ${{ matrix.crate }}
features: ${{ matrix.feature }}
features: ${{ matrix.feature }}

View File

@@ -84,9 +84,39 @@ jobs:
path: ${{ matrix.target }}-${{ matrix.feature }}-rln.tar.gz
retention-days: 2
browser-rln-wasm:
name: Browser build (RLN WASM)
runs-on: ubuntu-latest
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install stable toolchain
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: make installdeps
- name: cross make build
run: |
cross make build
mkdir release
cp pkg/** release/
tar -czvf browser-rln-wasm.tar.gz release/
working-directory: rln-wasm
- name: Upload archive artifact
uses: actions/upload-artifact@v4
with:
name: browser-rln-wasm-archive
path: rln-wasm/browser-rln-wasm.tar.gz
retention-days: 2
prepare-prerelease:
name: Prepare pre-release
needs: [ linux, macos ]
needs: [ linux, macos, browser-rln-wasm ]
runs-on: ubuntu-latest
steps:
- name: Checkout code

4
.gitignore vendored
View File

@@ -9,7 +9,9 @@ rln-cli/database
# will have compiled files and executables
debug/
target/
wabt/
# Generated by Nix
result
# These are backup files generated by rustfmt
**/*.rs.bk

View File

@@ -1,3 +1,5 @@
# CHANGE LOG
## 2023-02-28 v0.2
This release contains:
@@ -10,7 +12,6 @@ This release contains:
- Dual License under Apache 2.0 and MIT
- RLN compiles as a static library, which can be consumed through a C FFI
## 2022-09-19 v0.1
Initial beta release.

4123
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[workspace]
members = ["rln", "rln-cli", "utils"]
default-members = ["rln", "rln-cli", "utils"]
members = ["rln", "rln-cli", "rln-wasm", "utils"]
default-members = ["rln", "rln-cli", "rln-wasm", "utils"]
resolver = "2"
# Compilation profile for any non-workspace member.

View File

@@ -1,6 +1,6 @@
.PHONY: all installdeps build test clean
.PHONY: all installdeps build test bench clean
all: .pre-build build
all: installdeps build
.fetch-submodules:
@git submodule update --init --recursive
@@ -13,26 +13,27 @@ endif
installdeps: .pre-build
ifeq ($(shell uname),Darwin)
# commented due to https://github.com/orgs/Homebrew/discussions/4612
# @brew update
@brew install cmake ninja
@git -C "wabt" pull || git clone --recursive https://github.com/WebAssembly/wabt.git "wabt"
@cd wabt && mkdir -p build && make
@brew install cmake ninja binaryen
else ifeq ($(shell uname),Linux)
@sudo apt-get update
@sudo apt-get install -y cmake ninja-build
@git -C "wabt" pull || git clone --recursive https://github.com/WebAssembly/wabt.git "wabt"
@cd wabt && mkdir -p build && cd build && cmake .. -GNinja && ninja && sudo ninja install
@if [ -f /etc/os-release ] && grep -q "ID=nixos" /etc/os-release; then \
echo "Detected NixOS, skipping apt-get installation."; \
else \
sudo apt-get install -y cmake ninja-build binaryen; \
fi
endif
# nvm already checks if it's installed, and no-ops if it is
@curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash
@. ${HOME}/.nvm/nvm.sh && nvm install 18.20.2 && nvm use 18.20.2;
@which wasm-pack > /dev/null && wasm-pack --version | grep -q "0.13.1" || cargo install wasm-pack --version=0.13.1
@which wasm-bindgen > /dev/null && wasm-bindgen --version | grep -q "0.2.100" || cargo install wasm-bindgen-cli --version=0.2.100
@test -s "$$HOME/.nvm/nvm.sh" || curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash
@bash -c '. "$$HOME/.nvm/nvm.sh"; [ "$$(node -v 2>/dev/null)" = "v22.14.0" ] || nvm install 22.14.0; nvm use 22.14.0; nvm alias default 22.14.0'
build: .pre-build
build: installdeps
@cargo make build
test: .pre-build
test: build
@cargo make test
bench: build
@cargo make bench
clean:
@cargo clean

View File

@@ -37,6 +37,12 @@ Zerokit currently focuses on RLN (Rate-Limiting Nullifier) implementation using
make installdeps
```
#### Use Nix to install dependencies
```bash
nix develop
```
### Build and Test All Crates
```bash

48
flake.lock generated Normal file
View File

@@ -0,0 +1,48 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1740603184,
"narHash": "sha256-t+VaahjQAWyA+Ctn2idyo1yxRIYpaDxMgHkgCNiMJa4=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "f44bd8ca21e026135061a0a57dcf3d0775b67a49",
"type": "github"
},
"original": {
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "f44bd8ca21e026135061a0a57dcf3d0775b67a49",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs",
"rust-overlay": "rust-overlay"
}
},
"rust-overlay": {
"inputs": {
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1748399823,
"narHash": "sha256-kahD8D5hOXOsGbNdoLLnqCL887cjHkx98Izc37nDjlA=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "d68a69dc71bc19beb3479800392112c2f6218159",
"type": "github"
},
"original": {
"owner": "oxalica",
"repo": "rust-overlay",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

71
flake.nix Normal file
View File

@@ -0,0 +1,71 @@
{
description = "A flake for building zerokit";
inputs = {
# Version 24.11
nixpkgs.url = "github:NixOS/nixpkgs?rev=f44bd8ca21e026135061a0a57dcf3d0775b67a49";
rust-overlay = {
url = "github:oxalica/rust-overlay";
inputs.nixpkgs.follows = "nixpkgs";
};
};
outputs = { self, nixpkgs, rust-overlay }:
let
stableSystems = [
"x86_64-linux" "aarch64-linux"
"x86_64-darwin" "aarch64-darwin"
"x86_64-windows" "i686-linux"
"i686-windows"
];
forAllSystems = nixpkgs.lib.genAttrs stableSystems;
overlays = [
(import rust-overlay)
(f: p: { inherit rust-overlay; })
];
pkgsFor = forAllSystems (system: import nixpkgs { inherit system overlays; });
in rec
{
packages = forAllSystems (system: let
pkgs = pkgsFor.${system};
buildPackage = pkgs.callPackage ./nix/default.nix;
buildRln = (buildPackage { src = self; project = "rln"; }).override;
in rec {
rln = buildRln {
features = "arkzkey";
};
rln-linux-arm64 = buildRln {
target-platform = "aarch64-multiplatform";
rust-target = "aarch64-unknown-linux-gnu";
};
rln-android-arm64 = buildRln {
target-platform = "aarch64-android-prebuilt";
rust-target = "aarch64-linux-android";
};
rln-ios-arm64 = buildRln {
target-platform = "aarch64-darwin";
rust-target = "aarch64-apple-ios";
};
# TODO: Remove legacy name for RLN android library
zerokit-android-arm64 = rln-android-arm64;
default = rln;
});
devShells = forAllSystems (system: let
pkgs = pkgsFor.${system};
in {
default = pkgs.mkShell {
buildInputs = with pkgs; [
git cmake cargo-make rustup
binaryen ninja gnuplot
rust-bin.stable.latest.default
];
};
});
};
}

62
nix/default.nix Normal file
View File

@@ -0,0 +1,62 @@
{
pkgs,
rust-overlay,
project,
src ? ../.,
release ? true,
features ? "arkzkey",
target-platform ? null,
rust-target ? null,
}:
let
# Use cross-compilation if target-platform is specified.
targetPlatformPkgs = if target-platform != null
then pkgs.pkgsCross.${target-platform}
else pkgs;
rust-bin = rust-overlay.lib.mkRustBin { } targetPlatformPkgs.buildPackages;
# Use Rust and Cargo versions from rust-overlay.
rustPlatform = targetPlatformPkgs.makeRustPlatform {
cargo = rust-bin.stable.latest.minimal;
rustc = rust-bin.stable.latest.minimal;
};
in rustPlatform.buildRustPackage {
pname = "zerokit";
version = if src ? rev then src.rev else "nightly";
# Improve caching of sources
src = builtins.path { path = src; name = "zerokit"; };
cargoLock = {
lockFile = ../Cargo.lock;
allowBuiltinFetchGit = true;
};
doCheck = false;
CARGO_HOME = "/tmp";
buildPhase = ''
cargo build --lib \
${if release then "--release" else ""} \
${if rust-target != null then "--target=${rust-target}" else ""} \
${if features != null then "--features=${features}" else ""} \
--manifest-path ${project}/Cargo.toml
'';
installPhase = ''
mkdir -p $out/
for file in $(find target -name 'librln.*' | grep -v deps/); do
mkdir -p $out/$(dirname $file)
cp -r $file $out/$file
done
'';
meta = with pkgs.lib; {
description = "Zerokit";
license = licenses.mit;
};
}

View File

@@ -13,14 +13,14 @@ path = "src/examples/stateless.rs"
required-features = ["stateless"]
[dependencies]
rln = { path = "../rln", default-features = true, features = ["pmtree-ft"] }
rln = { path = "../rln", default-features = true }
zerokit_utils = { path = "../utils" }
clap = { version = "4.5.29", features = ["cargo", "derive", "env"] }
clap_derive = { version = "4.5.28" }
color-eyre = "0.6.2"
serde_json = "1.0.138"
serde = { version = "1.0.217", features = ["derive"] }
clap = { version = "4.5.38", features = ["cargo", "derive", "env"] }
color-eyre = "0.6.4"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
[features]
default = []
arkzkey = ["rln/arkzkey"]
stateless = ["rln/stateless"]

View File

@@ -1,8 +1,9 @@
[tasks.build]
command = "cargo"
args = ["build", "--release"]
args = ["build"]
[tasks.test]
command = "cargo"
args = ["test", "--release"]
disabled = true
[tasks.bench]
disabled = true

View File

@@ -6,7 +6,7 @@ use std::{
};
use clap::{Parser, Subcommand};
use color_eyre::{eyre::eyre, Result};
use color_eyre::{eyre::eyre, Report, Result};
use rln::{
circuit::Fr,
hashers::{hash_to_field, poseidon_hash},
@@ -182,7 +182,7 @@ impl RLNSystem {
Ok(false) => {
println!("Verification failed: message_id must be unique within the epoch and satisfy 0 <= message_id < MESSAGE_LIMIT: {MESSAGE_LIMIT}");
}
Err(err) => return Err(err),
Err(err) => return Err(Report::new(err)),
}
Ok(())
}

View File

@@ -127,11 +127,12 @@ impl RLNSystem {
};
let merkle_proof = self.tree.proof(user_index)?;
let x = hash_to_field(signal.as_bytes());
let rln_witness = rln_witness_from_values(
identity.identity_secret_hash,
&merkle_proof,
hash_to_field(signal.as_bytes()),
x,
external_nullifier,
Fr::from(MESSAGE_LIMIT),
Fr::from(message_id),
@@ -157,12 +158,12 @@ impl RLNSystem {
let mut input_buffer = Cursor::new(proof_with_signal);
let root = self.tree.root();
let root_serialized = fr_to_bytes_le(&root);
let mut root_buffer = Cursor::new(root_serialized);
let roots_serialized = fr_to_bytes_le(&root);
let mut roots_buffer = Cursor::new(roots_serialized);
match self
.rln
.verify_with_roots(&mut input_buffer, &mut root_buffer)
.verify_with_roots(&mut input_buffer, &mut roots_buffer)
{
Ok(true) => {
let nullifier = &proof_data[256..288];

6
rln-wasm/.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
/target
**/*.rs.bk
Cargo.lock
bin/
pkg/
wasm-pack.log

46
rln-wasm/Cargo.toml Normal file
View File

@@ -0,0 +1,46 @@
[package]
name = "rln-wasm"
version = "0.2.0"
edition = "2021"
license = "MIT or Apache2"
[lib]
crate-type = ["cdylib", "rlib"]
required-features = ["stateless"]
[dependencies]
rln = { path = "../rln", version = "0.8.0", default-features = false }
zerokit_utils = { path = "../utils", version = "0.6.0" }
num-bigint = { version = "0.4.6", default-features = false }
js-sys = "0.3.77"
wasm-bindgen = "0.2.100"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen-rayon = { version = "1.2.0", optional = true }
# The `console_error_panic_xhook` crate provides better debugging of panics by
# logging them with `console.error`. This is great for development, but requires
# all the `std::fmt` and `std::panicking` infrastructure, so isn't great for
# code size when deploying.
console_error_panic_hook = { version = "0.1.7", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.16", features = ["js"] }
[dev-dependencies]
serde_json = "1.0"
wasm-bindgen-test = "0.3.50"
wasm-bindgen-futures = "0.4.50"
[dev-dependencies.web-sys]
version = "0.3.77"
features = ["Window", "Navigator"]
[features]
default = ["console_error_panic_hook"]
stateless = ["rln/stateless"]
arkzkey = ["rln/arkzkey"]
multithread = ["wasm-bindgen-rayon"]
[package.metadata.docs.rs]
all-features = true

241
rln-wasm/Makefile.toml Normal file
View File

@@ -0,0 +1,241 @@
[tasks.build]
clear = true
dependencies = ["pack_build", "pack_rename", "pack_resize"]
[tasks.build_arkzkey]
clear = true
dependencies = ["pack_build_arkzkey", "pack_rename", "pack_resize"]
[tasks.build_multithread]
clear = true
dependencies = [
"pack_build_multithread",
"post_build_multithread",
"pack_rename",
"pack_resize",
]
[tasks.build_multithread_arkzkey]
clear = true
dependencies = [
"pack_build_multithread_arkzkey",
"post_build_multithread",
"pack_rename",
"pack_resize",
]
[tasks.pack_build]
command = "wasm-pack"
args = [
"build",
"--release",
"--target",
"web",
"--scope",
"waku",
"--features",
"stateless",
]
[tasks.pack_build_arkzkey]
command = "wasm-pack"
args = [
"build",
"--release",
"--target",
"web",
"--scope",
"waku",
"--features",
"stateless,arkzkey",
]
[tasks.pack_build_multithread]
command = "env"
args = [
"RUSTFLAGS=-C target-feature=+atomics,+bulk-memory,+mutable-globals",
"rustup",
"run",
"nightly",
"wasm-pack",
"build",
"--release",
"--target",
"web",
"--scope",
"waku",
"--features",
"stateless,multithread",
"-Z",
"build-std=panic_abort,std",
]
[tasks.pack_build_multithread_arkzkey]
command = "env"
args = [
"RUSTFLAGS=-C target-feature=+atomics,+bulk-memory,+mutable-globals",
"rustup",
"run",
"nightly",
"wasm-pack",
"build",
"--release",
"--target",
"web",
"--scope",
"waku",
"--features",
"stateless,multithread,arkzkey",
"-Z",
"build-std=panic_abort,std",
]
[tasks.post_build_multithread]
script = '''
wasm-bindgen --target web --split-linked-modules --out-dir ./pkg ../target/wasm32-unknown-unknown/release/rln_wasm.wasm && \
find ./pkg/snippets -name "workerHelpers.worker.js" -exec sed -i.bak 's|from '\''\.\.\/\.\.\/\.\.\/'\'';|from "../../../rln_wasm.js";|g' {} \; -exec rm -f {}.bak \; && \
find ./pkg/snippets -name "workerHelpers.worker.js" -exec sed -i.bak 's|await initWbg(module, memory);|await initWbg({ module, memory });|g' {} \; -exec rm -f {}.bak \;
'''
[tasks.pack_rename]
script = "sed -i.bak 's/rln-wasm/zerokit-rln-wasm/g' pkg/package.json && rm pkg/package.json.bak"
[tasks.pack_resize]
command = "wasm-opt"
args = [
"pkg/rln_wasm_bg.wasm",
"-Oz",
"--strip-debug",
"--strip-dwarf",
"--remove-unused-module-elements",
"--vacuum",
"-o",
"pkg/rln_wasm_bg.wasm",
]
[tasks.test]
command = "wasm-pack"
args = [
"test",
"--release",
"--node",
"--target",
"wasm32-unknown-unknown",
"--features",
"stateless",
"--",
"--nocapture",
]
dependencies = ["build"]
[tasks.test_arkzkey]
command = "wasm-pack"
args = [
"test",
"--release",
"--node",
"--target",
"wasm32-unknown-unknown",
"--features",
"stateless,arkzkey",
"--",
"--nocapture",
]
dependencies = ["build_arkzkey"]
[tasks.test_browser]
command = "wasm-pack"
args = [
"test",
"--release",
"--chrome",
# "--firefox",
# "--safari",
"--headless",
"--target",
"wasm32-unknown-unknown",
"--features",
"stateless",
"--",
"--nocapture",
]
dependencies = ["build"]
[tasks.test_browser_arkzkey]
command = "wasm-pack"
args = [
"test",
"--release",
"--chrome",
# "--firefox",
# "--safari",
"--headless",
"--target",
"wasm32-unknown-unknown",
"--features",
"stateless,arkzkey",
"--",
"--nocapture",
]
dependencies = ["build_arkzkey"]
[tasks.test_multithread]
command = "env"
args = [
"RUSTFLAGS=-C target-feature=+atomics,+bulk-memory,+mutable-globals",
"rustup",
"run",
"nightly",
"wasm-pack",
"test",
"--release",
"--chrome",
# "--firefox",
# "--safari",
"--headless",
"--target",
"wasm32-unknown-unknown",
"--features",
"stateless,multithread",
"-Z",
"build-std=panic_abort,std",
"--",
"--nocapture",
]
dependencies = ["build_multithread"]
[tasks.test_multithread_arkzkey]
command = "env"
args = [
"RUSTFLAGS=-C target-feature=+atomics,+bulk-memory,+mutable-globals",
"rustup",
"run",
"nightly",
"wasm-pack",
"test",
"--release",
"--chrome",
# "--firefox",
# "--safari",
"--headless",
"--target",
"wasm32-unknown-unknown",
"--features",
"stateless,multithread,arkzkey",
"-Z",
"build-std=panic_abort,std",
"--",
"--nocapture",
]
dependencies = ["build_multithread_arkzkey"]
[tasks.bench]
disabled = true
[tasks.login]
command = "wasm-pack"
args = ["login"]
[tasks.publish]
command = "wasm-pack"
args = ["publish", "--access", "public", "--target", "web"]

171
rln-wasm/README.md Normal file
View File

@@ -0,0 +1,171 @@
# RLN for WASM
[![npm version](https://badge.fury.io/js/@waku%2Fzerokit-rln-wasm.svg)](https://badge.fury.io/js/@waku%2Fzerokit-rln-wasm)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
The Zerokit RLN WASM Module provides WebAssembly bindings for working with
Rate-Limiting Nullifier [RLN](https://rfc.vac.dev/spec/32/) zkSNARK proofs and primitives.
This module is used by [waku-org/js-rln](https://github.com/waku-org/js-rln/) to enable
RLN functionality in JavaScript/TypeScript applications.
## Install Dependencies
> [!NOTE]
> This project requires the following tools:
>
> - `wasm-pack` - for compiling Rust to WebAssembly
> - `cargo-make` - for running build commands
> - `nvm` - to install and manage Node.js
>
> Ensure all dependencies are installed before proceeding.
### Manually
#### Install `wasm-pack`
```bash
cargo install wasm-pack --version=0.13.1
```
#### Install `cargo-make`
```bash
cargo install cargo-make
```
#### Install `Node.js`
If you don't have `nvm` (Node Version Manager), install it by following
the [installation instructions](https://github.com/nvm-sh/nvm?tab=readme-ov-file#install--update-script).
After installing `nvm`, install and use Node.js `v22.14.0`:
```bash
nvm install 22.14.0
nvm use 22.14.0
nvm alias default 22.14.0
```
If you already have Node.js installed,
check your version with `node -v` command — the version must be strictly greater than 22.
### Or install everything
You can run the following command from the root of the repository to install all required dependencies for `zerokit`
```bash
make installdeps
```
## Building the library
First, navigate to the rln-wasm directory:
```bash
cd rln-wasm
```
Compile zerokit for `wasm32-unknown-unknown`:
```bash
cargo make build
```
Or compile with the **arkzkey** feature enabled
```bash
cargo make build_arkzkey
```
## Running tests and benchmarks
```bash
cargo make test
```
Or test with the **arkzkey** feature enabled
```bash
cargo make test_arkzkey
```
If you want to run the tests in browser headless mode, you can use the following command:
```bash
cargo make test_browser
cargo make test_browser_arkzkey
```
## Parallel computation
The library supports parallel computation using the `wasm-bindgen-rayon` crate,
enabling multi-threaded execution in the browser.
> [!NOTE]
> Parallel support is not enabled by default due to WebAssembly and browser limitations. \
> Compiling this feature requires `nightly` Rust and the `wasm-bindgen-cli` tool.
### Build Setup
#### Install `nightly` Rust
```bash
rustup install nightly
```
#### Install `wasm-bindgen-cli`
```bash
cargo install wasm-bindgen-cli --version=0.2.100
```
### Build Commands
To enable parallel computation for WebAssembly threads, you can use the following command:
```bash
cargo make build_multithread
```
Or with the **arkzkey** feature enabled:
```bash
cargo make build_multithread_arkzkey
```
### WebAssembly Threading Support
Most modern browsers support WebAssembly threads,
but they require the following headers to enable `SharedArrayBuffer`, which is necessary for multithreading:
- Cross-Origin-Opener-Policy: same-origin
- Cross-Origin-Embedder-Policy: require-corp
Without these, the application will fall back to single-threaded mode.
## Feature detection
If you're targeting [older browser versions that didn't support WebAssembly threads yet](https://webassembly.org/roadmap/),
you'll likely want to create two builds - one with thread support and one without -
and use feature detection to choose the right one on the JavaScript side.
You can use [wasm-feature-detect](https://github.com/GoogleChromeLabs/wasm-feature-detect)library for this purpose.
For example, your code might look like this:
```js
import { threads } from 'wasm-feature-detect';
let wasmPkg;
if (await threads()) {
wasmPkg = await import('./pkg-with-threads/index.js');
await wasmPkg.default();
await wasmPkg.initThreadPool(navigator.hardwareConcurrency);
} else {
wasmPkg = await import('./pkg-without-threads/index.js');
await wasmPkg.default();
}
wasmPkg.nowCallAnyExportedFuncs();
```

View File

@@ -0,0 +1,335 @@
// Browser compatible witness calculator
(function (global) {
async function builder(code, options) {
options = options || {};
let wasmModule;
try {
wasmModule = await WebAssembly.compile(code);
} catch (err) {
console.log(err);
console.log(
"\nTry to run circom --c in order to generate c++ code instead\n"
);
throw new Error(err);
}
let wc;
let errStr = "";
let msgStr = "";
const instance = await WebAssembly.instantiate(wasmModule, {
runtime: {
exceptionHandler: function (code) {
let err;
if (code == 1) {
err = "Signal not found.\n";
} else if (code == 2) {
err = "Too many signals set.\n";
} else if (code == 3) {
err = "Signal already set.\n";
} else if (code == 4) {
err = "Assert Failed.\n";
} else if (code == 5) {
err = "Not enough memory.\n";
} else if (code == 6) {
err = "Input signal array access exceeds the size.\n";
} else {
err = "Unknown error.\n";
}
throw new Error(err + errStr);
},
printErrorMessage: function () {
errStr += getMessage() + "\n";
// console.error(getMessage());
},
writeBufferMessage: function () {
const msg = getMessage();
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
if (msg === "\n") {
console.log(msgStr);
msgStr = "";
} else {
// If we've buffered other content, put a space in between the items
if (msgStr !== "") {
msgStr += " ";
}
// Then append the message to the message we are creating
msgStr += msg;
}
},
showSharedRWMemory: function () {
printSharedRWMemory();
},
},
});
const sanityCheck = options;
// options &&
// (
// options.sanityCheck ||
// options.logGetSignal ||
// options.logSetSignal ||
// options.logStartComponent ||
// options.logFinishComponent
// );
wc = new WitnessCalculator(instance, sanityCheck);
return wc;
function getMessage() {
var message = "";
var c = instance.exports.getMessageChar();
while (c != 0) {
message += String.fromCharCode(c);
c = instance.exports.getMessageChar();
}
return message;
}
function printSharedRWMemory() {
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
const arr = new Uint32Array(shared_rw_memory_size);
for (let j = 0; j < shared_rw_memory_size; j++) {
arr[shared_rw_memory_size - 1 - j] =
instance.exports.readSharedRWMemory(j);
}
// If we've buffered other content, put a space in between the items
if (msgStr !== "") {
msgStr += " ";
}
// Then append the value to the message we are creating
msgStr += fromArray32(arr).toString();
}
}
class WitnessCalculator {
constructor(instance, sanityCheck) {
this.instance = instance;
this.version = this.instance.exports.getVersion();
this.n32 = this.instance.exports.getFieldNumLen32();
this.instance.exports.getRawPrime();
const arr = new Uint32Array(this.n32);
for (let i = 0; i < this.n32; i++) {
arr[this.n32 - 1 - i] = this.instance.exports.readSharedRWMemory(i);
}
this.prime = fromArray32(arr);
this.witnessSize = this.instance.exports.getWitnessSize();
this.sanityCheck = sanityCheck;
}
circom_version() {
return this.instance.exports.getVersion();
}
async _doCalculateWitness(input, sanityCheck) {
//input is assumed to be a map from signals to arrays of bigints
this.instance.exports.init(this.sanityCheck || sanityCheck ? 1 : 0);
const keys = Object.keys(input);
var input_counter = 0;
keys.forEach((k) => {
const h = fnvHash(k);
const hMSB = parseInt(h.slice(0, 8), 16);
const hLSB = parseInt(h.slice(8, 16), 16);
const fArr = flatArray(input[k]);
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
if (signalSize < 0) {
throw new Error(`Signal ${k} not found\n`);
}
if (fArr.length < signalSize) {
throw new Error(`Not enough values for input signal ${k}\n`);
}
if (fArr.length > signalSize) {
throw new Error(`Too many values for input signal ${k}\n`);
}
for (let i = 0; i < fArr.length; i++) {
const arrFr = toArray32(BigInt(fArr[i]) % this.prime, this.n32);
for (let j = 0; j < this.n32; j++) {
this.instance.exports.writeSharedRWMemory(
j,
arrFr[this.n32 - 1 - j]
);
}
try {
this.instance.exports.setInputSignal(hMSB, hLSB, i);
input_counter++;
} catch (err) {
// console.log(`After adding signal ${i} of ${k}`)
throw new Error(err);
}
}
});
if (input_counter < this.instance.exports.getInputSize()) {
throw new Error(
`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`
);
}
}
async calculateWitness(input, sanityCheck) {
const w = [];
await this._doCalculateWitness(input, sanityCheck);
for (let i = 0; i < this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const arr = new Uint32Array(this.n32);
for (let j = 0; j < this.n32; j++) {
arr[this.n32 - 1 - j] = this.instance.exports.readSharedRWMemory(j);
}
w.push(fromArray32(arr));
}
return w;
}
async calculateBinWitness(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize * this.n32);
const buff = new Uint8Array(buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
for (let i = 0; i < this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const pos = i * this.n32;
for (let j = 0; j < this.n32; j++) {
buff32[pos + j] = this.instance.exports.readSharedRWMemory(j);
}
}
return buff;
}
async calculateWTNSBin(input, sanityCheck) {
const buff32 = new Uint32Array(
this.witnessSize * this.n32 + this.n32 + 11
);
const buff = new Uint8Array(buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
//"wtns"
buff[0] = "w".charCodeAt(0);
buff[1] = "t".charCodeAt(0);
buff[2] = "n".charCodeAt(0);
buff[3] = "s".charCodeAt(0);
//version 2
buff32[1] = 2;
//number of sections: 2
buff32[2] = 2;
//id section 1
buff32[3] = 1;
const n8 = this.n32 * 4;
//id section 1 length in 64bytes
const idSection1length = 8 + n8;
const idSection1lengthHex = idSection1length.toString(16);
buff32[4] = parseInt(idSection1lengthHex.slice(0, 8), 16);
buff32[5] = parseInt(idSection1lengthHex.slice(8, 16), 16);
//this.n32
buff32[6] = n8;
//prime number
this.instance.exports.getRawPrime();
var pos = 7;
for (let j = 0; j < this.n32; j++) {
buff32[pos + j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
// witness size
buff32[pos] = this.witnessSize;
pos++;
//id section 2
buff32[pos] = 2;
pos++;
// section 2 length
const idSection2length = n8 * this.witnessSize;
const idSection2lengthHex = idSection2length.toString(16);
buff32[pos] = parseInt(idSection2lengthHex.slice(0, 8), 16);
buff32[pos + 1] = parseInt(idSection2lengthHex.slice(8, 16), 16);
pos += 2;
for (let i = 0; i < this.witnessSize; i++) {
this.instance.exports.getWitness(i);
for (let j = 0; j < this.n32; j++) {
buff32[pos + j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
}
return buff;
}
}
function toArray32(rem, size) {
const res = []; //new Uint32Array(size); //has no unshift
const radix = BigInt(0x100000000);
while (rem) {
res.unshift(Number(rem % radix));
rem = rem / radix;
}
if (size) {
var i = size - res.length;
while (i > 0) {
res.unshift(0);
i--;
}
}
return res;
}
function fromArray32(arr) {
//returns a BigInt
var res = BigInt(0);
const radix = BigInt(0x100000000);
for (let i = 0; i < arr.length; i++) {
res = res * radix + BigInt(arr[i]);
}
return res;
}
function flatArray(a) {
var res = [];
fillArray(res, a);
return res;
function fillArray(res, a) {
if (Array.isArray(a)) {
for (let i = 0; i < a.length; i++) {
fillArray(res, a[i]);
}
} else {
res.push(a);
}
}
}
function fnvHash(str) {
const uint64_max = BigInt(2) ** BigInt(64);
let hash = BigInt("0xCBF29CE484222325");
for (var i = 0; i < str.length; i++) {
hash ^= BigInt(str[i].charCodeAt());
hash *= BigInt(0x100000001b3);
hash %= uint64_max;
}
let shash = hash.toString(16);
let n = 16 - shash.length;
shash = "0".repeat(n).concat(shash);
return shash;
}
// Make it globally available
global.witnessCalculatorBuilder = builder;
})(typeof self !== "undefined" ? self : window);

View File

@@ -0,0 +1,325 @@
// Node.js module compatible witness calculator
module.exports = async function builder(code, options) {
options = options || {};
let wasmModule;
try {
wasmModule = await WebAssembly.compile(code);
} catch (err) {
console.log(err);
console.log(
"\nTry to run circom --c in order to generate c++ code instead\n"
);
throw new Error(err);
}
let wc;
let errStr = "";
let msgStr = "";
const instance = await WebAssembly.instantiate(wasmModule, {
runtime: {
exceptionHandler: function (code) {
let err;
if (code == 1) {
err = "Signal not found.\n";
} else if (code == 2) {
err = "Too many signals set.\n";
} else if (code == 3) {
err = "Signal already set.\n";
} else if (code == 4) {
err = "Assert Failed.\n";
} else if (code == 5) {
err = "Not enough memory.\n";
} else if (code == 6) {
err = "Input signal array access exceeds the size.\n";
} else {
err = "Unknown error.\n";
}
throw new Error(err + errStr);
},
printErrorMessage: function () {
errStr += getMessage() + "\n";
// console.error(getMessage());
},
writeBufferMessage: function () {
const msg = getMessage();
// Any calls to `log()` will always end with a `\n`, so that's when we print and reset
if (msg === "\n") {
console.log(msgStr);
msgStr = "";
} else {
// If we've buffered other content, put a space in between the items
if (msgStr !== "") {
msgStr += " ";
}
// Then append the message to the message we are creating
msgStr += msg;
}
},
showSharedRWMemory: function () {
printSharedRWMemory();
},
},
});
const sanityCheck = options;
// options &&
// (
// options.sanityCheck ||
// options.logGetSignal ||
// options.logSetSignal ||
// options.logStartComponent ||
// options.logFinishComponent
// );
wc = new WitnessCalculator(instance, sanityCheck);
return wc;
function getMessage() {
var message = "";
var c = instance.exports.getMessageChar();
while (c != 0) {
message += String.fromCharCode(c);
c = instance.exports.getMessageChar();
}
return message;
}
function printSharedRWMemory() {
const shared_rw_memory_size = instance.exports.getFieldNumLen32();
const arr = new Uint32Array(shared_rw_memory_size);
for (let j = 0; j < shared_rw_memory_size; j++) {
arr[shared_rw_memory_size - 1 - j] =
instance.exports.readSharedRWMemory(j);
}
// If we've buffered other content, put a space in between the items
if (msgStr !== "") {
msgStr += " ";
}
// Then append the value to the message we are creating
msgStr += fromArray32(arr).toString();
}
};
class WitnessCalculator {
constructor(instance, sanityCheck) {
this.instance = instance;
this.version = this.instance.exports.getVersion();
this.n32 = this.instance.exports.getFieldNumLen32();
this.instance.exports.getRawPrime();
const arr = new Uint32Array(this.n32);
for (let i = 0; i < this.n32; i++) {
arr[this.n32 - 1 - i] = this.instance.exports.readSharedRWMemory(i);
}
this.prime = fromArray32(arr);
this.witnessSize = this.instance.exports.getWitnessSize();
this.sanityCheck = sanityCheck;
}
circom_version() {
return this.instance.exports.getVersion();
}
async _doCalculateWitness(input, sanityCheck) {
//input is assumed to be a map from signals to arrays of bigints
this.instance.exports.init(this.sanityCheck || sanityCheck ? 1 : 0);
const keys = Object.keys(input);
var input_counter = 0;
keys.forEach((k) => {
const h = fnvHash(k);
const hMSB = parseInt(h.slice(0, 8), 16);
const hLSB = parseInt(h.slice(8, 16), 16);
const fArr = flatArray(input[k]);
let signalSize = this.instance.exports.getInputSignalSize(hMSB, hLSB);
if (signalSize < 0) {
throw new Error(`Signal ${k} not found\n`);
}
if (fArr.length < signalSize) {
throw new Error(`Not enough values for input signal ${k}\n`);
}
if (fArr.length > signalSize) {
throw new Error(`Too many values for input signal ${k}\n`);
}
for (let i = 0; i < fArr.length; i++) {
const arrFr = toArray32(BigInt(fArr[i]) % this.prime, this.n32);
for (let j = 0; j < this.n32; j++) {
this.instance.exports.writeSharedRWMemory(j, arrFr[this.n32 - 1 - j]);
}
try {
this.instance.exports.setInputSignal(hMSB, hLSB, i);
input_counter++;
} catch (err) {
// console.log(`After adding signal ${i} of ${k}`)
throw new Error(err);
}
}
});
if (input_counter < this.instance.exports.getInputSize()) {
throw new Error(
`Not all inputs have been set. Only ${input_counter} out of ${this.instance.exports.getInputSize()}`
);
}
}
async calculateWitness(input, sanityCheck) {
const w = [];
await this._doCalculateWitness(input, sanityCheck);
for (let i = 0; i < this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const arr = new Uint32Array(this.n32);
for (let j = 0; j < this.n32; j++) {
arr[this.n32 - 1 - j] = this.instance.exports.readSharedRWMemory(j);
}
w.push(fromArray32(arr));
}
return w;
}
async calculateBinWitness(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize * this.n32);
const buff = new Uint8Array(buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
for (let i = 0; i < this.witnessSize; i++) {
this.instance.exports.getWitness(i);
const pos = i * this.n32;
for (let j = 0; j < this.n32; j++) {
buff32[pos + j] = this.instance.exports.readSharedRWMemory(j);
}
}
return buff;
}
async calculateWTNSBin(input, sanityCheck) {
const buff32 = new Uint32Array(this.witnessSize * this.n32 + this.n32 + 11);
const buff = new Uint8Array(buff32.buffer);
await this._doCalculateWitness(input, sanityCheck);
//"wtns"
buff[0] = "w".charCodeAt(0);
buff[1] = "t".charCodeAt(0);
buff[2] = "n".charCodeAt(0);
buff[3] = "s".charCodeAt(0);
//version 2
buff32[1] = 2;
//number of sections: 2
buff32[2] = 2;
//id section 1
buff32[3] = 1;
const n8 = this.n32 * 4;
//id section 1 length in 64bytes
const idSection1length = 8 + n8;
const idSection1lengthHex = idSection1length.toString(16);
buff32[4] = parseInt(idSection1lengthHex.slice(0, 8), 16);
buff32[5] = parseInt(idSection1lengthHex.slice(8, 16), 16);
//this.n32
buff32[6] = n8;
//prime number
this.instance.exports.getRawPrime();
var pos = 7;
for (let j = 0; j < this.n32; j++) {
buff32[pos + j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
// witness size
buff32[pos] = this.witnessSize;
pos++;
//id section 2
buff32[pos] = 2;
pos++;
// section 2 length
const idSection2length = n8 * this.witnessSize;
const idSection2lengthHex = idSection2length.toString(16);
buff32[pos] = parseInt(idSection2lengthHex.slice(0, 8), 16);
buff32[pos + 1] = parseInt(idSection2lengthHex.slice(8, 16), 16);
pos += 2;
for (let i = 0; i < this.witnessSize; i++) {
this.instance.exports.getWitness(i);
for (let j = 0; j < this.n32; j++) {
buff32[pos + j] = this.instance.exports.readSharedRWMemory(j);
}
pos += this.n32;
}
return buff;
}
}
function toArray32(rem, size) {
const res = []; //new Uint32Array(size); //has no unshift
const radix = BigInt(0x100000000);
while (rem) {
res.unshift(Number(rem % radix));
rem = rem / radix;
}
if (size) {
var i = size - res.length;
while (i > 0) {
res.unshift(0);
i--;
}
}
return res;
}
function fromArray32(arr) {
//returns a BigInt
var res = BigInt(0);
const radix = BigInt(0x100000000);
for (let i = 0; i < arr.length; i++) {
res = res * radix + BigInt(arr[i]);
}
return res;
}
function flatArray(a) {
var res = [];
fillArray(res, a);
return res;
function fillArray(res, a) {
if (Array.isArray(a)) {
for (let i = 0; i < a.length; i++) {
fillArray(res, a[i]);
}
} else {
res.push(a);
}
}
}
function fnvHash(str) {
const uint64_max = BigInt(2) ** BigInt(64);
let hash = BigInt("0xCBF29CE484222325");
for (var i = 0; i < str.length; i++) {
hash ^= BigInt(str[i].charCodeAt());
hash *= BigInt(0x100000001b3);
hash %= uint64_max;
}
let shash = hash.toString(16);
let n = 16 - shash.length;
shash = "0".repeat(n).concat(shash);
return shash;
}

297
rln-wasm/src/lib.rs Normal file
View File

@@ -0,0 +1,297 @@
#![cfg(target_arch = "wasm32")]
use js_sys::{BigInt as JsBigInt, Object, Uint8Array};
use num_bigint::BigInt;
use rln::public::{hash, poseidon_hash, RLN};
use std::vec::Vec;
use wasm_bindgen::prelude::*;
#[cfg(feature = "multithread")]
pub use wasm_bindgen_rayon::init_thread_pool;
#[wasm_bindgen(js_name = initPanicHook)]
pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
#[wasm_bindgen(js_name = RLN)]
pub struct RLNWrapper {
// The purpose of this wrapper is to hold a RLN instance with the 'static lifetime
// because wasm_bindgen does not allow returning elements with lifetimes
instance: RLN,
}
// Macro to call methods with arbitrary amount of arguments,
// which have the last argument is output buffer pointer
// First argument to the macro is context,
// second is the actual method on `RLN`
// third is the aforementioned output buffer argument
// rest are all other arguments to the method
macro_rules! call_with_output_and_error_msg {
// this variant is needed for the case when
// there are zero other arguments
($instance:expr, $method:ident, $error_msg:expr) => {
{
let mut output_data: Vec<u8> = Vec::new();
let new_instance = $instance.process();
if let Err(err) = new_instance.instance.$method(&mut output_data) {
std::mem::forget(output_data);
Err(format!("Msg: {:#?}, Error: {:#?}", $error_msg, err))
} else {
let result = Uint8Array::from(&output_data[..]);
std::mem::forget(output_data);
Ok(result)
}
}
};
($instance:expr, $method:ident, $error_msg:expr, $( $arg:expr ),* ) => {
{
let mut output_data: Vec<u8> = Vec::new();
let new_instance = $instance.process();
if let Err(err) = new_instance.instance.$method($($arg.process()),*, &mut output_data) {
std::mem::forget(output_data);
Err(format!("Msg: {:#?}, Error: {:#?}", $error_msg, err))
} else {
let result = Uint8Array::from(&output_data[..]);
std::mem::forget(output_data);
Ok(result)
}
}
};
}
macro_rules! call {
($instance:expr, $method:ident $(, $arg:expr)*) => {
{
let new_instance: &mut RLNWrapper = $instance.process();
new_instance.instance.$method($($arg.process()),*)
}
}
}
macro_rules! call_bool_method_with_error_msg {
($instance:expr, $method:ident, $error_msg:expr $(, $arg:expr)*) => {
{
let new_instance: &RLNWrapper = $instance.process();
new_instance.instance.$method($($arg.process()),*).map_err(|err| format!("Msg: {:#?}, Error: {:#?}", $error_msg, err))
}
}
}
// Macro to execute a function with arbitrary amount of arguments,
// First argument is the function to execute
// Rest are all other arguments to the method
macro_rules! fn_call_with_output_and_error_msg {
// this variant is needed for the case when
// there are zero other arguments
($func:ident, $error_msg:expr) => {
{
let mut output_data: Vec<u8> = Vec::new();
if let Err(err) = $func(&mut output_data) {
std::mem::forget(output_data);
Err(format!("Msg: {:#?}, Error: {:#?}", $error_msg, err))
} else {
let result = Uint8Array::from(&output_data[..]);
std::mem::forget(output_data);
Ok(result)
}
}
};
($func:ident, $error_msg:expr, $( $arg:expr ),* ) => {
{
let mut output_data: Vec<u8> = Vec::new();
if let Err(err) = $func($($arg.process()),*, &mut output_data) {
std::mem::forget(output_data);
Err(format!("Msg: {:#?}, Error: {:#?}", $error_msg, err))
} else {
let result = Uint8Array::from(&output_data[..]);
std::mem::forget(output_data);
Ok(result)
}
}
};
}
trait ProcessArg {
type ReturnType;
fn process(self) -> Self::ReturnType;
}
impl ProcessArg for usize {
type ReturnType = usize;
fn process(self) -> Self::ReturnType {
self
}
}
impl<T> ProcessArg for Vec<T> {
type ReturnType = Vec<T>;
fn process(self) -> Self::ReturnType {
self
}
}
impl ProcessArg for *const RLN {
type ReturnType = &'static RLN;
fn process(self) -> Self::ReturnType {
unsafe { &*self }
}
}
impl ProcessArg for *const RLNWrapper {
type ReturnType = &'static RLNWrapper;
fn process(self) -> Self::ReturnType {
unsafe { &*self }
}
}
impl ProcessArg for *mut RLNWrapper {
type ReturnType = &'static mut RLNWrapper;
fn process(self) -> Self::ReturnType {
unsafe { &mut *self }
}
}
impl<'a> ProcessArg for &'a [u8] {
type ReturnType = &'a [u8];
fn process(self) -> Self::ReturnType {
self
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = newRLN)]
pub fn wasm_new(zkey: Uint8Array) -> Result<*mut RLNWrapper, String> {
let instance = RLN::new_with_params(zkey.to_vec()).map_err(|err| format!("{:#?}", err))?;
let wrapper = RLNWrapper { instance };
Ok(Box::into_raw(Box::new(wrapper)))
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = rlnWitnessToJson)]
pub fn wasm_rln_witness_to_json(
ctx: *mut RLNWrapper,
serialized_witness: Uint8Array,
) -> Result<Object, String> {
let inputs = call!(
ctx,
get_rln_witness_bigint_json,
&serialized_witness.to_vec()[..]
)
.map_err(|err| err.to_string())?;
let js_value = serde_wasm_bindgen::to_value(&inputs).map_err(|err| err.to_string())?;
Object::from_entries(&js_value).map_err(|err| format!("{:#?}", err))
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = generateRLNProofWithWitness)]
pub fn wasm_generate_rln_proof_with_witness(
ctx: *mut RLNWrapper,
calculated_witness: Vec<JsBigInt>,
serialized_witness: Uint8Array,
) -> Result<Uint8Array, String> {
let mut witness_vec: Vec<BigInt> = vec![];
for v in calculated_witness {
witness_vec.push(
v.to_string(10)
.map_err(|err| format!("{:#?}", err))?
.as_string()
.ok_or("not a string error")?
.parse::<BigInt>()
.map_err(|err| format!("{:#?}", err))?,
);
}
call_with_output_and_error_msg!(
ctx,
generate_rln_proof_with_witness,
"could not generate proof",
witness_vec,
serialized_witness.to_vec()
)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = generateMembershipKey)]
pub fn wasm_key_gen(ctx: *const RLNWrapper) -> Result<Uint8Array, String> {
call_with_output_and_error_msg!(ctx, key_gen, "could not generate membership keys")
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = generateExtendedMembershipKey)]
pub fn wasm_extended_key_gen(ctx: *const RLNWrapper) -> Result<Uint8Array, String> {
call_with_output_and_error_msg!(ctx, extended_key_gen, "could not generate membership keys")
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = generateSeededMembershipKey)]
pub fn wasm_seeded_key_gen(ctx: *const RLNWrapper, seed: Uint8Array) -> Result<Uint8Array, String> {
call_with_output_and_error_msg!(
ctx,
seeded_key_gen,
"could not generate membership key",
&seed.to_vec()[..]
)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = generateSeededExtendedMembershipKey)]
pub fn wasm_seeded_extended_key_gen(
ctx: *const RLNWrapper,
seed: Uint8Array,
) -> Result<Uint8Array, String> {
call_with_output_and_error_msg!(
ctx,
seeded_extended_key_gen,
"could not generate membership key",
&seed.to_vec()[..]
)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = recovedIDSecret)]
pub fn wasm_recover_id_secret(
ctx: *const RLNWrapper,
input_proof_data_1: Uint8Array,
input_proof_data_2: Uint8Array,
) -> Result<Uint8Array, String> {
call_with_output_and_error_msg!(
ctx,
recover_id_secret,
"could not recover id secret",
&input_proof_data_1.to_vec()[..],
&input_proof_data_2.to_vec()[..]
)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[wasm_bindgen(js_name = verifyWithRoots)]
pub fn wasm_verify_with_roots(
ctx: *const RLNWrapper,
proof: Uint8Array,
roots: Uint8Array,
) -> Result<bool, String> {
call_bool_method_with_error_msg!(
ctx,
verify_with_roots,
"error while verifying proof with roots".to_string(),
&proof.to_vec()[..],
&roots.to_vec()[..]
)
}
#[wasm_bindgen(js_name = hash)]
pub fn wasm_hash(input: Uint8Array) -> Result<Uint8Array, String> {
fn_call_with_output_and_error_msg!(hash, "could not generate hash", &input.to_vec()[..])
}
#[wasm_bindgen(js_name = poseidonHash)]
pub fn wasm_poseidon_hash(input: Uint8Array) -> Result<Uint8Array, String> {
fn_call_with_output_and_error_msg!(
poseidon_hash,
"could not generate poseidon hash",
&input.to_vec()[..]
)
}

257
rln-wasm/tests/browser.rs Normal file
View File

@@ -0,0 +1,257 @@
#![cfg(target_arch = "wasm32")]
#[cfg(test)]
mod tests {
use js_sys::{BigInt as JsBigInt, Date, Object, Uint8Array};
use rln::circuit::{Fr, TEST_TREE_HEIGHT};
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::poseidon_tree::PoseidonTree;
use rln::protocol::{prepare_verify_input, rln_witness_from_values, serialize_witness};
use rln::utils::{bytes_le_to_fr, fr_to_bytes_le};
use rln_wasm::{
wasm_generate_rln_proof_with_witness, wasm_key_gen, wasm_new, wasm_rln_witness_to_json,
wasm_verify_with_roots,
};
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
use wasm_bindgen_test::{console_log, wasm_bindgen_test, wasm_bindgen_test_configure};
use zerokit_utils::merkle_tree::merkle_tree::ZerokitMerkleTree;
#[cfg(feature = "multithread")]
use {rln_wasm::init_thread_pool, wasm_bindgen_futures::JsFuture, web_sys::window};
#[wasm_bindgen(inline_js = r#"
export function isThreadpoolSupported() {
return typeof SharedArrayBuffer !== 'undefined' &&
typeof Atomics !== 'undefined' &&
typeof crossOriginIsolated !== 'undefined' &&
crossOriginIsolated;
}
export function initWitnessCalculator(jsCode) {
eval(jsCode);
if (typeof window.witnessCalculatorBuilder !== 'function') {
return false;
}
return true;
}
export function readFile(data) {
return new Uint8Array(data);
}
export async function calculateWitness(circom_data, inputs) {
const wasmBuffer = circom_data instanceof Uint8Array ? circom_data : new Uint8Array(circom_data);
const witnessCalculator = await window.witnessCalculatorBuilder(wasmBuffer);
const calculatedWitness = await witnessCalculator.calculateWitness(inputs, false);
return JSON.stringify(calculatedWitness, (key, value) =>
typeof value === "bigint" ? value.toString() : value
);
}
"#)]
extern "C" {
#[wasm_bindgen(catch)]
fn isThreadpoolSupported() -> Result<bool, JsValue>;
#[wasm_bindgen(catch)]
fn initWitnessCalculator(js: &str) -> Result<bool, JsValue>;
#[wasm_bindgen(catch)]
fn readFile(data: &[u8]) -> Result<Uint8Array, JsValue>;
#[wasm_bindgen(catch)]
async fn calculateWitness(circom_data: &[u8], inputs: Object) -> Result<JsValue, JsValue>;
}
const WITNESS_CALCULATOR_JS: &str = include_str!("../resources/witness_calculator_browser.js");
#[cfg(feature = "arkzkey")]
const ZKEY_BYTES: &[u8] =
include_bytes!("../../rln/resources/tree_height_20/rln_final.arkzkey");
#[cfg(not(feature = "arkzkey"))]
const ZKEY_BYTES: &[u8] = include_bytes!("../../rln/resources/tree_height_20/rln_final.zkey");
const CIRCOM_BYTES: &[u8] = include_bytes!("../../rln/resources/tree_height_20/rln.wasm");
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
pub async fn rln_wasm_benchmark() {
// Check if thread pool is supported
#[cfg(feature = "multithread")]
if !isThreadpoolSupported().expect("Failed to check thread pool support") {
panic!("Thread pool is NOT supported");
} else {
// Initialize thread pool
let cpu_count = window()
.expect("Failed to get window")
.navigator()
.hardware_concurrency() as usize;
JsFuture::from(init_thread_pool(cpu_count))
.await
.expect("Failed to initialize thread pool");
}
// Initialize the witness calculator
initWitnessCalculator(WITNESS_CALCULATOR_JS)
.expect("Failed to initialize witness calculator");
let mut results = String::from("\nbenchmarks:\n");
let iterations = 10;
let zkey = readFile(&ZKEY_BYTES).expect("Failed to read zkey file");
// Benchmark wasm_new
let start_wasm_new = Date::now();
for _ in 0..iterations {
let _ = wasm_new(zkey.clone()).expect("Failed to create RLN instance");
}
let wasm_new_result = Date::now() - start_wasm_new;
// Create RLN instance for other benchmarks
let rln_instance = wasm_new(zkey).expect("Failed to create RLN instance");
let mut tree = PoseidonTree::default(TEST_TREE_HEIGHT).expect("Failed to create tree");
// Benchmark wasm_key_gen
let start_wasm_key_gen = Date::now();
for _ in 0..iterations {
let _ = wasm_key_gen(rln_instance).expect("Failed to generate keys");
}
let wasm_key_gen_result = Date::now() - start_wasm_key_gen;
// Generate identity pair for other benchmarks
let mem_keys = wasm_key_gen(rln_instance).expect("Failed to generate keys");
let id_key = mem_keys.subarray(0, 32);
let (identity_secret_hash, _) = bytes_le_to_fr(&id_key.to_vec());
let (id_commitment, _) = bytes_le_to_fr(&mem_keys.subarray(32, 64).to_vec());
let epoch = hash_to_field(b"test-epoch");
let rln_identifier = hash_to_field(b"test-rln-identifier");
let external_nullifier = poseidon_hash(&[epoch, rln_identifier]);
let identity_index = tree.leaves_set();
let user_message_limit = Fr::from(100);
let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]);
tree.update_next(rate_commitment)
.expect("Failed to update tree");
let message_id = Fr::from(0);
let signal: [u8; 32] = [0; 32];
let x = hash_to_field(&signal);
let merkle_proof = tree
.proof(identity_index)
.expect("Failed to generate merkle proof");
let rln_witness = rln_witness_from_values(
identity_secret_hash,
&merkle_proof,
x,
external_nullifier,
user_message_limit,
message_id,
)
.expect("Failed to create RLN witness");
let serialized_witness =
serialize_witness(&rln_witness).expect("Failed to serialize witness");
let witness_buffer = Uint8Array::from(&serialized_witness[..]);
let json_inputs = wasm_rln_witness_to_json(rln_instance, witness_buffer.clone())
.expect("Failed to convert witness to JSON");
// Benchmark calculateWitness
let start_calculate_witness = Date::now();
for _ in 0..iterations {
let _ = calculateWitness(&CIRCOM_BYTES, json_inputs.clone())
.await
.expect("Failed to calculate witness");
}
let calculate_witness_result = Date::now() - start_calculate_witness;
// Calculate witness for other benchmarks
let calculated_witness_json = calculateWitness(&CIRCOM_BYTES, json_inputs)
.await
.expect("Failed to calculate witness")
.as_string()
.expect("Failed to convert calculated witness to string");
let calculated_witness_vec_str: Vec<String> =
serde_json::from_str(&calculated_witness_json).expect("Failed to parse JSON");
let calculated_witness: Vec<JsBigInt> = calculated_witness_vec_str
.iter()
.map(|x| JsBigInt::new(&x.into()).expect("Failed to create JsBigInt"))
.collect();
// Benchmark wasm_generate_rln_proof_with_witness
let start_wasm_generate_rln_proof_with_witness = Date::now();
for _ in 0..iterations {
let _ = wasm_generate_rln_proof_with_witness(
rln_instance,
calculated_witness.clone(),
witness_buffer.clone(),
)
.expect("Failed to generate proof");
}
let wasm_generate_rln_proof_with_witness_result =
Date::now() - start_wasm_generate_rln_proof_with_witness;
// Generate a proof for other benchmarks
let proof =
wasm_generate_rln_proof_with_witness(rln_instance, calculated_witness, witness_buffer)
.expect("Failed to generate proof");
let proof_data = proof.to_vec();
let verify_input = prepare_verify_input(proof_data, &signal);
let input_buffer = Uint8Array::from(&verify_input[..]);
let root = tree.root();
let roots_serialized = fr_to_bytes_le(&root);
let roots_buffer = Uint8Array::from(&roots_serialized[..]);
// Benchmark wasm_verify_with_roots
let start_wasm_verify_with_roots = Date::now();
for _ in 0..iterations {
let _ =
wasm_verify_with_roots(rln_instance, input_buffer.clone(), roots_buffer.clone())
.expect("Failed to verify proof");
}
let wasm_verify_with_roots_result = Date::now() - start_wasm_verify_with_roots;
// Verify the proof with the root
let is_proof_valid = wasm_verify_with_roots(rln_instance, input_buffer, roots_buffer)
.expect("Failed to verify proof");
assert!(is_proof_valid, "verification failed");
// Format and display results
let format_duration = |duration_ms: f64| -> String {
let avg_ms = duration_ms / (iterations as f64);
if avg_ms >= 1000.0 {
format!("{:.3} s", avg_ms / 1000.0)
} else {
format!("{:.3} ms", avg_ms)
}
};
results.push_str(&format!("wasm_new: {}\n", format_duration(wasm_new_result)));
results.push_str(&format!(
"wasm_key_gen: {}\n",
format_duration(wasm_key_gen_result)
));
results.push_str(&format!(
"calculateWitness: {}\n",
format_duration(calculate_witness_result)
));
results.push_str(&format!(
"wasm_generate_rln_proof_with_witness: {}\n",
format_duration(wasm_generate_rln_proof_with_witness_result)
));
results.push_str(&format!(
"wasm_verify_with_roots: {}\n",
format_duration(wasm_verify_with_roots_result)
));
// Log the results
console_log!("{results}");
}
}

222
rln-wasm/tests/node.rs Normal file
View File

@@ -0,0 +1,222 @@
#![cfg(not(feature = "multithread"))]
#![cfg(target_arch = "wasm32")]
#[cfg(test)]
mod tests {
use js_sys::{BigInt as JsBigInt, Date, Object, Uint8Array};
use rln::circuit::{Fr, TEST_TREE_HEIGHT};
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::poseidon_tree::PoseidonTree;
use rln::protocol::{prepare_verify_input, rln_witness_from_values, serialize_witness};
use rln::utils::{bytes_le_to_fr, fr_to_bytes_le};
use rln_wasm::{
wasm_generate_rln_proof_with_witness, wasm_key_gen, wasm_new, wasm_rln_witness_to_json,
wasm_verify_with_roots,
};
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
use wasm_bindgen_test::{console_log, wasm_bindgen_test};
use zerokit_utils::merkle_tree::merkle_tree::ZerokitMerkleTree;
#[wasm_bindgen(inline_js = r#"
const fs = require("fs");
module.exports = {
readFile: function (path) {
return fs.readFileSync(path);
},
calculateWitness: async function (circom_path, inputs) {
const wc = require("resources/witness_calculator_node.js");
const wasmFile = fs.readFileSync(circom_path);
const wasmFileBuffer = wasmFile.slice(
wasmFile.byteOffset,
wasmFile.byteOffset + wasmFile.byteLength
);
const witnessCalculator = await wc(wasmFileBuffer);
const calculatedWitness = await witnessCalculator.calculateWitness(
inputs,
false
);
return JSON.stringify(calculatedWitness, (key, value) =>
typeof value === "bigint" ? value.toString() : value
);
},
};
"#)]
extern "C" {
#[wasm_bindgen(catch)]
fn readFile(path: &str) -> Result<Uint8Array, JsValue>;
#[wasm_bindgen(catch)]
async fn calculateWitness(circom_path: &str, input: Object) -> Result<JsValue, JsValue>;
}
#[cfg(feature = "arkzkey")]
const ZKEY_PATH: &str = "../rln/resources/tree_height_20/rln_final.arkzkey";
#[cfg(not(feature = "arkzkey"))]
const ZKEY_PATH: &str = "../rln/resources/tree_height_20/rln_final.zkey";
const CIRCOM_PATH: &str = "../rln/resources/tree_height_20/rln.wasm";
#[wasm_bindgen_test]
pub async fn rln_wasm_benchmark() {
let mut results = String::from("\nbenchmarks:\n");
let iterations = 10;
let zkey = readFile(&ZKEY_PATH).expect("Failed to read zkey file");
// Benchmark wasm_new
let start_wasm_new = Date::now();
for _ in 0..iterations {
let _ = wasm_new(zkey.clone()).expect("Failed to create RLN instance");
}
let wasm_new_result = Date::now() - start_wasm_new;
// Create RLN instance for other benchmarks
let rln_instance = wasm_new(zkey).expect("Failed to create RLN instance");
let mut tree = PoseidonTree::default(TEST_TREE_HEIGHT).expect("Failed to create tree");
// Benchmark wasm_key_gen
let start_wasm_key_gen = Date::now();
for _ in 0..iterations {
let _ = wasm_key_gen(rln_instance).expect("Failed to generate keys");
}
let wasm_key_gen_result = Date::now() - start_wasm_key_gen;
// Generate identity pair for other benchmarks
let mem_keys = wasm_key_gen(rln_instance).expect("Failed to generate keys");
let id_key = mem_keys.subarray(0, 32);
let (identity_secret_hash, _) = bytes_le_to_fr(&id_key.to_vec());
let (id_commitment, _) = bytes_le_to_fr(&mem_keys.subarray(32, 64).to_vec());
let epoch = hash_to_field(b"test-epoch");
let rln_identifier = hash_to_field(b"test-rln-identifier");
let external_nullifier = poseidon_hash(&[epoch, rln_identifier]);
let identity_index = tree.leaves_set();
let user_message_limit = Fr::from(100);
let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]);
tree.update_next(rate_commitment)
.expect("Failed to update tree");
let message_id = Fr::from(0);
let signal: [u8; 32] = [0; 32];
let x = hash_to_field(&signal);
let merkle_proof = tree
.proof(identity_index)
.expect("Failed to generate merkle proof");
let rln_witness = rln_witness_from_values(
identity_secret_hash,
&merkle_proof,
x,
external_nullifier,
user_message_limit,
message_id,
)
.expect("Failed to create RLN witness");
let serialized_witness =
serialize_witness(&rln_witness).expect("Failed to serialize witness");
let witness_buffer = Uint8Array::from(&serialized_witness[..]);
let json_inputs = wasm_rln_witness_to_json(rln_instance, witness_buffer.clone())
.expect("Failed to convert witness to JSON");
// Benchmark calculateWitness
let start_calculate_witness = Date::now();
for _ in 0..iterations {
let _ = calculateWitness(&CIRCOM_PATH, json_inputs.clone())
.await
.expect("Failed to calculate witness");
}
let calculate_witness_result = Date::now() - start_calculate_witness;
// Calculate witness for other benchmarks
let calculated_witness_json = calculateWitness(&CIRCOM_PATH, json_inputs)
.await
.expect("Failed to calculate witness")
.as_string()
.expect("Failed to convert calculated witness to string");
let calculated_witness_vec_str: Vec<String> =
serde_json::from_str(&calculated_witness_json).expect("Failed to parse JSON");
let calculated_witness: Vec<JsBigInt> = calculated_witness_vec_str
.iter()
.map(|x| JsBigInt::new(&x.into()).expect("Failed to create JsBigInt"))
.collect();
// Benchmark wasm_generate_rln_proof_with_witness
let start_wasm_generate_rln_proof_with_witness = Date::now();
for _ in 0..iterations {
let _ = wasm_generate_rln_proof_with_witness(
rln_instance,
calculated_witness.clone(),
witness_buffer.clone(),
)
.expect("Failed to generate proof");
}
let wasm_generate_rln_proof_with_witness_result =
Date::now() - start_wasm_generate_rln_proof_with_witness;
// Generate a proof for other benchmarks
let proof =
wasm_generate_rln_proof_with_witness(rln_instance, calculated_witness, witness_buffer)
.expect("Failed to generate proof");
let proof_data = proof.to_vec();
let verify_input = prepare_verify_input(proof_data, &signal);
let input_buffer = Uint8Array::from(&verify_input[..]);
let root = tree.root();
let roots_serialized = fr_to_bytes_le(&root);
let roots_buffer = Uint8Array::from(&roots_serialized[..]);
// Benchmark wasm_verify_with_roots
let start_wasm_verify_with_roots = Date::now();
for _ in 0..iterations {
let _ =
wasm_verify_with_roots(rln_instance, input_buffer.clone(), roots_buffer.clone())
.expect("Failed to verify proof");
}
let wasm_verify_with_roots_result = Date::now() - start_wasm_verify_with_roots;
// Verify the proof with the root
let is_proof_valid = wasm_verify_with_roots(rln_instance, input_buffer, roots_buffer)
.expect("Failed to verify proof");
assert!(is_proof_valid, "verification failed");
// Format and display results
let format_duration = |duration_ms: f64| -> String {
let avg_ms = duration_ms / (iterations as f64);
if avg_ms >= 1000.0 {
format!("{:.3} s", avg_ms / 1000.0)
} else {
format!("{:.3} ms", avg_ms)
}
};
results.push_str(&format!("wasm_new: {}\n", format_duration(wasm_new_result)));
results.push_str(&format!(
"wasm_key_gen: {}\n",
format_duration(wasm_key_gen_result)
));
results.push_str(&format!(
"calculate_witness: {}\n",
format_duration(calculate_witness_result)
));
results.push_str(&format!(
"wasm_generate_rln_proof_with_witness: {}\n",
format_duration(wasm_generate_rln_proof_with_witness_result)
));
results.push_str(&format!(
"wasm_verify_with_roots: {}\n",
format_duration(wasm_verify_with_roots_result)
));
// Log the results
console_log!("{results}");
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "rln"
version = "0.7.0"
version = "0.8.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "APIs to manage, compute and verify zkSNARK proofs and RLN primitives"
@@ -9,7 +9,7 @@ homepage = "https://vac.dev"
repository = "https://github.com/vacp2p/zerokit"
[lib]
crate-type = ["rlib", "staticlib"]
crate-type = ["rlib", "staticlib", "cdylib"]
bench = false
# This flag disable cargo doctests, i.e. testing example code-snippets in documentation
@@ -18,78 +18,76 @@ doctest = false
[dependencies]
# ZKP Generation
ark-bn254 = { version = "0.5.0", features = ["std"] }
ark-ff = { version = "0.5.0", features = ["std", "asm"] }
ark-serialize = { version = "0.5.0", features = ["derive"] }
ark-ec = { version = "0.5.0", default-features = false }
ark-std = { version = "0.5.0", default-features = false }
ark-groth16 = { version = "0.5.0", features = [
ark-relations = { version = "0.5.1", features = ["std"] }
ark-ff = { version = "0.5.0", default-features = false, features = [
"parallel",
] }
ark-ec = { version = "0.5.0", default-features = false, features = [
"parallel",
] }
ark-std = { version = "0.5.0", default-features = false, features = [
"parallel",
] }
ark-poly = { version = "0.5.0", default-features = false, features = [
"parallel",
] }
ark-groth16 = { version = "0.5.0", default-features = false, features = [
"parallel",
] }
ark-serialize = { version = "0.5.0", default-features = false, features = [
"parallel",
], default-features = false }
ark-relations = { version = "0.5.0", default-features = false, features = [
"std",
] }
ark-circom = { version = "0.5.0" }
ark-r1cs-std = { version = "0.5.0" }
# error handling
color-eyre = "0.6.2"
thiserror = "2.0.11"
thiserror = "2.0.12"
# utilities
byteorder = "1.4.3"
byteorder = "1.5.0"
cfg-if = "1.0"
num-bigint = { version = "0.4.6", default-features = false, features = [
"rand",
] }
num-bigint = { version = "0.4.6", default-features = false, features = ["std"] }
num-traits = "0.2.19"
once_cell = "1.19.0"
lazy_static = "1.4.0"
once_cell = "1.21.3"
lazy_static = "1.5.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
ruint = { version = "1.12.4", features = ["rand", "serde", "ark-ff-04"] }
ruint = { version = "1.15.0", features = ["rand", "serde", "ark-ff-04"] }
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
utils = { package = "zerokit_utils", version = "0.5.2", path = "../utils/", default-features = false }
utils = { package = "zerokit_utils", version = "0.6.0", path = "../utils", default-features = false }
# serialization
prost = "0.13.1"
prost = "0.13.5"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
document-features = { version = "=0.2.10", optional = true }
document-features = { version = "0.2.11", optional = true }
[dev-dependencies]
sled = "0.34.7"
criterion = { version = "0.4.0", features = ["html_reports"] }
criterion = { version = "0.6.0", features = ["html_reports"] }
claims = "0.8.0"
[features]
default = ["parallel", "pmtree-ft"]
parallel = [
"ark-ec/parallel",
"ark-ff/parallel",
"ark-std/parallel",
"ark-groth16/parallel",
"utils/parallel",
]
fullmerkletree = ["default"]
arkzkey = []
default = ["pmtree-ft"]
fullmerkletree = []
stateless = []
arkzkey = []
# Note: pmtree feature is still experimental
pmtree-ft = ["utils/pmtree-ft"]
[[bench]]
name = "pmtree_benchmark"
harness = false
[[bench]]
name = "circuit_loading_benchmark"
harness = false
[[bench]]
name = "circuit_loading_arkzkey_benchmark"
harness = false
required-features = ["arkzkey"]
[[bench]]
name = "circuit_loading_benchmark"
harness = false
[[bench]]
name = "pmtree_benchmark"
harness = false
[[bench]]
name = "poseidon_tree_benchmark"
harness = false

View File

@@ -1,8 +1,12 @@
# Zerokit RLN Module
[![Crates.io](https://img.shields.io/crates/v/rln.svg)](https://crates.io/crates/rln)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
The Zerokit RLN Module provides a Rust implementation for working with Rate-Limiting Nullifier [RLN](https://rfc.vac.dev/spec/32/) zkSNARK proofs and primitives. This module allows you to:
The Zerokit RLN Module provides a Rust implementation for working with
Rate-Limiting Nullifier [RLN](https://rfc.vac.dev/spec/32/) zkSNARK proofs and primitives.
This module allows you to:
- Generate and verify RLN proofs
- Work with Merkle trees for commitment storage
@@ -11,7 +15,8 @@ The Zerokit RLN Module provides a Rust implementation for working with Rate-Limi
## Quick Start
> [!IMPORTANT]
> Version 0.6.1 is required for WASM support or x32 architecture. Current version doesn't support these platforms due to dependency issues. WASM support will return in a future release.
> Version 0.7.0 is the only version that does not support WASM and x32 architecture.
> WASM support is available in version 0.8.0 and above.
### Add RLN as dependency
@@ -24,9 +29,16 @@ rln = { git = "https://github.com/vacp2p/zerokit" }
## Basic Usage Example
Note that we need to pass to RLN object constructor the path where the graph file (`graph.bin`, built for the input tree size), the corresponding proving key (`rln_final.zkey`) or (`rln_final_uncompr.arkzkey`) and verification key (`verification_key.arkvkey`, optional) are found.
The RLN object constructor requires the following files:
In the following we will use [cursors](https://doc.rust-lang.org/std/io/struct.Cursor.html) as readers/writers for interfacing with RLN public APIs.
- `graph.bin`: The graph file built for the input tree size
- `rln_final.zkey` or `rln_final_uncompr.arkzkey`: The proving key
- `verification_key.arkvkey`: The verification key (optional)
Additionally, `rln.wasm` is used for testing in the rln-wasm module.
In the following we will use [cursors](https://doc.rust-lang.org/std/io/struct.Cursor.html)
as readers/writers for interfacing with RLN public APIs.
```rust
use std::io::Cursor;
@@ -34,9 +46,9 @@ use std::io::Cursor;
use rln::{
circuit::Fr,
hashers::{hash_to_field, poseidon_hash},
protocol::{keygen, prepare_verify_input},
protocol::{keygen, prepare_prove_input, prepare_verify_input},
public::RLN,
utils::{fr_to_bytes_le, normalize_usize},
utils::fr_to_bytes_le,
};
use serde_json::json;
@@ -53,8 +65,8 @@ fn main() {
// 3. Add a rate commitment to the Merkle tree
let id_index = 10;
let user_message_limit = 10;
let rate_commitment = poseidon_hash(&[id_commitment, Fr::from(user_message_limit)]);
let user_message_limit = Fr::from(10);
let rate_commitment = poseidon_hash(&[id_commitment, user_message_limit]);
let mut buffer = Cursor::new(fr_to_bytes_le(&rate_commitment));
rln.set_leaf(id_index, &mut buffer).unwrap();
@@ -65,35 +77,41 @@ fn main() {
// We generate rln_identifier from a date seed and we ensure is
// mapped to a field element by hashing-to-field its content
let rln_identifier = hash_to_field(b"test-rln-identifier");
// We generate a external nullifier
let external_nullifier = poseidon_hash(&[epoch, rln_identifier]);
// We choose a message_id satisfy 0 <= message_id < user_message_limit
let message_id = Fr::from(1);
// 5. Generate and verify a proof for a message
let signal = b"RLN is awesome";
// 6. Prepare input for generate_rln_proof API
// input_data is [ identity_secret<32> | id_index<8> | external_nullifier<32> | user_message_limit<32> | message_id<32> | signal_len<8> | signal<var> ]
let mut serialized: Vec<u8> = Vec::new();
serialized.append(&mut fr_to_bytes_le(&identity_secret_hash));
serialized.append(&mut normalize_usize(id_index));
serialized.append(&mut fr_to_bytes_le(&Fr::from(user_message_limit)));
serialized.append(&mut fr_to_bytes_le(&Fr::from(1)));
serialized.append(&mut fr_to_bytes_le(&external_nullifier));
serialized.append(&mut normalize_usize(signal.len()));
serialized.append(&mut signal.to_vec());
// input_data is [ identity_secret<32> | id_index<8> | external_nullifier<32>
// | user_message_limit<32> | message_id<32> | signal_len<8> | signal<var> ]
let prove_input = prepare_prove_input(
identity_secret_hash,
id_index,
user_message_limit,
message_id,
external_nullifier,
signal,
);
// 7. Generate a RLN proof
// We generate a RLN proof for proof_input
let mut input_buffer = Cursor::new(serialized);
let mut input_buffer = Cursor::new(prove_input);
let mut output_buffer = Cursor::new(Vec::<u8>::new());
rln.generate_rln_proof(&mut input_buffer, &mut output_buffer)
.unwrap();
// We get the public outputs returned by the circuit evaluation
// The byte vector `proof_data` is serialized as `[ zk-proof | tree_root | external_nullifier | share_x | share_y | nullifier ]`.
// The byte vector `proof_data` is serialized as
// `[ zk-proof | tree_root | external_nullifier | share_x | share_y | nullifier ]`.
let proof_data = output_buffer.into_inner();
// 8. Verify a RLN proof
// Input buffer is serialized as `[proof_data | signal_len | signal ]`, where `proof_data` is (computed as) the output obtained by `generate_rln_proof`.
// Input buffer is serialized as `[proof_data | signal_len | signal ]`,
// where `proof_data` is (computed as) the output obtained by `generate_rln_proof`.
let verify_data = prepare_verify_input(proof_data, signal);
// We verify the zk-proof against the provided proof values
@@ -110,16 +128,22 @@ fn main() {
The `external nullifier` includes two parameters.
The first one is `epoch` and it's used to identify messages received in a certain time frame.
It usually corresponds to the current UNIX time but can also be set to a random value or generated by a seed, provided that it corresponds to a field element.
It usually corresponds to the current UNIX time but can also be set to a random value or generated by a seed,
provided that it corresponds to a field element.
The second one is `rln_identifier` and it's used to prevent a RLN ZK proof generated for one application to be re-used in another one.
The second one is `rln_identifier` and it's used to prevent a RLN ZK proof generated
for one application to be re-used in another one.
### Features
- **Multiple Backend Support**: Choose between different zkey formats with feature flags
- `arkzkey`: Use the optimized Arkworks-compatible zkey format (faster loading)
- `stateless`: For stateless proof verification
- **Pre-compiled Circuits**: Ready-to-use circuits with Merkle tree height of 20
- **Pre-compiled Circuits**: Ready-to-use circuits with Merkle tree depth of 20
- **Wasm Support**: WebAssembly bindings via rln-wasm crate with features like:
- Browser and Node.js compatibility
- Optional multi-threading support using wasm-bindgen-rayon
- Headless browser testing capabilities
## Building and Testing
@@ -147,13 +171,16 @@ cargo make test_stateless # For stateless feature
## Advanced: Custom Circuit Compilation
The `rln` (<https://github.com/rate-limiting-nullifier/circom-rln>) repository, which contains the RLN circuit implementation is using for pre-compiled RLN circuit for zerokit RLN.
The `rln` (<https://github.com/rate-limiting-nullifier/circom-rln>) repository,
which contains the RLN circuit implementation is using for pre-compiled RLN circuit for zerokit RLN.
If you want to compile your own RLN circuit, you can follow the instructions below.
### 1. Compile ZK Circuits for getting the zkey and verification key files
This script actually generates not only the zkey and verification key files for the RLN circuit, but also the execution wasm file used for witness calculation.
However, the wasm file is not needed for the `rln` module, because current implementation uses the iden3 graph file for witness calculation.
This script actually generates not only the zkey and verification key files for the RLN circuit,
but also the execution wasm file used for witness calculation.
However, the wasm file is not needed for the `rln` module,
because current implementation uses the iden3 graph file for witness calculation.
This graph file is generated by the `circom-witnesscalc` tool in [step 2](#2-generate-witness-calculation-graph).
To customize the circuit parameters, modify `circom-rln/circuits/rln.circom`:
@@ -166,19 +193,27 @@ component main { public [x, externalNullifier] } = RLN(N, M);
Where:
- `N`: Merkle tree height, determining the maximum membership capacity (2^N members).
- `N`: Merkle tree depth, determining the maximum membership capacity (2^N members).
- `M`: Bit size for range checks, setting an upper bound for the number of messages per epoch (2^M messages).
> [!NOTE]
> However, if `N` is too big, this might require a larger Powers of Tau ceremony than the one hardcoded in `./scripts/build-circuits.sh`, which is `2^14`. \
> In such case, we refer to the official [Circom documentation](https://docs.circom.io/getting-started/proving-circuits/#powers-of-tau) for instructions on how to run an appropriate Powers of Tau ceremony and Phase 2 in order to compile the desired circuit. \
> Additionally, while `M` sets an upper bound on the number of messages per epoch (`2^M`), you can configure lower message limit for your use case, as long as it satisfies `user_message_limit ≤ 2^M`. \
> Currently, the `rln` module comes with a [pre-compiled](https://github.com/vacp2p/zerokit/tree/master/rln/resources) RLN circuit with a Merkle tree of height `20` and a bit size of `16`, allowing up to `2^20` registered members and a `2^16` message limit per epoch.
> However, if `N` is too big, this might require a larger Powers of Tau ceremony
> than the one hardcoded in `./scripts/build-circuits.sh`, which is `2^14`.
> In such case, we refer to the official
> [Circom documentation](https://docs.circom.io/getting-started/proving-circuits/#powers-of-tau)
> for instructions on how to run an appropriate Powers of Tau ceremony and Phase 2 in order to compile the desired circuit. \
> Additionally, while `M` sets an upper bound on the number of messages per epoch (`2^M`),
> you can configure lower message limit for your use case, as long as it satisfies `user_message_limit ≤ 2^M`. \
> Currently, the `rln` module comes with a [pre-compiled](https://github.com/vacp2p/zerokit/tree/master/rln/resources)
> RLN circuit with a Merkle tree of depth `20` and a bit size of `16`,
> allowing up to `2^20` registered members and a `2^16` message limit per epoch.
#### Install circom compiler
You can follow the instructions below or refer to the [installing Circom](https://docs.circom.io/getting-started/installation/#installing-circom) guide for more details, but make sure to use the specific version `v2.1.0`.
You can follow the instructions below or refer to the
[installing Circom](https://docs.circom.io/getting-started/installation/#installing-circom) guide for more details,
but make sure to use the specific version `v2.1.0`.
```sh
# Clone the circom repository
@@ -215,7 +250,8 @@ cp zkeyFiles/rln/final.zkey <path_to_rln_final.zkey>
### 2. Generate Witness Calculation Graph
The execution graph file used for witness calculation can be compiled following instructions in the [circom-witnesscalc](https://github.com/iden3/circom-witnesscalc) repository.
The execution graph file used for witness calculation can be compiled following instructions
in the [circom-witnesscalc](https://github.com/iden3/circom-witnesscalc) repository.
As mentioned in step 1, we should use `rln.circom` file from `circom-rln` repository.
```sh
@@ -232,11 +268,14 @@ cargo build
cargo run --package circom_witnesscalc --bin build-circuit ../circom-rln/circuits/rln.circom <path_to_graph.bin>
```
The `rln` module comes with [pre-compiled](https://github.com/vacp2p/zerokit/tree/master/rln/resources) execution graph files for the RLN circuit.
The `rln` module comes with [pre-compiled](https://github.com/vacp2p/zerokit/tree/master/rln/resources)
execution graph files for the RLN circuit.
### 3. Generate Arkzkey Representation for zkey and verification key files
For faster loading, compile the zkey file into the arkzkey format using [ark-zkey](https://github.com/seemenkina/ark-zkey). This is fork of the [original](https://github.com/zkmopro/ark-zkey) repository with the uncompressed zkey support.
For faster loading, compile the zkey file into the arkzkey format using
[ark-zkey](https://github.com/seemenkina/ark-zkey).
This is fork of the [original](https://github.com/zkmopro/ark-zkey) repository with the uncompressed zkey support.
```sh
# Clone the ark-zkey repository
@@ -249,7 +288,8 @@ cd ark-zkey && cargo build
cargo run --bin arkzkey-util <path_to_rln_final.zkey>
```
Currently, the `rln` module comes with [pre-compiled](https://github.com/vacp2p/zerokit/tree/master/rln/resources) arkzkey keys for the RLN circuit.
Currently, the `rln` module comes with
[pre-compiled](https://github.com/vacp2p/zerokit/tree/master/rln/resources) arkzkey keys for the RLN circuit.
## Get involved

View File

@@ -1,5 +1,5 @@
use ark_circom::read_zkey;
use criterion::{criterion_group, criterion_main, Criterion};
use rln::circuit::zkey::read_zkey;
use std::io::Cursor;
pub fn zkey_load_benchmark(c: &mut Criterion) {

View File

@@ -21,17 +21,11 @@ pub fn pmtree_benchmark(c: &mut Criterion) {
c.bench_function("Pmtree::override_range", |b| {
b.iter(|| {
tree.override_range(0, leaves.clone(), [0, 1, 2, 3])
tree.override_range(0, leaves.clone().into_iter(), [0, 1, 2, 3].into_iter())
.unwrap();
})
});
c.bench_function("Pmtree::compute_root", |b| {
b.iter(|| {
tree.compute_root().unwrap();
})
});
c.bench_function("Pmtree::get", |b| {
b.iter(|| {
tree.get(0).unwrap();

Binary file not shown.

7
rln/src/circuit/error.rs Normal file
View File

@@ -0,0 +1,7 @@
#[derive(Debug, thiserror::Error)]
pub enum ZKeyReadError {
#[error("No proving key found!")]
EmptyBytes,
#[error("{0}")]
SerializationError(#[from] ark_serialize::SerializationError),
}

View File

@@ -10,11 +10,11 @@ use std::{
cmp::Ordering,
collections::HashMap,
error::Error,
ops::{BitAnd, BitOr, BitXor, Deref, Shl, Shr},
ops::{Deref, Shl, Shr},
};
use crate::circuit::iden3calc::proto;
use crate::circuit::Fr;
use crate::iden3calc::proto;
pub const M: U256 =
uint!(21888242871839275222246405745257275088548364400416034343698204186575808495617_U256);

View File

@@ -7,7 +7,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use prost::Message;
use std::io::{Read, Write};
use crate::iden3calc::{
use crate::circuit::iden3calc::{
graph,
graph::{Operation, TresOperation, UnoOperation},
proto, InputSignalsInfo,

View File

@@ -1,31 +1,35 @@
// This crate provides interfaces for the zero-knowledge circuit and keys
pub mod error;
pub mod iden3calc;
pub mod qap;
pub mod zkey;
use ::lazy_static::lazy_static;
use ark_bn254::{
Bn254, Fq as ArkFq, Fq2 as ArkFq2, Fr as ArkFr, G1Affine as ArkG1Affine,
G1Projective as ArkG1Projective, G2Affine as ArkG2Affine, G2Projective as ArkG2Projective,
};
use ark_groth16::{ProvingKey, VerifyingKey};
use ark_groth16::ProvingKey;
use ark_relations::r1cs::ConstraintMatrices;
use cfg_if::cfg_if;
use color_eyre::{Report, Result};
use crate::iden3calc::calc_witness;
use crate::circuit::error::ZKeyReadError;
use crate::circuit::iden3calc::calc_witness;
#[cfg(feature = "arkzkey")]
use {
ark_ff::Field, ark_serialize::CanonicalDeserialize, ark_serialize::CanonicalSerialize,
color_eyre::eyre::WrapErr,
};
use {ark_ff::Field, ark_serialize::CanonicalDeserialize, ark_serialize::CanonicalSerialize};
#[cfg(not(feature = "arkzkey"))]
use {ark_circom::read_zkey, std::io::Cursor};
use {crate::circuit::zkey::read_zkey, std::io::Cursor};
#[cfg(feature = "arkzkey")]
pub const ARKZKEY_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln_final.arkzkey");
pub const ARKZKEY_BYTES: &[u8] = include_bytes!("../../resources/tree_height_20/rln_final.arkzkey");
pub const ZKEY_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/rln_final.zkey");
const GRAPH_BYTES: &[u8] = include_bytes!("../resources/tree_height_20/graph.bin");
pub const ZKEY_BYTES: &[u8] = include_bytes!("../../resources/tree_height_20/rln_final.zkey");
#[cfg(not(target_arch = "wasm32"))]
const GRAPH_BYTES: &[u8] = include_bytes!("../../resources/tree_height_20/graph.bin");
lazy_static! {
static ref ZKEY: (ProvingKey<Curve>, ConstraintMatrices<Fr>) = {
@@ -54,9 +58,11 @@ pub type G2Affine = ArkG2Affine;
pub type G2Projective = ArkG2Projective;
// Loads the proving key using a bytes vector
pub fn zkey_from_raw(zkey_data: &[u8]) -> Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>)> {
pub fn zkey_from_raw(
zkey_data: &[u8],
) -> Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>), ZKeyReadError> {
if zkey_data.is_empty() {
return Err(Report::msg("No proving key found!"));
return Err(ZKeyReadError::EmptyBytes);
}
let proving_key_and_matrices = match () {
@@ -73,30 +79,11 @@ pub fn zkey_from_raw(zkey_data: &[u8]) -> Result<(ProvingKey<Curve>, ConstraintM
}
// Loads the proving key
#[cfg(not(target_arch = "wasm32"))]
pub fn zkey_from_folder() -> &'static (ProvingKey<Curve>, ConstraintMatrices<Fr>) {
&ZKEY
}
// Loads the verification key from a bytes vector
pub fn vk_from_raw(zkey_data: &[u8]) -> Result<VerifyingKey<Curve>> {
if !zkey_data.is_empty() {
let (proving_key, _matrices) = zkey_from_raw(zkey_data)?;
return Ok(proving_key.vk);
}
Err(Report::msg("No proving/verification key found!"))
}
// Checks verification key to be correct with respect to proving key
pub fn check_vk_from_zkey(verifying_key: VerifyingKey<Curve>) -> Result<()> {
let (proving_key, _matrices) = zkey_from_folder();
if proving_key.vk == verifying_key {
Ok(())
} else {
Err(Report::msg("verifying_keys are not equal"))
}
}
pub fn calculate_rln_witness<I: IntoIterator<Item = (String, Vec<Fr>)>>(
inputs: I,
graph_data: &[u8],
@@ -104,6 +91,7 @@ pub fn calculate_rln_witness<I: IntoIterator<Item = (String, Vec<Fr>)>>(
calc_witness(inputs, graph_data)
}
#[cfg(not(target_arch = "wasm32"))]
pub fn graph_from_folder() -> &'static [u8] {
GRAPH_BYTES
}
@@ -140,20 +128,18 @@ pub struct SerializableMatrix<F: Field> {
#[cfg(feature = "arkzkey")]
pub fn read_arkzkey_from_bytes_uncompressed(
arkzkey_data: &[u8],
) -> Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>)> {
) -> Result<(ProvingKey<Curve>, ConstraintMatrices<Fr>), ZKeyReadError> {
if arkzkey_data.is_empty() {
return Err(Report::msg("No proving key found!"));
return Err(ZKeyReadError::EmptyBytes);
}
let mut cursor = std::io::Cursor::new(arkzkey_data);
let serialized_proving_key =
SerializableProvingKey::deserialize_uncompressed_unchecked(&mut cursor)
.wrap_err("Failed to deserialize proving key")?;
SerializableProvingKey::deserialize_uncompressed_unchecked(&mut cursor)?;
let serialized_constraint_matrices =
SerializableConstraintMatrices::deserialize_uncompressed_unchecked(&mut cursor)
.wrap_err("Failed to deserialize constraint matrices")?;
SerializableConstraintMatrices::deserialize_uncompressed_unchecked(&mut cursor)?;
// Get on right form for API
let proving_key: ProvingKey<Bn254> = serialized_proving_key.0;

114
rln/src/circuit/qap.rs Normal file
View File

@@ -0,0 +1,114 @@
// This file is based on the code by arkworks. Its preimage can be found here:
// https://github.com/arkworks-rs/circom-compat/blob/3c95ed98e23a408b4d99a53e483a9bba39685a4e/src/circom/qap.rs
use ark_ff::PrimeField;
use ark_groth16::r1cs_to_qap::{evaluate_constraint, LibsnarkReduction, R1CSToQAP};
use ark_poly::EvaluationDomain;
use ark_relations::r1cs::{ConstraintMatrices, ConstraintSystemRef, SynthesisError};
use ark_std::{cfg_into_iter, cfg_iter, cfg_iter_mut, vec};
/// Implements the witness map used by snarkjs. The arkworks witness map calculates the
/// coefficients of H through computing (AB-C)/Z in the evaluation domain and going back to the
/// coefficients domain. snarkjs instead precomputes the Lagrange form of the powers of tau bases
/// in a domain twice as large and the witness map is computed as the odd coefficients of (AB-C)
/// in that domain. This serves as HZ when computing the C proof element.
pub struct CircomReduction;
impl R1CSToQAP for CircomReduction {
#[allow(clippy::type_complexity)]
fn instance_map_with_evaluation<F: PrimeField, D: EvaluationDomain<F>>(
cs: ConstraintSystemRef<F>,
t: &F,
) -> Result<(Vec<F>, Vec<F>, Vec<F>, F, usize, usize), SynthesisError> {
LibsnarkReduction::instance_map_with_evaluation::<F, D>(cs, t)
}
fn witness_map_from_matrices<F: PrimeField, D: EvaluationDomain<F>>(
matrices: &ConstraintMatrices<F>,
num_inputs: usize,
num_constraints: usize,
full_assignment: &[F],
) -> Result<Vec<F>, SynthesisError> {
let zero = F::zero();
let domain =
D::new(num_constraints + num_inputs).ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
let domain_size = domain.size();
let mut a = vec![zero; domain_size];
let mut b = vec![zero; domain_size];
#[allow(unexpected_cfgs)]
cfg_iter_mut!(a[..num_constraints])
.zip(cfg_iter_mut!(b[..num_constraints]))
.zip(cfg_iter!(&matrices.a))
.zip(cfg_iter!(&matrices.b))
.for_each(|(((a, b), at_i), bt_i)| {
*a = evaluate_constraint(at_i, full_assignment);
*b = evaluate_constraint(bt_i, full_assignment);
});
{
let start = num_constraints;
let end = start + num_inputs;
a[start..end].clone_from_slice(&full_assignment[..num_inputs]);
}
let mut c = vec![zero; domain_size];
#[allow(unexpected_cfgs)]
cfg_iter_mut!(c[..num_constraints])
.zip(&a)
.zip(&b)
.for_each(|((c_i, &a), &b)| {
*c_i = a * b;
});
domain.ifft_in_place(&mut a);
domain.ifft_in_place(&mut b);
let root_of_unity = {
let domain_size_double = 2 * domain_size;
let domain_double =
D::new(domain_size_double).ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
domain_double.element(1)
};
D::distribute_powers_and_mul_by_const(&mut a, root_of_unity, F::one());
D::distribute_powers_and_mul_by_const(&mut b, root_of_unity, F::one());
domain.fft_in_place(&mut a);
domain.fft_in_place(&mut b);
let mut ab = domain.mul_polynomials_in_evaluation_domain(&a, &b);
drop(a);
drop(b);
domain.ifft_in_place(&mut c);
D::distribute_powers_and_mul_by_const(&mut c, root_of_unity, F::one());
domain.fft_in_place(&mut c);
#[allow(unexpected_cfgs)]
cfg_iter_mut!(ab)
.zip(c)
.for_each(|(ab_i, c_i)| *ab_i -= &c_i);
Ok(ab)
}
fn h_query_scalars<F: PrimeField, D: EvaluationDomain<F>>(
max_power: usize,
t: F,
_: F,
delta_inverse: F,
) -> Result<Vec<F>, SynthesisError> {
// the usual H query has domain-1 powers. Z has domain powers. So HZ has 2*domain-1 powers.
#[allow(unexpected_cfgs)]
let mut scalars = cfg_into_iter!(0..2 * max_power + 1)
.map(|i| delta_inverse * t.pow([i as u64]))
.collect::<Vec<_>>();
let domain_size = scalars.len();
let domain = D::new(domain_size).ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
// generate the lagrange coefficients
domain.ifft_in_place(&mut scalars);
#[allow(unexpected_cfgs)]
Ok(cfg_into_iter!(scalars).skip(1).step_by(2).collect())
}
}

371
rln/src/circuit/zkey.rs Normal file
View File

@@ -0,0 +1,371 @@
// This file is based on the code by arkworks. Its preimage can be found here:
// https://github.com/arkworks-rs/circom-compat/blob/3c95ed98e23a408b4d99a53e483a9bba39685a4e/src/zkey.rs
//! ZKey Parsing
//!
//! Each ZKey file is broken into sections:
//! Header(1)
//! Prover Type 1 Groth
//! HeaderGroth(2)
//! n8q
//! q
//! n8r
//! r
//! NVars
//! NPub
//! DomainSize (multiple of 2
//! alpha1
//! beta1
//! delta1
//! beta2
//! gamma2
//! delta2
//! IC(3)
//! Coefs(4)
//! PointsA(5)
//! PointsB1(6)
//! PointsB2(7)
//! PointsC(8)
//! PointsH(9)
//! Contributions(10)
use ark_ff::{BigInteger256, PrimeField};
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::{CanonicalDeserialize, SerializationError};
use ark_std::log2;
use byteorder::{LittleEndian, ReadBytesExt};
use std::{
collections::HashMap,
io::{Read, Seek, SeekFrom},
};
use ark_bn254::{Bn254, Fq, Fq2, Fr, G1Affine, G2Affine};
use ark_groth16::{ProvingKey, VerifyingKey};
use num_traits::Zero;
type IoResult<T> = Result<T, SerializationError>;
#[derive(Clone, Debug)]
struct Section {
position: u64,
#[allow(dead_code)]
size: usize,
}
/// Reads a SnarkJS ZKey file into an Arkworks ProvingKey.
pub fn read_zkey<R: Read + Seek>(
reader: &mut R,
) -> IoResult<(ProvingKey<Bn254>, ConstraintMatrices<Fr>)> {
let mut binfile = BinFile::new(reader)?;
let proving_key = binfile.proving_key()?;
let matrices = binfile.matrices()?;
Ok((proving_key, matrices))
}
#[derive(Debug)]
struct BinFile<'a, R> {
#[allow(dead_code)]
ftype: String,
#[allow(dead_code)]
version: u32,
sections: HashMap<u32, Vec<Section>>,
reader: &'a mut R,
}
impl<'a, R: Read + Seek> BinFile<'a, R> {
fn new(reader: &'a mut R) -> IoResult<Self> {
let mut magic = [0u8; 4];
reader.read_exact(&mut magic)?;
let version = reader.read_u32::<LittleEndian>()?;
let num_sections = reader.read_u32::<LittleEndian>()?;
let mut sections = HashMap::new();
for _ in 0..num_sections {
let section_id = reader.read_u32::<LittleEndian>()?;
let section_length = reader.read_u64::<LittleEndian>()?;
let section = sections.entry(section_id).or_insert_with(Vec::new);
section.push(Section {
position: reader.stream_position()?,
size: section_length as usize,
});
reader.seek(SeekFrom::Current(section_length as i64))?;
}
Ok(Self {
ftype: std::str::from_utf8(&magic[..]).unwrap().to_string(),
version,
sections,
reader,
})
}
fn proving_key(&mut self) -> IoResult<ProvingKey<Bn254>> {
let header = self.groth_header()?;
let ic = self.ic(header.n_public)?;
let a_query = self.a_query(header.n_vars)?;
let b_g1_query = self.b_g1_query(header.n_vars)?;
let b_g2_query = self.b_g2_query(header.n_vars)?;
let l_query = self.l_query(header.n_vars - header.n_public - 1)?;
let h_query = self.h_query(header.domain_size as usize)?;
let vk = VerifyingKey::<Bn254> {
alpha_g1: header.verifying_key.alpha_g1,
beta_g2: header.verifying_key.beta_g2,
gamma_g2: header.verifying_key.gamma_g2,
delta_g2: header.verifying_key.delta_g2,
gamma_abc_g1: ic,
};
let pk = ProvingKey::<Bn254> {
vk,
beta_g1: header.verifying_key.beta_g1,
delta_g1: header.verifying_key.delta_g1,
a_query,
b_g1_query,
b_g2_query,
h_query,
l_query,
};
Ok(pk)
}
fn get_section(&self, id: u32) -> Section {
self.sections.get(&id).unwrap()[0].clone()
}
fn groth_header(&mut self) -> IoResult<HeaderGroth> {
let section = self.get_section(2);
let header = HeaderGroth::new(&mut self.reader, &section)?;
Ok(header)
}
fn ic(&mut self, n_public: usize) -> IoResult<Vec<G1Affine>> {
// the range is non-inclusive so we do +1 to get all inputs
self.g1_section(n_public + 1, 3)
}
/// Returns the [`ConstraintMatrices`] corresponding to the zkey
pub fn matrices(&mut self) -> IoResult<ConstraintMatrices<Fr>> {
let header = self.groth_header()?;
let section = self.get_section(4);
self.reader.seek(SeekFrom::Start(section.position))?;
let num_coeffs: u32 = self.reader.read_u32::<LittleEndian>()?;
// insantiate AB
let mut matrices = vec![vec![vec![]; header.domain_size as usize]; 2];
let mut max_constraint_index = 0;
for _ in 0..num_coeffs {
let matrix: u32 = self.reader.read_u32::<LittleEndian>()?;
let constraint: u32 = self.reader.read_u32::<LittleEndian>()?;
let signal: u32 = self.reader.read_u32::<LittleEndian>()?;
let value: Fr = deserialize_field_fr(&mut self.reader)?;
max_constraint_index = std::cmp::max(max_constraint_index, constraint);
matrices[matrix as usize][constraint as usize].push((value, signal as usize));
}
let num_constraints = max_constraint_index as usize - header.n_public;
// Remove the public input constraints, Arkworks adds them later
matrices.iter_mut().for_each(|m| {
m.truncate(num_constraints);
});
// This is taken from Arkworks' to_matrices() function
let a = matrices[0].clone();
let b = matrices[1].clone();
let a_num_non_zero: usize = a.iter().map(|lc| lc.len()).sum();
let b_num_non_zero: usize = b.iter().map(|lc| lc.len()).sum();
let matrices = ConstraintMatrices {
num_instance_variables: header.n_public + 1,
num_witness_variables: header.n_vars - header.n_public,
num_constraints,
a_num_non_zero,
b_num_non_zero,
c_num_non_zero: 0,
a,
b,
c: vec![],
};
Ok(matrices)
}
fn a_query(&mut self, n_vars: usize) -> IoResult<Vec<G1Affine>> {
self.g1_section(n_vars, 5)
}
fn b_g1_query(&mut self, n_vars: usize) -> IoResult<Vec<G1Affine>> {
self.g1_section(n_vars, 6)
}
fn b_g2_query(&mut self, n_vars: usize) -> IoResult<Vec<G2Affine>> {
self.g2_section(n_vars, 7)
}
fn l_query(&mut self, n_vars: usize) -> IoResult<Vec<G1Affine>> {
self.g1_section(n_vars, 8)
}
fn h_query(&mut self, n_vars: usize) -> IoResult<Vec<G1Affine>> {
self.g1_section(n_vars, 9)
}
fn g1_section(&mut self, num: usize, section_id: usize) -> IoResult<Vec<G1Affine>> {
let section = self.get_section(section_id as u32);
self.reader.seek(SeekFrom::Start(section.position))?;
deserialize_g1_vec(self.reader, num as u32)
}
fn g2_section(&mut self, num: usize, section_id: usize) -> IoResult<Vec<G2Affine>> {
let section = self.get_section(section_id as u32);
self.reader.seek(SeekFrom::Start(section.position))?;
deserialize_g2_vec(self.reader, num as u32)
}
}
#[derive(Default, Clone, Debug, CanonicalDeserialize)]
pub struct ZVerifyingKey {
alpha_g1: G1Affine,
beta_g1: G1Affine,
beta_g2: G2Affine,
gamma_g2: G2Affine,
delta_g1: G1Affine,
delta_g2: G2Affine,
}
impl ZVerifyingKey {
fn new<R: Read>(reader: &mut R) -> IoResult<Self> {
let alpha_g1 = deserialize_g1(reader)?;
let beta_g1 = deserialize_g1(reader)?;
let beta_g2 = deserialize_g2(reader)?;
let gamma_g2 = deserialize_g2(reader)?;
let delta_g1 = deserialize_g1(reader)?;
let delta_g2 = deserialize_g2(reader)?;
Ok(Self {
alpha_g1,
beta_g1,
beta_g2,
gamma_g2,
delta_g1,
delta_g2,
})
}
}
#[derive(Clone, Debug)]
struct HeaderGroth {
#[allow(dead_code)]
n8q: u32,
#[allow(dead_code)]
q: BigInteger256,
#[allow(dead_code)]
n8r: u32,
#[allow(dead_code)]
r: BigInteger256,
n_vars: usize,
n_public: usize,
domain_size: u32,
#[allow(dead_code)]
power: u32,
verifying_key: ZVerifyingKey,
}
impl HeaderGroth {
fn new<R: Read + Seek>(reader: &mut R, section: &Section) -> IoResult<Self> {
reader.seek(SeekFrom::Start(section.position))?;
Self::read(reader)
}
fn read<R: Read>(mut reader: &mut R) -> IoResult<Self> {
// TODO: Impl From<u32> in Arkworks
let n8q: u32 = u32::deserialize_uncompressed(&mut reader)?;
// group order r of Bn254
let q = BigInteger256::deserialize_uncompressed(&mut reader)?;
let n8r: u32 = u32::deserialize_uncompressed(&mut reader)?;
// Prime field modulus
let r = BigInteger256::deserialize_uncompressed(&mut reader)?;
let n_vars = u32::deserialize_uncompressed(&mut reader)? as usize;
let n_public = u32::deserialize_uncompressed(&mut reader)? as usize;
let domain_size: u32 = u32::deserialize_uncompressed(&mut reader)?;
let power = log2(domain_size as usize);
let verifying_key = ZVerifyingKey::new(&mut reader)?;
Ok(Self {
n8q,
q,
n8r,
r,
n_vars,
n_public,
domain_size,
power,
verifying_key,
})
}
}
// need to divide by R, since snarkjs outputs the zkey with coefficients
// multiplieid by R^2
fn deserialize_field_fr<R: Read>(reader: &mut R) -> IoResult<Fr> {
let bigint = BigInteger256::deserialize_uncompressed(reader)?;
Ok(Fr::new_unchecked(Fr::new_unchecked(bigint).into_bigint()))
}
// skips the multiplication by R because Circom points are already in Montgomery form
fn deserialize_field<R: Read>(reader: &mut R) -> IoResult<Fq> {
let bigint = BigInteger256::deserialize_uncompressed(reader)?;
// if you use Fq::new it multiplies by R
Ok(Fq::new_unchecked(bigint))
}
pub fn deserialize_field2<R: Read>(reader: &mut R) -> IoResult<Fq2> {
let c0 = deserialize_field(reader)?;
let c1 = deserialize_field(reader)?;
Ok(Fq2::new(c0, c1))
}
fn deserialize_g1<R: Read>(reader: &mut R) -> IoResult<G1Affine> {
let x = deserialize_field(reader)?;
let y = deserialize_field(reader)?;
let infinity = x.is_zero() && y.is_zero();
if infinity {
Ok(G1Affine::identity())
} else {
Ok(G1Affine::new(x, y))
}
}
fn deserialize_g2<R: Read>(reader: &mut R) -> IoResult<G2Affine> {
let f1 = deserialize_field2(reader)?;
let f2 = deserialize_field2(reader)?;
let infinity = f1.is_zero() && f2.is_zero();
if infinity {
Ok(G2Affine::identity())
} else {
Ok(G2Affine::new(f1, f2))
}
}
fn deserialize_g1_vec<R: Read>(reader: &mut R, n_vars: u32) -> IoResult<Vec<G1Affine>> {
(0..n_vars).map(|_| deserialize_g1(reader)).collect()
}
fn deserialize_g2_vec<R: Read>(reader: &mut R, n_vars: u32) -> IoResult<Vec<G2Affine>> {
(0..n_vars).map(|_| deserialize_g2(reader)).collect()
}

77
rln/src/error.rs Normal file
View File

@@ -0,0 +1,77 @@
use crate::circuit::error::ZKeyReadError;
use ark_bn254::Fr;
use ark_relations::r1cs::SynthesisError;
use ark_serialize::SerializationError;
use num_bigint::{BigInt, ParseBigIntError};
use std::array::TryFromSliceError;
use std::num::TryFromIntError;
use std::string::FromUtf8Error;
use thiserror::Error;
use utils::error::{FromConfigError, ZerokitMerkleTreeError};
#[derive(Debug, thiserror::Error)]
pub enum ConversionError {
#[error("Expected radix 10 or 16")]
WrongRadix,
#[error("{0}")]
ParseBigInt(#[from] ParseBigIntError),
#[error("{0}")]
ToUsize(#[from] TryFromIntError),
#[error("{0}")]
FromSlice(#[from] TryFromSliceError),
}
#[derive(Error, Debug)]
pub enum ProofError {
#[error("{0}")]
ProtocolError(#[from] ProtocolError),
#[error("Error producing proof: {0}")]
SynthesisError(#[from] SynthesisError),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("{0}")]
Conversion(#[from] ConversionError),
#[error("Expected to read {0} bytes but read only {1} bytes")]
InvalidReadLen(usize, usize),
#[error("Cannot convert bigint {0:?} to biguint")]
BigUintConversion(BigInt),
#[error("{0}")]
JsonError(#[from] serde_json::Error),
#[error("Message id ({0}) is not within user_message_limit ({1})")]
InvalidMessageId(Fr, Fr),
}
#[derive(Debug, thiserror::Error)]
pub enum ComputeIdSecretError {
/// Usually it means that the same signal is used to recover the user secret hash
#[error("Cannot recover secret: division by zero")]
DivisionByZero,
}
#[derive(Debug, thiserror::Error)]
pub enum RLNError {
#[error("I/O error: {0}")]
IO(#[from] std::io::Error),
#[error("Utf8 error: {0}")]
Utf8(#[from] FromUtf8Error),
#[error("Serde json error: {0}")]
JSON(#[from] serde_json::Error),
#[error("Config error: {0}")]
Config(#[from] FromConfigError),
#[error("Serialization error: {0}")]
Serialization(#[from] SerializationError),
#[error("Merkle tree error: {0}")]
MerkleTree(#[from] ZerokitMerkleTreeError),
#[error("ZKey error: {0}")]
ZKey(#[from] ZKeyReadError),
#[error("Conversion error: {0}")]
Conversion(#[from] ConversionError),
#[error("Protocol error: {0}")]
Protocol(#[from] ProtocolError),
#[error("Proof error: {0}")]
Proof(#[from] ProofError),
#[error("Unable to extract secret")]
RecoverSecret(#[from] ComputeIdSecretError),
}

View File

@@ -466,6 +466,12 @@ pub extern "C" fn key_gen(ctx: *const RLN, output_buffer: *mut Buffer) -> bool {
call_with_output_arg!(ctx, key_gen, output_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn key_gen_be(ctx: *const RLN, output_buffer: *mut Buffer) -> bool {
call_with_output_arg!(ctx, key_gen_be, output_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn seeded_key_gen(
@@ -476,12 +482,28 @@ pub extern "C" fn seeded_key_gen(
call_with_output_arg!(ctx, seeded_key_gen, output_buffer, input_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn seeded_key_gen_be(
ctx: *const RLN,
input_buffer: *const Buffer,
output_buffer: *mut Buffer,
) -> bool {
call_with_output_arg!(ctx, seeded_key_gen_be, output_buffer, input_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn extended_key_gen(ctx: *const RLN, output_buffer: *mut Buffer) -> bool {
call_with_output_arg!(ctx, extended_key_gen, output_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn extended_key_gen_be(ctx: *const RLN, output_buffer: *mut Buffer) -> bool {
call_with_output_arg!(ctx, extended_key_gen_be, output_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn seeded_extended_key_gen(
@@ -492,6 +514,16 @@ pub extern "C" fn seeded_extended_key_gen(
call_with_output_arg!(ctx, seeded_extended_key_gen, output_buffer, input_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn seeded_extended_key_gen_be(
ctx: *const RLN,
input_buffer: *const Buffer,
output_buffer: *mut Buffer,
) -> bool {
call_with_output_arg!(ctx, seeded_extended_key_gen_be, output_buffer, input_buffer)
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
#[no_mangle]
pub extern "C" fn recover_id_secret(

View File

@@ -23,7 +23,7 @@ static POSEIDON: Lazy<Poseidon<Fr>> = Lazy::new(|| Poseidon::<Fr>::from(&ROUND_P
pub fn poseidon_hash(input: &[Fr]) -> Fr {
POSEIDON
.hash(input.to_vec())
.hash(input)
.expect("hash with fixed input size can't fail")
}

View File

@@ -1,8 +1,8 @@
#![allow(dead_code)]
pub mod circuit;
pub mod error;
#[cfg(not(target_arch = "wasm32"))]
pub mod ffi;
pub mod hashers;
pub mod iden3calc;
#[cfg(feature = "pmtree-ft")]
pub mod pm_tree_adapter;
pub mod poseidon_tree;
@@ -11,5 +11,3 @@ pub mod public;
#[cfg(test)]
pub mod public_api_tests;
pub mod utils;
pub mod ffi;

View File

@@ -1,17 +1,15 @@
use serde_json::Value;
use std::fmt::Debug;
use std::path::PathBuf;
use std::str::FromStr;
use color_eyre::{Report, Result};
use serde_json::Value;
use utils::pmtree::tree::Key;
use utils::pmtree::{Database, Hasher};
use utils::*;
use crate::circuit::Fr;
use crate::hashers::{poseidon_hash, PoseidonHash};
use crate::utils::{bytes_le_to_fr, fr_to_bytes_le};
use utils::error::{FromConfigError, ZerokitMerkleTreeError};
use utils::pmtree::tree::Key;
use utils::pmtree::{Database, Hasher, PmtreeErrorKind};
use utils::{pmtree, Config, Mode, SledDB, ZerokitMerkleProof, ZerokitMerkleTree};
const METADATA_KEY: [u8; 8] = *b"metadata";
@@ -63,9 +61,9 @@ fn get_tmp() -> bool {
pub struct PmtreeConfig(Config);
impl FromStr for PmtreeConfig {
type Err = Report;
type Err = FromConfigError;
fn from_str(s: &str) -> Result<Self> {
fn from_str(s: &str) -> Result<Self, Self::Err> {
let config: Value = serde_json::from_str(s)?;
let path = config["path"].as_str();
@@ -85,10 +83,7 @@ impl FromStr for PmtreeConfig {
&& temporary.unwrap()
&& path.as_ref().unwrap().exists()
{
return Err(Report::msg(format!(
"Path {:?} already exists, cannot use temporary",
path.unwrap()
)));
return Err(FromConfigError::PathExists);
}
let config = Config::new()
@@ -133,12 +128,16 @@ impl ZerokitMerkleTree for PmTree {
type Hasher = PoseidonHash;
type Config = PmtreeConfig;
fn default(depth: usize) -> Result<Self> {
fn default(depth: usize) -> Result<Self, ZerokitMerkleTreeError> {
let default_config = PmtreeConfig::default();
PmTree::new(depth, Self::Hasher::default_leaf(), default_config)
}
fn new(depth: usize, _default_leaf: FrOf<Self::Hasher>, config: Self::Config) -> Result<Self> {
fn new(
depth: usize,
_default_leaf: FrOf<Self::Hasher>,
config: Self::Config,
) -> Result<Self, ZerokitMerkleTreeError> {
let tree_loaded = pmtree::MerkleTree::load(config.clone().0);
let tree = match tree_loaded {
Ok(tree) => tree,
@@ -168,14 +167,12 @@ impl ZerokitMerkleTree for PmTree {
self.tree.root()
}
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
Ok(self.tree.root())
}
fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>) -> Result<()> {
self.tree
.set(index, leaf)
.map_err(|e| Report::msg(e.to_string()))?;
fn set(
&mut self,
index: usize,
leaf: FrOf<Self::Hasher>,
) -> Result<(), ZerokitMerkleTreeError> {
self.tree.set(index, leaf)?;
self.cached_leaves_indices[index] = 1;
Ok(())
}
@@ -184,27 +181,31 @@ impl ZerokitMerkleTree for PmTree {
&mut self,
start: usize,
values: I,
) -> Result<()> {
) -> Result<(), ZerokitMerkleTreeError> {
let v = values.into_iter().collect::<Vec<_>>();
self.tree
.set_range(start, v.clone().into_iter())
.map_err(|e| Report::msg(e.to_string()))?;
self.tree.set_range(start, v.clone().into_iter())?;
for i in start..v.len() {
self.cached_leaves_indices[i] = 1
}
Ok(())
}
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>> {
self.tree.get(index).map_err(|e| Report::msg(e.to_string()))
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>, ZerokitMerkleTreeError> {
self.tree
.get(index)
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind)
}
fn get_subtree_root(&self, n: usize, index: usize) -> Result<FrOf<Self::Hasher>> {
fn get_subtree_root(
&self,
n: usize,
index: usize,
) -> Result<FrOf<Self::Hasher>, ZerokitMerkleTreeError> {
if n > self.depth() {
return Err(Report::msg("level exceeds depth size"));
return Err(ZerokitMerkleTreeError::InvalidLevel);
}
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
if n == 0 {
Ok(self.root())
@@ -235,55 +236,66 @@ impl ZerokitMerkleTree for PmTree {
start: usize,
leaves: I,
indices: J,
) -> Result<()> {
) -> Result<(), ZerokitMerkleTreeError> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
let mut indices = indices.into_iter().collect::<Vec<_>>();
indices.sort();
match (leaves.len(), indices.len()) {
(0, 0) => Err(Report::msg("no leaves or indices to be removed")),
(0, 0) => Err(ZerokitMerkleTreeError::InvalidLeaf),
(1, 0) => self.set(start, leaves[0]),
(0, 1) => self.delete(indices[0]),
(_, 0) => self.set_range(start, leaves),
(0, _) => self.remove_indices(&indices),
(_, _) => self.remove_indices_and_set_leaves(start, leaves, &indices),
(_, 0) => self.set_range(start, leaves.into_iter()),
(0, _) => self
.remove_indices(&indices)
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind),
(_, _) => self
.remove_indices_and_set_leaves(start, leaves, &indices)
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind),
}
}
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> {
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<(), ZerokitMerkleTreeError> {
self.tree
.update_next(leaf)
.map_err(|e| Report::msg(e.to_string()))
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind)
}
fn delete(&mut self, index: usize) -> Result<()> {
fn delete(&mut self, index: usize) -> Result<(), ZerokitMerkleTreeError> {
self.tree
.delete(index)
.map_err(|e| Report::msg(e.to_string()))?;
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind)?;
self.cached_leaves_indices[index] = 0;
Ok(())
}
fn proof(&self, index: usize) -> Result<Self::Proof> {
fn proof(&self, index: usize) -> Result<Self::Proof, ZerokitMerkleTreeError> {
let proof = self.tree.proof(index)?;
Ok(PmTreeProof { proof })
}
fn verify(&self, leaf: &FrOf<Self::Hasher>, witness: &Self::Proof) -> Result<bool> {
fn verify(
&self,
leaf: &FrOf<Self::Hasher>,
witness: &Self::Proof,
) -> Result<bool, ZerokitMerkleTreeError> {
if self.tree.verify(leaf, &witness.proof) {
Ok(true)
} else {
Err(Report::msg("verify failed"))
Err(ZerokitMerkleTreeError::InvalidWitness)
}
}
fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> {
self.tree.db.put(METADATA_KEY, metadata.to_vec())?;
fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), ZerokitMerkleTreeError> {
self.tree
.db
.put(METADATA_KEY, metadata.to_vec())
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind)?;
self.metadata = metadata.to_vec();
Ok(())
}
fn metadata(&self) -> Result<Vec<u8>> {
fn metadata(&self) -> Result<Vec<u8>, ZerokitMerkleTreeError> {
if !self.metadata.is_empty() {
return Ok(self.metadata.clone());
}
@@ -297,8 +309,11 @@ impl ZerokitMerkleTree for PmTree {
Ok(data.unwrap())
}
fn close_db_connection(&mut self) -> Result<()> {
self.tree.db.close().map_err(|e| Report::msg(e.to_string()))
fn close_db_connection(&mut self) -> Result<(), ZerokitMerkleTreeError> {
self.tree
.db
.close()
.map_err(ZerokitMerkleTreeError::PmtreeErrorKind)
}
}
@@ -306,15 +321,13 @@ type PmTreeHasher = <PmTree as ZerokitMerkleTree>::Hasher;
type FrOfPmTreeHasher = FrOf<PmTreeHasher>;
impl PmTree {
fn remove_indices(&mut self, indices: &[usize]) -> Result<()> {
fn remove_indices(&mut self, indices: &[usize]) -> Result<(), PmtreeErrorKind> {
let start = indices[0];
let end = indices.last().unwrap() + 1;
let new_leaves = (start..end).map(|_| PmTreeHasher::default_leaf());
self.tree
.set_range(start, new_leaves)
.map_err(|e| Report::msg(e.to_string()))?;
self.tree.set_range(start, new_leaves)?;
for i in start..end {
self.cached_leaves_indices[i] = 0
@@ -327,7 +340,7 @@ impl PmTree {
start: usize,
leaves: Vec<FrOfPmTreeHasher>,
indices: &[usize],
) -> Result<()> {
) -> Result<(), PmtreeErrorKind> {
let min_index = *indices.first().unwrap();
let max_index = start + leaves.len();
@@ -344,9 +357,7 @@ impl PmTree {
set_values[start - min_index + i] = leaf;
}
self.tree
.set_range(start, set_values)
.map_err(|e| Report::msg(e.to_string()))?;
self.tree.set_range(start, set_values)?;
for i in indices {
self.cached_leaves_indices[*i] = 0;

View File

@@ -4,26 +4,33 @@
use cfg_if::cfg_if;
cfg_if! {
if #[cfg(feature = "pmtree-ft")] {
use crate::pm_tree_adapter::*;
} else {
use crate::hashers::{PoseidonHash};
use utils::merkle_tree::*;
}
}
// The zerokit RLN default Merkle tree implementation is the OptimalMerkleTree.
// To switch to FullMerkleTree implementation, it is enough to enable the fullmerkletree feature
cfg_if! {
if #[cfg(feature = "fullmerkletree")] {
use utils::{
FullMerkleTree,
FullMerkleProof,
};
use crate::hashers::PoseidonHash;
pub type PoseidonTree = FullMerkleTree<PoseidonHash>;
pub type MerkleProof = FullMerkleProof<PoseidonHash>;
} else if #[cfg(feature = "pmtree-ft")] {
use crate::pm_tree_adapter::{PmTree, PmTreeProof};
pub type PoseidonTree = PmTree;
pub type MerkleProof = PmTreeProof;
} else {
use crate::hashers::{PoseidonHash};
use utils::{
OptimalMerkleTree,
OptimalMerkleProof,
};
pub type PoseidonTree = OptimalMerkleTree<PoseidonHash>;
pub type MerkleProof = OptimalMerkleProof<PoseidonHash>;
}

View File

@@ -1,28 +1,30 @@
// This crate collects all the underlying primitives used to implement RLN
use ark_bn254::Fr;
use ark_circom::CircomReduction;
use ark_ff::AdditiveGroup;
use ark_groth16::{prepare_verifying_key, Groth16, Proof as ArkProof, ProvingKey, VerifyingKey};
use ark_relations::r1cs::{ConstraintMatrices, SynthesisError};
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{rand::thread_rng, UniformRand};
use color_eyre::{Report, Result};
use num_bigint::BigInt;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use serde::{Deserialize, Serialize};
#[cfg(test)]
use std::time::Instant;
use thiserror::Error;
use tiny_keccak::{Hasher as _, Keccak};
use crate::circuit::{calculate_rln_witness, Curve};
use crate::circuit::{calculate_rln_witness, qap::CircomReduction, Curve};
use crate::error::{ComputeIdSecretError, ConversionError, ProofError, ProtocolError};
use crate::hashers::{hash_to_field, poseidon_hash};
use crate::poseidon_tree::*;
use crate::poseidon_tree::{MerkleProof, PoseidonTree};
use crate::public::RLN_IDENTIFIER;
use crate::utils::*;
use crate::utils::{
bytes_be_to_fr, bytes_le_to_fr, bytes_le_to_vec_fr, bytes_le_to_vec_u8, fr_byte_size,
fr_to_bytes_be, fr_to_bytes_le, normalize_usize, to_bigint, vec_fr_to_bytes_le,
vec_u8_to_bytes_le,
};
use utils::{ZerokitMerkleProof, ZerokitMerkleTree};
///////////////////////////////////////////////////////
// RLN Witness data structure and utility functions
///////////////////////////////////////////////////////
@@ -59,9 +61,17 @@ pub fn serialize_field_element(element: Fr) -> Vec<u8> {
fr_to_bytes_le(&element)
}
pub fn serialize_field_element_be(element: Fr) -> Vec<u8> {
fr_to_bytes_be(&element)
}
pub fn deserialize_field_element(serialized: Vec<u8>) -> Fr {
let (element, _) = bytes_le_to_fr(&serialized);
element
}
pub fn deserialize_field_element_be(serialized: Vec<u8>) -> Fr {
let (element, _) = bytes_be_to_fr(&serialized);
element
}
@@ -72,6 +82,13 @@ pub fn deserialize_identity_pair(serialized: Vec<u8>) -> (Fr, Fr) {
(identity_secret_hash, id_commitment)
}
pub fn deserialize_identity_pair_be(serialized: Vec<u8>) -> (Fr, Fr) {
let (identity_secret_hash, read) = bytes_be_to_fr(&serialized);
let (id_commitment, _) = bytes_be_to_fr(&serialized[read..]);
(identity_secret_hash, id_commitment)
}
pub fn deserialize_identity_tuple(serialized: Vec<u8>) -> (Fr, Fr, Fr, Fr) {
let mut all_read = 0;
@@ -94,13 +111,35 @@ pub fn deserialize_identity_tuple(serialized: Vec<u8>) -> (Fr, Fr, Fr, Fr) {
)
}
pub fn deserialize_identity_tuple_be(serialized: Vec<u8>) -> (Fr, Fr, Fr, Fr) {
let mut all_read = 0;
let (identity_trapdoor, read) = bytes_be_to_fr(&serialized[all_read..]);
all_read += read;
let (identity_nullifier, read) = bytes_be_to_fr(&serialized[all_read..]);
all_read += read;
let (identity_secret_hash, read) = bytes_be_to_fr(&serialized[all_read..]);
all_read += read;
let (identity_commitment, _) = bytes_be_to_fr(&serialized[all_read..]);
(
identity_trapdoor,
identity_nullifier,
identity_secret_hash,
identity_commitment,
)
}
/// Serializes witness
///
/// # Errors
///
/// Returns an error if `rln_witness.message_id` is not within `rln_witness.user_message_limit`.
/// input data is [ identity_secret<32> | user_message_limit<32> | message_id<32> | path_elements[<32>] | identity_path_index<8> | x<32> | external_nullifier<32> ]
pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Result<Vec<u8>> {
/// input data is [ identity_secret<32> | user_message_limit<32> | message_id<32> | path_elements<32> | identity_path_index<8> | x<32> | external_nullifier<32> ]
pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Result<Vec<u8>, ProtocolError> {
// Check if message_id is within user_message_limit
message_id_range_check(&rln_witness.message_id, &rln_witness.user_message_limit)?;
@@ -115,8 +154,8 @@ pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Result<Vec<u8>> {
serialized.extend_from_slice(&fr_to_bytes_le(&rln_witness.identity_secret));
serialized.extend_from_slice(&fr_to_bytes_le(&rln_witness.user_message_limit));
serialized.extend_from_slice(&fr_to_bytes_le(&rln_witness.message_id));
serialized.extend_from_slice(&vec_fr_to_bytes_le(&rln_witness.path_elements)?);
serialized.extend_from_slice(&vec_u8_to_bytes_le(&rln_witness.identity_path_index)?);
serialized.extend_from_slice(&vec_fr_to_bytes_le(&rln_witness.path_elements));
serialized.extend_from_slice(&vec_u8_to_bytes_le(&rln_witness.identity_path_index));
serialized.extend_from_slice(&fr_to_bytes_le(&rln_witness.x));
serialized.extend_from_slice(&fr_to_bytes_le(&rln_witness.external_nullifier));
@@ -128,7 +167,7 @@ pub fn serialize_witness(rln_witness: &RLNWitnessInput) -> Result<Vec<u8>> {
/// # Errors
///
/// Returns an error if `message_id` is not within `user_message_limit`.
pub fn deserialize_witness(serialized: &[u8]) -> Result<(RLNWitnessInput, usize)> {
pub fn deserialize_witness(serialized: &[u8]) -> Result<(RLNWitnessInput, usize), ProtocolError> {
let mut all_read: usize = 0;
let (identity_secret, read) = bytes_le_to_fr(&serialized[all_read..]);
@@ -155,7 +194,7 @@ pub fn deserialize_witness(serialized: &[u8]) -> Result<(RLNWitnessInput, usize)
all_read += read;
if serialized.len() != all_read {
return Err(Report::msg("serialized length is not equal to all_read"));
return Err(ProtocolError::InvalidReadLen(serialized.len(), all_read));
}
Ok((
@@ -179,15 +218,18 @@ pub fn deserialize_witness(serialized: &[u8]) -> Result<(RLNWitnessInput, usize)
pub fn proof_inputs_to_rln_witness(
tree: &mut PoseidonTree,
serialized: &[u8],
) -> Result<(RLNWitnessInput, usize)> {
) -> Result<(RLNWitnessInput, usize), ProtocolError> {
let mut all_read: usize = 0;
let (identity_secret, read) = bytes_le_to_fr(&serialized[all_read..]);
all_read += read;
let id_index = usize::try_from(u64::from_le_bytes(
serialized[all_read..all_read + 8].try_into()?,
))?;
serialized[all_read..all_read + 8]
.try_into()
.map_err(ConversionError::FromSlice)?,
))
.map_err(ConversionError::ToUsize)?;
all_read += 8;
let (user_message_limit, read) = bytes_le_to_fr(&serialized[all_read..]);
@@ -200,8 +242,11 @@ pub fn proof_inputs_to_rln_witness(
all_read += read;
let signal_len = usize::try_from(u64::from_le_bytes(
serialized[all_read..all_read + 8].try_into()?,
))?;
serialized[all_read..all_read + 8]
.try_into()
.map_err(ConversionError::FromSlice)?,
))
.map_err(ConversionError::ToUsize)?;
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + signal_len].to_vec();
@@ -238,7 +283,7 @@ pub fn rln_witness_from_values(
external_nullifier: Fr,
user_message_limit: Fr,
message_id: Fr,
) -> Result<RLNWitnessInput> {
) -> Result<RLNWitnessInput, ProtocolError> {
message_id_range_check(&message_id, &user_message_limit)?;
let path_elements = merkle_proof.get_path_elements();
@@ -285,7 +330,9 @@ pub fn random_rln_witness(tree_height: usize) -> RLNWitnessInput {
}
}
pub fn proof_values_from_witness(rln_witness: &RLNWitnessInput) -> Result<RLNProofValues> {
pub fn proof_values_from_witness(
rln_witness: &RLNWitnessInput,
) -> Result<RLNProofValues, ProtocolError> {
message_id_range_check(&rln_witness.message_id, &rln_witness.user_message_limit)?;
// y share
@@ -503,7 +550,7 @@ pub fn extended_seeded_keygen(signal: &[u8]) -> (Fr, Fr, Fr, Fr) {
)
}
pub fn compute_id_secret(share1: (Fr, Fr), share2: (Fr, Fr)) -> Result<Fr, String> {
pub fn compute_id_secret(share1: (Fr, Fr), share2: (Fr, Fr)) -> Result<Fr, ComputeIdSecretError> {
// Assuming a0 is the identity secret and a1 = poseidonHash([a0, external_nullifier]),
// a (x,y) share satisfies the following relation
// y = a_0 + x * a_1
@@ -513,30 +560,25 @@ pub fn compute_id_secret(share1: (Fr, Fr), share2: (Fr, Fr)) -> Result<Fr, Strin
// If the two input shares were computed for the same external_nullifier and identity secret, we can recover the latter
// y1 = a_0 + x1 * a_1
// y2 = a_0 + x2 * a_1
let a_1 = (y1 - y2) / (x1 - x2);
let a_0 = y1 - x1 * a_1;
// If shares come from the same polynomial, a0 is correctly recovered and a1 = poseidonHash([a0, external_nullifier])
Ok(a_0)
if (x1 - x2) != Fr::ZERO {
let a_1 = (y1 - y2) / (x1 - x2);
let a_0 = y1 - x1 * a_1;
// If shares come from the same polynomial, a0 is correctly recovered and a1 = poseidonHash([a0, external_nullifier])
Ok(a_0)
} else {
Err(ComputeIdSecretError::DivisionByZero)
}
}
///////////////////////////////////////////////////////
// zkSNARK utility functions
///////////////////////////////////////////////////////
#[derive(Error, Debug)]
pub enum ProofError {
#[error("Error reading circuit key: {0}")]
CircuitKeyError(#[from] Report),
#[error("Error producing witness: {0}")]
WitnessError(Report),
#[error("Error producing proof: {0}")]
SynthesisError(#[from] SynthesisError),
}
fn calculate_witness_element<E: ark_ec::pairing::Pairing>(
witness: Vec<BigInt>,
) -> Result<Vec<E::ScalarField>> {
) -> Result<Vec<E::ScalarField>, ProtocolError> {
use ark_ff::PrimeField;
let modulus = <E::ScalarField as PrimeField>::MODULUS;
@@ -549,9 +591,9 @@ fn calculate_witness_element<E: ark_ec::pairing::Pairing>(
modulus.into()
- w.abs()
.to_biguint()
.ok_or(Report::msg("not a biguint value"))?
.ok_or(ProtocolError::BigUintConversion(w))?
} else {
w.to_biguint().ok_or(Report::msg("not a biguint value"))?
w.to_biguint().ok_or(ProtocolError::BigUintConversion(w))?
};
witness_vec.push(E::ScalarField::from(w))
}
@@ -567,8 +609,7 @@ pub fn generate_proof_with_witness(
#[cfg(test)]
let now = Instant::now();
let full_assignment =
calculate_witness_element::<Curve>(witness).map_err(ProofError::WitnessError)?;
let full_assignment = calculate_witness_element::<Curve>(witness)?;
#[cfg(test)]
println!("witness generation took: {:.2?}", now.elapsed());
@@ -605,7 +646,7 @@ pub fn generate_proof_with_witness(
/// Returns an error if `rln_witness.message_id` is not within `rln_witness.user_message_limit`.
pub fn inputs_for_witness_calculation(
rln_witness: &RLNWitnessInput,
) -> Result<[(&str, Vec<Fr>); 7]> {
) -> Result<[(&str, Vec<Fr>); 7], ProtocolError> {
message_id_range_check(&rln_witness.message_id, &rln_witness.user_message_limit)?;
let mut identity_path_index = Vec::with_capacity(rln_witness.identity_path_index.len());
@@ -733,7 +774,9 @@ where
/// # Errors
///
/// Returns an error if `rln_witness.message_id` is not within `rln_witness.user_message_limit`.
pub fn rln_witness_from_json(input_json: serde_json::Value) -> Result<RLNWitnessInput> {
pub fn rln_witness_from_json(
input_json: serde_json::Value,
) -> Result<RLNWitnessInput, ProtocolError> {
let rln_witness: RLNWitnessInput = serde_json::from_value(input_json).unwrap();
message_id_range_check(&rln_witness.message_id, &rln_witness.user_message_limit)?;
@@ -745,7 +788,9 @@ pub fn rln_witness_from_json(input_json: serde_json::Value) -> Result<RLNWitness
/// # Errors
///
/// Returns an error if `message_id` is not within `user_message_limit`.
pub fn rln_witness_to_json(rln_witness: &RLNWitnessInput) -> Result<serde_json::Value> {
pub fn rln_witness_to_json(
rln_witness: &RLNWitnessInput,
) -> Result<serde_json::Value, ProtocolError> {
message_id_range_check(&rln_witness.message_id, &rln_witness.user_message_limit)?;
let rln_witness_json = serde_json::to_value(rln_witness)?;
@@ -758,13 +803,15 @@ pub fn rln_witness_to_json(rln_witness: &RLNWitnessInput) -> Result<serde_json::
/// # Errors
///
/// Returns an error if `message_id` is not within `user_message_limit`.
pub fn rln_witness_to_bigint_json(rln_witness: &RLNWitnessInput) -> Result<serde_json::Value> {
pub fn rln_witness_to_bigint_json(
rln_witness: &RLNWitnessInput,
) -> Result<serde_json::Value, ProtocolError> {
message_id_range_check(&rln_witness.message_id, &rln_witness.user_message_limit)?;
let mut path_elements = Vec::new();
for v in rln_witness.path_elements.iter() {
path_elements.push(to_bigint(v)?.to_str_radix(10));
path_elements.push(to_bigint(v).to_str_radix(10));
}
let mut identity_path_index = Vec::new();
@@ -774,22 +821,26 @@ pub fn rln_witness_to_bigint_json(rln_witness: &RLNWitnessInput) -> Result<serde
.for_each(|v| identity_path_index.push(BigInt::from(*v).to_str_radix(10)));
let inputs = serde_json::json!({
"identitySecret": to_bigint(&rln_witness.identity_secret)?.to_str_radix(10),
"userMessageLimit": to_bigint(&rln_witness.user_message_limit)?.to_str_radix(10),
"messageId": to_bigint(&rln_witness.message_id)?.to_str_radix(10),
"identitySecret": to_bigint(&rln_witness.identity_secret).to_str_radix(10),
"userMessageLimit": to_bigint(&rln_witness.user_message_limit).to_str_radix(10),
"messageId": to_bigint(&rln_witness.message_id).to_str_radix(10),
"pathElements": path_elements,
"identityPathIndex": identity_path_index,
"x": to_bigint(&rln_witness.x)?.to_str_radix(10),
"externalNullifier": to_bigint(&rln_witness.external_nullifier)?.to_str_radix(10),
"x": to_bigint(&rln_witness.x).to_str_radix(10),
"externalNullifier": to_bigint(&rln_witness.external_nullifier).to_str_radix(10),
});
Ok(inputs)
}
pub fn message_id_range_check(message_id: &Fr, user_message_limit: &Fr) -> Result<()> {
pub fn message_id_range_check(
message_id: &Fr,
user_message_limit: &Fr,
) -> Result<(), ProtocolError> {
if message_id > user_message_limit {
return Err(color_eyre::Report::msg(
"message_id is not within user_message_limit",
return Err(ProtocolError::InvalidMessageId(
*message_id,
*user_message_limit,
));
}
Ok(())

View File

@@ -1,3 +1,24 @@
use crate::circuit::{zkey_from_raw, Curve, Fr};
use crate::hashers::{hash_to_field, poseidon_hash as utils_poseidon_hash};
use crate::protocol::{
compute_id_secret, deserialize_proof_values, deserialize_witness, extended_keygen,
extended_seeded_keygen, generate_proof, keygen, proof_inputs_to_rln_witness,
proof_values_from_witness, rln_witness_to_bigint_json, rln_witness_to_json, seeded_keygen,
serialize_proof_values, serialize_witness, verify_proof,
};
use crate::utils::{
bytes_be_to_vec_fr, bytes_le_to_fr, bytes_le_to_vec_fr, bytes_le_to_vec_u8, fr_byte_size,
fr_to_bytes_be, fr_to_bytes_le, vec_fr_to_bytes_le, vec_u8_to_bytes_le,
};
#[cfg(not(target_arch = "wasm32"))]
use {
crate::circuit::{graph_from_folder, zkey_from_folder},
std::default::Default,
};
#[cfg(target_arch = "wasm32")]
use crate::protocol::generate_proof_with_witness;
/// This is the main public API for RLN module. It is used by the FFI, and should be
/// used by tests etc. as well
#[cfg(not(feature = "stateless"))]
@@ -8,16 +29,14 @@ use {
utils::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree},
};
use crate::circuit::{graph_from_folder, vk_from_raw, zkey_from_folder, zkey_from_raw, Curve, Fr};
use crate::hashers::{hash_to_field, poseidon_hash as utils_poseidon_hash};
use crate::protocol::*;
use crate::utils::*;
use crate::error::{ConversionError, ProtocolError, RLNError};
use ark_groth16::{Proof as ArkProof, ProvingKey, VerifyingKey};
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, Write};
use color_eyre::{Report, Result};
use std::{default::Default, io::Cursor};
#[cfg(target_arch = "wasm32")]
use num_bigint::BigInt;
use std::io::Cursor;
use utils::error::ZerokitMerkleTreeError;
/// The application-specific RLN identifier.
///
@@ -32,6 +51,7 @@ pub const RLN_IDENTIFIER: &[u8] = b"zerokit/rln/010203040506070809";
pub struct RLN {
proving_key: (ProvingKey<Curve>, ConstraintMatrices<Fr>),
pub(crate) verification_key: VerifyingKey<Curve>,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) graph_data: Vec<u8>,
#[cfg(not(feature = "stateless"))]
pub(crate) tree: PoseidonTree,
@@ -54,8 +74,8 @@ impl RLN {
/// // We create a new RLN instance
/// let mut rln = RLN::new(tree_height, input);
/// ```
#[cfg(not(feature = "stateless"))]
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> Result<RLN> {
#[cfg(all(not(target_arch = "wasm32"), not(feature = "stateless")))]
pub fn new<R: Read>(tree_height: usize, mut input_data: R) -> Result<RLN, RLNError> {
// We read input
let mut input: Vec<u8> = Vec::new();
input_data.read_to_end(&mut input)?;
@@ -63,9 +83,9 @@ impl RLN {
let rln_config: Value = serde_json::from_str(&String::from_utf8(input)?)?;
let tree_config = rln_config["tree_config"].to_string();
let proving_key = zkey_from_folder();
let verification_key = &proving_key.0.vk;
let graph_data = graph_from_folder();
let proving_key = zkey_from_folder().to_owned();
let verification_key = proving_key.0.vk.to_owned();
let graph_data = graph_from_folder().to_owned();
let tree_config: <PoseidonTree as ZerokitMerkleTree>::Config = if tree_config.is_empty() {
<PoseidonTree as ZerokitMerkleTree>::Config::default()
@@ -81,9 +101,9 @@ impl RLN {
)?;
Ok(RLN {
proving_key: proving_key.to_owned(),
verification_key: verification_key.to_owned(),
graph_data: graph_data.to_vec(),
proving_key,
verification_key,
graph_data,
#[cfg(not(feature = "stateless"))]
tree,
})
@@ -97,17 +117,16 @@ impl RLN {
/// // We create a new RLN instance
/// let mut rln = RLN::new();
/// ```
#[cfg_attr(docsrs, doc(cfg(feature = "stateless")))]
#[cfg(feature = "stateless")]
pub fn new() -> Result<RLN> {
let proving_key = zkey_from_folder();
let verification_key = &proving_key.0.vk;
let graph_data = graph_from_folder();
#[cfg(all(not(target_arch = "wasm32"), feature = "stateless"))]
pub fn new() -> Result<RLN, RLNError> {
let proving_key = zkey_from_folder().to_owned();
let verification_key = proving_key.0.vk.to_owned();
let graph_data = graph_from_folder().to_owned();
Ok(RLN {
proving_key: proving_key.to_owned(),
verification_key: verification_key.to_owned(),
graph_data: graph_data.to_vec(),
proving_key,
verification_key,
graph_data,
})
}
@@ -116,6 +135,7 @@ impl RLN {
/// Input parameters are
/// - `tree_height`: the height of the internal Merkle tree
/// - `zkey_vec`: a byte vector containing to the proving key (`rln_final.zkey`) or (`rln_final.arkzkey`) as binary file
/// - `graph_data`: a byte vector containing the graph data (`graph.bin`) as binary file
/// - `tree_config_input`: a reader for a string containing a json with the merkle tree configuration
///
/// Example:
@@ -143,19 +163,18 @@ impl RLN {
/// tree_height,
/// resources[0].clone(),
/// resources[1].clone(),
/// resources[2].clone(),
/// tree_config_buffer,
/// );
/// ```
#[cfg(not(feature = "stateless"))]
#[cfg(all(not(target_arch = "wasm32"), not(feature = "stateless")))]
pub fn new_with_params<R: Read>(
tree_height: usize,
zkey_vec: Vec<u8>,
graph_data: Vec<u8>,
mut tree_config_input: R,
) -> Result<RLN> {
) -> Result<RLN, RLNError> {
let proving_key = zkey_from_raw(&zkey_vec)?;
let verification_key = vk_from_raw(&zkey_vec)?;
let verification_key = proving_key.0.vk.to_owned();
let mut tree_config_vec: Vec<u8> = Vec::new();
tree_config_input.read_to_end(&mut tree_config_vec)?;
@@ -187,6 +206,7 @@ impl RLN {
///
/// Input parameters are
/// - `zkey_vec`: a byte vector containing to the proving key (`rln_final.zkey`) or (`rln_final.arkzkey`) as binary file
/// - `graph_data`: a byte vector containing the graph data (`graph.bin`) as binary file
///
/// Example:
/// ```
@@ -208,13 +228,12 @@ impl RLN {
/// let mut rln = RLN::new_with_params(
/// resources[0].clone(),
/// resources[1].clone(),
/// resources[2].clone(),
/// );
/// ```
#[cfg(feature = "stateless")]
pub fn new_with_params(zkey_vec: Vec<u8>, graph_data: Vec<u8>) -> Result<RLN> {
#[cfg(all(not(target_arch = "wasm32"), feature = "stateless"))]
pub fn new_with_params(zkey_vec: Vec<u8>, graph_data: Vec<u8>) -> Result<RLN, RLNError> {
let proving_key = zkey_from_raw(&zkey_vec)?;
let verification_key = vk_from_raw(&zkey_vec)?;
let verification_key = proving_key.0.vk.to_owned();
Ok(RLN {
proving_key,
@@ -223,6 +242,36 @@ impl RLN {
})
}
/// Creates a new stateless RLN object by passing circuit resources as a byte vector.
///
/// Input parameters are
/// - `zkey_vec`: a byte vector containing the proving key (`rln_final.zkey`) or (`rln_final.arkzkey`) as binary file
///
/// Example:
/// ```
/// use std::fs::File;
/// use std::io::Read;
///
/// let zkey_path = "./resources/tree_height_20/rln_final.zkey";
///
/// let mut file = File::open(zkey_path).expect("Failed to open file");
/// let metadata = std::fs::metadata(zkey_path).expect("Failed to read metadata");
/// let mut zkey_vec = vec![0; metadata.len() as usize];
/// file.read_exact(&mut zkey_vec).expect("Failed to read file");
///
/// let mut rln = RLN::new_with_params(zkey_vec)?;
/// ```
#[cfg(all(target_arch = "wasm32", feature = "stateless"))]
pub fn new_with_params(zkey_vec: Vec<u8>) -> Result<RLN, RLNError> {
let proving_key = zkey_from_raw(&zkey_vec)?;
let verification_key = proving_key.0.vk.to_owned();
Ok(RLN {
proving_key,
verification_key,
})
}
////////////////////////////////////////////////////////
// Merkle-tree APIs
////////////////////////////////////////////////////////
@@ -233,7 +282,7 @@ impl RLN {
/// Input values are:
/// - `tree_height`: the height of the Merkle tree.
#[cfg(not(feature = "stateless"))]
pub fn set_tree(&mut self, tree_height: usize) -> Result<()> {
pub fn set_tree(&mut self, tree_height: usize) -> Result<(), RLNError> {
// We compute a default empty tree of desired height
self.tree = PoseidonTree::default(tree_height)?;
@@ -264,7 +313,7 @@ impl RLN {
/// rln.set_leaf(id_index, &mut buffer).unwrap();
/// ```
#[cfg(not(feature = "stateless"))]
pub fn set_leaf<R: Read>(&mut self, index: usize, mut input_data: R) -> Result<()> {
pub fn set_leaf<R: Read>(&mut self, index: usize, mut input_data: R) -> Result<(), RLNError> {
// We read input
let mut leaf_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut leaf_byte)?;
@@ -294,7 +343,7 @@ impl RLN {
/// rln.get_leaf(id_index, &mut buffer).unwrap();
/// let rate_commitment = deserialize_field_element(&buffer.into_inner()).unwrap();
#[cfg(not(feature = "stateless"))]
pub fn get_leaf<W: Write>(&self, index: usize, mut output_data: W) -> Result<()> {
pub fn get_leaf<W: Write>(&self, index: usize, mut output_data: W) -> Result<(), RLNError> {
// We get the leaf at input index
let leaf = self.tree.get(index)?;
@@ -337,7 +386,11 @@ impl RLN {
/// rln.set_leaves_from(index, &mut buffer).unwrap();
/// ```
#[cfg(not(feature = "stateless"))]
pub fn set_leaves_from<R: Read>(&mut self, index: usize, mut input_data: R) -> Result<()> {
pub fn set_leaves_from<R: Read>(
&mut self,
index: usize,
mut input_data: R,
) -> Result<(), RLNError> {
// We read input
let mut leaves_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut leaves_byte)?;
@@ -346,8 +399,7 @@ impl RLN {
// We set the leaves
self.tree
.override_range(index, leaves, [])
.map_err(|_| Report::msg("Could not set leaves"))?;
.override_range(index, leaves.into_iter(), [].into_iter())?;
Ok(())
}
@@ -358,7 +410,7 @@ impl RLN {
/// Input values are:
/// - `input_data`: a reader for the serialization of multiple leaf values (serialization done with [`rln::utils::vec_fr_to_bytes_le`](crate::utils::vec_fr_to_bytes_le))
#[cfg(not(feature = "stateless"))]
pub fn init_tree_with_leaves<R: Read>(&mut self, input_data: R) -> Result<()> {
pub fn init_tree_with_leaves<R: Read>(&mut self, input_data: R) -> Result<(), RLNError> {
// reset the tree
// NOTE: this requires the tree to be initialized with the correct height initially
// TODO: accept tree_height as a parameter and initialize the tree with that height
@@ -414,7 +466,7 @@ impl RLN {
index: usize,
mut input_leaves: R,
mut input_indices: R,
) -> Result<()> {
) -> Result<(), RLNError> {
// We read input
let mut leaves_byte: Vec<u8> = Vec::new();
input_leaves.read_to_end(&mut leaves_byte)?;
@@ -429,8 +481,7 @@ impl RLN {
// We set the leaves
self.tree
.override_range(index, leaves, indices)
.map_err(|e| Report::msg(format!("Could not perform the batch operation: {e}")))?;
.override_range(index, leaves.into_iter(), indices.into_iter())?;
Ok(())
}
@@ -483,7 +534,7 @@ impl RLN {
/// rln.set_next_leaf(&mut buffer).unwrap();
/// ```
#[cfg(not(feature = "stateless"))]
pub fn set_next_leaf<R: Read>(&mut self, mut input_data: R) -> Result<()> {
pub fn set_next_leaf<R: Read>(&mut self, mut input_data: R) -> Result<(), RLNError> {
// We read input
let mut leaf_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut leaf_byte)?;
@@ -509,7 +560,7 @@ impl RLN {
/// rln.delete_leaf(index).unwrap();
/// ```
#[cfg(not(feature = "stateless"))]
pub fn delete_leaf(&mut self, index: usize) -> Result<()> {
pub fn delete_leaf(&mut self, index: usize) -> Result<(), RLNError> {
self.tree.delete(index)?;
Ok(())
}
@@ -527,7 +578,7 @@ impl RLN {
/// rln.set_metadata(metadata).unwrap();
/// ```
#[cfg(not(feature = "stateless"))]
pub fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> {
pub fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), RLNError> {
self.tree.set_metadata(metadata)?;
Ok(())
}
@@ -547,7 +598,7 @@ impl RLN {
/// let metadata = buffer.into_inner();
/// ```
#[cfg(not(feature = "stateless"))]
pub fn get_metadata<W: Write>(&self, mut output_data: W) -> Result<()> {
pub fn get_metadata<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let metadata = self.tree.metadata()?;
output_data.write_all(&metadata)?;
Ok(())
@@ -567,10 +618,9 @@ impl RLN {
/// let (root, _) = bytes_le_to_fr(&buffer.into_inner());
/// ```
#[cfg(not(feature = "stateless"))]
pub fn get_root<W: Write>(&self, mut output_data: W) -> Result<()> {
pub fn get_root<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let root = self.tree.root();
output_data.write_all(&fr_to_bytes_le(&root))?;
Ok(())
}
@@ -595,7 +645,7 @@ impl RLN {
level: usize,
index: usize,
mut output_data: W,
) -> Result<()> {
) -> Result<(), RLNError> {
let subroot = self.tree.get_subtree_root(level, index)?;
output_data.write_all(&fr_to_bytes_le(&subroot))?;
@@ -624,13 +674,14 @@ impl RLN {
/// let (identity_path_index, _) = bytes_le_to_vec_u8(&buffer_inner[read..].to_vec());
/// ```
#[cfg(not(feature = "stateless"))]
pub fn get_proof<W: Write>(&self, index: usize, mut output_data: W) -> Result<()> {
pub fn get_proof<W: Write>(&self, index: usize, mut output_data: W) -> Result<(), RLNError> {
let merkle_proof = self.tree.proof(index).expect("proof should exist");
let path_elements = merkle_proof.get_path_elements();
let identity_path_index = merkle_proof.get_path_index();
output_data.write_all(&vec_fr_to_bytes_le(&path_elements)?)?;
output_data.write_all(&vec_u8_to_bytes_le(&identity_path_index)?)?;
// Note: unwrap safe - vec_fr_to_bytes_le & vec_u8_to_bytes_le are infallible
output_data.write_all(&vec_fr_to_bytes_le(&path_elements))?;
output_data.write_all(&vec_u8_to_bytes_le(&identity_path_index))?;
Ok(())
}
@@ -668,7 +719,7 @@ impl RLN {
/// assert_eq!(idxs, [0, 1, 2, 3, 4]);
/// ```
#[cfg(not(feature = "stateless"))]
pub fn get_empty_leaves_indices<W: Write>(&self, mut output_data: W) -> Result<()> {
pub fn get_empty_leaves_indices<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let idxs = self.tree.get_empty_leaves_indices();
idxs.serialize_compressed(&mut output_data)?;
Ok(())
@@ -698,15 +749,16 @@ impl RLN {
/// rln.prove(&mut input_buffer, &mut output_buffer).unwrap();
/// let zk_proof = output_buffer.into_inner();
/// ```
#[cfg(not(target_arch = "wasm32"))]
pub fn prove<R: Read, W: Write>(
&mut self,
mut input_data: R,
mut output_data: W,
) -> Result<()> {
) -> Result<(), RLNError> {
// We read input RLN witness and we serialize_compressed it
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let (rln_witness, _) = deserialize_witness(&serialized)?;
let mut serialized_witness: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized_witness)?;
let (rln_witness, _) = deserialize_witness(&serialized_witness)?;
let proof = generate_proof(&self.proving_key, &rln_witness, &self.graph_data)?;
@@ -752,7 +804,7 @@ impl RLN {
///
/// assert!(verified);
/// ```
pub fn verify<R: Read>(&self, mut input_data: R) -> Result<bool> {
pub fn verify<R: Read>(&self, mut input_data: R) -> Result<bool, RLNError> {
// Input data is serialized for Curve as:
// serialized_proof (compressed, 4*32 bytes) || serialized_proof_values (6*32 bytes), i.e.
// [ proof<128> | root<32> | external_nullifier<32> | x<32> | y<32> | nullifier<32> ]
@@ -818,12 +870,12 @@ impl RLN {
/// // proof_data is [ proof<128> | root<32> | external_nullifier<32> | x<32> | y<32> | nullifier<32>]
/// let mut proof_data = output_buffer.into_inner();
/// ```
#[cfg(not(feature = "stateless"))]
#[cfg(all(not(target_arch = "wasm32"), not(feature = "stateless")))]
pub fn generate_rln_proof<R: Read, W: Write>(
&mut self,
mut input_data: R,
mut output_data: W,
) -> Result<()> {
) -> Result<(), RLNError> {
// We read input RLN witness and we serialize_compressed it
let mut witness_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut witness_byte)?;
@@ -840,18 +892,18 @@ impl RLN {
Ok(())
}
// Generate RLN Proof using a witness calculated from outside zerokit
//
// output_data is [ proof<128> | root<32> | external_nullifier<32> | x<32> | y<32> | nullifier<32>]
// we skip it from documentation for now
/// Generate RLN Proof using a witness calculated from outside zerokit
///
/// output_data is [ proof<128> | root<32> | external_nullifier<32> | x<32> | y<32> | nullifier<32>]
#[cfg(not(target_arch = "wasm32"))]
pub fn generate_rln_proof_with_witness<R: Read, W: Write>(
&mut self,
mut input_data: R,
mut output_data: W,
) -> Result<()> {
let mut witness_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut witness_byte)?;
let (rln_witness, _) = deserialize_witness(&witness_byte)?;
) -> Result<(), RLNError> {
let mut serialized_witness: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized_witness)?;
let (rln_witness, _) = deserialize_witness(&serialized_witness)?;
let proof_values = proof_values_from_witness(&rln_witness)?;
let proof = generate_proof(&self.proving_key, &rln_witness, &self.graph_data)?;
@@ -863,6 +915,28 @@ impl RLN {
Ok(())
}
/// Generate RLN Proof using a witness calculated from outside zerokit
///
/// output_data is [ proof<128> | root<32> | external_nullifier<32> | x<32> | y<32> | nullifier<32>]
#[cfg(target_arch = "wasm32")]
pub fn generate_rln_proof_with_witness<W: Write>(
&mut self,
calculated_witness: Vec<BigInt>,
serialized_witness: Vec<u8>,
mut output_data: W,
) -> Result<(), RLNError> {
let (rln_witness, _) = deserialize_witness(&serialized_witness[..])?;
let proof_values = proof_values_from_witness(&rln_witness)?;
let proof = generate_proof_with_witness(calculated_witness, &self.proving_key).unwrap();
// Note: we export a serialization of ark-groth16::Proof not semaphore::Proof
// This proof is compressed, i.e. 128 bytes long
proof.serialize_compressed(&mut output_data)?;
output_data.write_all(&serialize_proof_values(&proof_values))?;
Ok(())
}
/// Verifies a zkSNARK RLN proof against the provided proof values and the state of the internal Merkle tree.
///
/// Input values are:
@@ -891,7 +965,7 @@ impl RLN {
/// assert!(verified);
/// ```
#[cfg(not(feature = "stateless"))]
pub fn verify_rln_proof<R: Read>(&self, mut input_data: R) -> Result<bool> {
pub fn verify_rln_proof<R: Read>(&self, mut input_data: R) -> Result<bool, RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let mut all_read = 0;
@@ -902,8 +976,11 @@ impl RLN {
all_read += read;
let signal_len = usize::try_from(u64::from_le_bytes(
serialized[all_read..all_read + 8].try_into()?,
))?;
serialized[all_read..all_read + 8]
.try_into()
.map_err(ConversionError::FromSlice)?,
))
.map_err(ConversionError::from)?;
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + signal_len].to_vec();
@@ -966,7 +1043,11 @@ impl RLN {
///
/// assert!(verified);
/// ```
pub fn verify_with_roots<R: Read>(&self, mut input_data: R, mut roots_data: R) -> Result<bool> {
pub fn verify_with_roots<R: Read>(
&self,
mut input_data: R,
mut roots_data: R,
) -> Result<bool, RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let mut all_read = 0;
@@ -977,8 +1058,11 @@ impl RLN {
all_read += read;
let signal_len = usize::try_from(u64::from_le_bytes(
serialized[all_read..all_read + 8].try_into()?,
))?;
serialized[all_read..all_read + 8]
.try_into()
.map_err(ConversionError::FromSlice)?,
))
.map_err(ConversionError::ToUsize)?;
all_read += 8;
let signal: Vec<u8> = serialized[all_read..all_read + signal_len].to_vec();
@@ -1047,7 +1131,7 @@ impl RLN {
/// // We serialize_compressed the keygen output
/// let (identity_secret_hash, id_commitment) = deserialize_identity_pair(buffer.into_inner());
/// ```
pub fn key_gen<W: Write>(&self, mut output_data: W) -> Result<()> {
pub fn key_gen<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let (identity_secret_hash, id_commitment) = keygen();
output_data.write_all(&fr_to_bytes_le(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_le(&id_commitment))?;
@@ -1055,6 +1139,23 @@ impl RLN {
Ok(())
}
/// Same as key_gen but serialized in BE format
pub fn key_gen_be<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let (identity_secret_hash, id_commitment) = keygen();
output_data.write_all(&fr_to_bytes_be(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_be(&id_commitment))?;
Ok(())
}
pub fn key_gen_be_2<W: Write>(&self, mut output_data: W) -> Result<(Fr, Fr), RLNError> {
let (identity_secret_hash, id_commitment) = keygen();
output_data.write_all(&fr_to_bytes_be(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_be(&id_commitment))?;
Ok((identity_secret_hash, id_commitment))
}
/// Returns an identity trapdoor, nullifier, secret and commitment tuple.
///
/// The identity secret is the Poseidon hash of the identity trapdoor and identity nullifier.
@@ -1077,7 +1178,7 @@ impl RLN {
/// // We serialize_compressed the keygen output
/// let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) = deserialize_identity_tuple(buffer.into_inner());
/// ```
pub fn extended_key_gen<W: Write>(&self, mut output_data: W) -> Result<()> {
pub fn extended_key_gen<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) =
extended_keygen();
output_data.write_all(&fr_to_bytes_le(&identity_trapdoor))?;
@@ -1088,6 +1189,29 @@ impl RLN {
Ok(())
}
/// Same as extend_key_gen but serialized in BE format.
pub fn extended_key_gen_be<W: Write>(&self, mut output_data: W) -> Result<(), RLNError> {
let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) =
extended_keygen();
output_data.write_all(&fr_to_bytes_be(&identity_trapdoor))?;
output_data.write_all(&fr_to_bytes_be(&identity_nullifier))?;
output_data.write_all(&fr_to_bytes_be(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_be(&id_commitment))?;
Ok(())
}
pub fn extended_key_gen_be_2<W: Write>(&self, mut output_data: W) -> Result<(Fr, Fr, Fr, Fr), RLNError> {
let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) =
extended_keygen();
output_data.write_all(&fr_to_bytes_be(&identity_trapdoor))?;
output_data.write_all(&fr_to_bytes_be(&identity_nullifier))?;
output_data.write_all(&fr_to_bytes_be(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_be(&id_commitment))?;
Ok((identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment))
}
/// Returns an identity secret and identity commitment pair generated using a seed.
///
/// The identity commitment is the Poseidon hash of the identity secret.
@@ -1116,7 +1240,7 @@ impl RLN {
&self,
mut input_data: R,
mut output_data: W,
) -> Result<()> {
) -> Result<(), RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@@ -1127,6 +1251,22 @@ impl RLN {
Ok(())
}
/// Same as seeded_key_gen but in BE format
pub fn seeded_key_gen_be<R: Read, W: Write>(
&self,
mut input_data: R,
mut output_data: W,
) -> Result<(), RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let (identity_secret_hash, id_commitment) = seeded_keygen(&serialized);
output_data.write_all(&fr_to_bytes_be(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_be(&id_commitment))?;
Ok(())
}
/// Returns an identity trapdoor, nullifier, secret and commitment tuple generated using a seed.
///
/// The identity secret is the Poseidon hash of the identity trapdoor and identity nullifier.
@@ -1159,7 +1299,7 @@ impl RLN {
&self,
mut input_data: R,
mut output_data: W,
) -> Result<()> {
) -> Result<(), RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@@ -1173,6 +1313,25 @@ impl RLN {
Ok(())
}
/// same as seeded_extended_key_gen but in BE format
pub fn seeded_extended_key_gen_be<R: Read, W: Write>(
&self,
mut input_data: R,
mut output_data: W,
) -> Result<(), RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let (identity_trapdoor, identity_nullifier, identity_secret_hash, id_commitment) =
extended_seeded_keygen(&serialized);
output_data.write_all(&fr_to_bytes_be(&identity_trapdoor))?;
output_data.write_all(&fr_to_bytes_be(&identity_nullifier))?;
output_data.write_all(&fr_to_bytes_be(&identity_secret_hash))?;
output_data.write_all(&fr_to_bytes_be(&id_commitment))?;
Ok(())
}
/// Recovers the identity secret from two set of proof values computed for same secret in same epoch with same rln identifier.
///
/// Input values are:
@@ -1212,7 +1371,7 @@ impl RLN {
mut input_proof_data_1: R,
mut input_proof_data_2: R,
mut output_data: W,
) -> Result<()> {
) -> Result<(), RLNError> {
// We serialize_compressed the two proofs, and we get the corresponding RLNProofValues objects
let mut serialized: Vec<u8> = Vec::new();
input_proof_data_1.read_to_end(&mut serialized)?;
@@ -1236,14 +1395,11 @@ impl RLN {
let share2 = (proof_values_2.x, proof_values_2.y);
// We recover the secret
let recovered_identity_secret_hash = compute_id_secret(share1, share2);
let recovered_identity_secret_hash =
compute_id_secret(share1, share2).map_err(RLNError::RecoverSecret)?;
// If an identity secret hash is recovered, we write it to output_data, otherwise nothing will be written.
if let Ok(identity_secret_hash) = recovered_identity_secret_hash {
output_data.write_all(&fr_to_bytes_le(&identity_secret_hash))?;
} else {
return Err(Report::msg("could not extract secret"));
}
output_data.write_all(&fr_to_bytes_le(&recovered_identity_secret_hash))?;
}
Ok(())
@@ -1256,13 +1412,16 @@ impl RLN {
///
/// The function returns the corresponding [`RLNWitnessInput`] object serialized using [`rln::protocol::serialize_witness`](crate::protocol::serialize_witness).
#[cfg(not(feature = "stateless"))]
pub fn get_serialized_rln_witness<R: Read>(&mut self, mut input_data: R) -> Result<Vec<u8>> {
pub fn get_serialized_rln_witness<R: Read>(
&mut self,
mut input_data: R,
) -> Result<Vec<u8>, RLNError> {
// We read input RLN witness and we serialize_compressed it
let mut witness_byte: Vec<u8> = Vec::new();
input_data.read_to_end(&mut witness_byte)?;
let (rln_witness, _) = proof_inputs_to_rln_witness(&mut self.tree, &witness_byte)?;
serialize_witness(&rln_witness)
serialize_witness(&rln_witness).map_err(RLNError::Protocol)
}
/// Converts a byte serialization of a [`RLNWitnessInput`] object to the corresponding JSON serialization.
@@ -1271,7 +1430,10 @@ impl RLN {
/// - `serialized_witness`: the byte serialization of a [`RLNWitnessInput`] object (serialization done with [`rln::protocol::serialize_witness`](crate::protocol::serialize_witness)).
///
/// The function returns the corresponding JSON encoding of the input [`RLNWitnessInput`] object.
pub fn get_rln_witness_json(&mut self, serialized_witness: &[u8]) -> Result<serde_json::Value> {
pub fn get_rln_witness_json(
&mut self,
serialized_witness: &[u8],
) -> Result<serde_json::Value, ProtocolError> {
let (rln_witness, _) = deserialize_witness(serialized_witness)?;
rln_witness_to_json(&rln_witness)
}
@@ -1286,7 +1448,7 @@ impl RLN {
pub fn get_rln_witness_bigint_json(
&mut self,
serialized_witness: &[u8],
) -> Result<serde_json::Value> {
) -> Result<serde_json::Value, ProtocolError> {
let (rln_witness, _) = deserialize_witness(serialized_witness)?;
rln_witness_to_bigint_json(&rln_witness)
}
@@ -1296,11 +1458,12 @@ impl RLN {
/// If not called, the connection will be closed when the RLN object is dropped.
/// This improves robustness of the tree.
#[cfg(not(feature = "stateless"))]
pub fn flush(&mut self) -> Result<()> {
pub fn flush(&mut self) -> Result<(), ZerokitMerkleTreeError> {
self.tree.close_db_connection()
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Default for RLN {
fn default() -> Self {
#[cfg(not(feature = "stateless"))]
@@ -1336,7 +1499,10 @@ impl Default for RLN {
/// // We serialize_compressed the keygen output
/// let field_element = deserialize_field_element(output_buffer.into_inner());
/// ```
pub fn hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> Result<()> {
pub fn hash<R: Read, W: Write>(
mut input_data: R,
mut output_data: W,
) -> Result<(), std::io::Error> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@@ -1346,6 +1512,20 @@ pub fn hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> Result<
Ok(())
}
/// same as hash function but in BE format
pub fn hash_be<R: Read, W: Write>(
mut input_data: R,
mut output_data: W,
) -> Result<(), std::io::Error> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let hash = hash_to_field(&serialized);
output_data.write_all(&fr_to_bytes_be(&hash))?;
Ok(())
}
/// Hashes a set of elements to a single element in the working prime field, using Poseidon.
///
/// The result is computed as the Poseidon Hash of the input signal.
@@ -1369,7 +1549,10 @@ pub fn hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> Result<
/// // We serialize_compressed the hash output
/// let hash_result = deserialize_field_element(output_buffer.into_inner());
/// ```
pub fn poseidon_hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -> Result<()> {
pub fn poseidon_hash<R: Read, W: Write>(
mut input_data: R,
mut output_data: W,
) -> Result<(), RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
@@ -1379,3 +1562,18 @@ pub fn poseidon_hash<R: Read, W: Write>(mut input_data: R, mut output_data: W) -
Ok(())
}
/// same as poseidon_hash function but in BE format. Note that input is expected in BE format too.
pub fn poseidon_hash_be<R: Read, W: Write>(
mut input_data: R,
mut output_data: W,
) -> Result<(), RLNError> {
let mut serialized: Vec<u8> = Vec::new();
input_data.read_to_end(&mut serialized)?;
let (inputs, _) = bytes_be_to_vec_fr(&serialized)?;
let hash = utils_poseidon_hash(inputs.as_ref());
output_data.write_all(&fr_to_bytes_be(&hash))?;
Ok(())
}

View File

@@ -1,7 +1,10 @@
use crate::circuit::TEST_TREE_HEIGHT;
use crate::protocol::*;
use crate::protocol::{
proof_values_from_witness, random_rln_witness, serialize_proof_values, serialize_witness,
verify_proof, RLNProofValues,
};
use crate::public::RLN;
use crate::utils::*;
use crate::utils::{generate_input_buffer, str_to_fr};
use ark_groth16::Proof as ArkProof;
use ark_serialize::CanonicalDeserialize;
use std::io::Cursor;
@@ -231,7 +234,7 @@ mod tree_test {
rln.set_tree(tree_height).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@@ -289,7 +292,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@@ -303,11 +306,11 @@ mod tree_test {
// `init_tree_with_leaves` resets the tree to the height it was initialized with, using `set_tree`
// We add leaves in a batch starting from index 0..set_index
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[0..set_index]).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[0..set_index]));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We add the remaining n leaves in a batch starting from index m
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[set_index..]).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves[set_index..]));
rln.set_leaves_from(set_index, &mut buffer).unwrap();
// We check if number of leaves set is consistent
@@ -359,7 +362,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@@ -377,8 +380,8 @@ mod tree_test {
let last_leaf_index = no_of_leaves - 1;
let indices = vec![last_leaf_index as u8];
let last_leaf = vec![*last_leaf];
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap());
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&last_leaf).unwrap());
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices));
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&last_leaf));
rln.atomic_operation(last_leaf_index, leaves_buffer, indices_buffer)
.unwrap();
@@ -408,7 +411,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@@ -422,8 +425,8 @@ mod tree_test {
let zero_index = 0;
let indices = vec![zero_index as u8];
let zero_leaf: Vec<Fr> = vec![];
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap());
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf).unwrap());
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices));
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf));
rln.atomic_operation(0, leaves_buffer, indices_buffer)
.unwrap();
@@ -452,7 +455,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// We check if number of leaves set is consistent
@@ -466,8 +469,8 @@ mod tree_test {
let set_index = rng.gen_range(0..no_of_leaves) as usize;
let indices = vec![set_index as u8];
let zero_leaf: Vec<Fr> = vec![];
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices).unwrap());
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf).unwrap());
let indices_buffer = Cursor::new(vec_u8_to_bytes_le(&indices));
let leaves_buffer = Cursor::new(vec_fr_to_bytes_le(&zero_leaf));
rln.atomic_operation(0, leaves_buffer, indices_buffer)
.unwrap();
@@ -486,7 +489,6 @@ mod tree_test {
assert_eq!(received_leaf, Fr::from(0));
}
#[allow(unused_must_use)]
#[test]
// This test checks if `set_leaves_from` throws an error when the index is out of bounds
fn test_set_leaves_bad_index() {
@@ -510,7 +512,9 @@ mod tree_test {
let (root_empty, _) = bytes_le_to_fr(&buffer.into_inner());
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
#[allow(unused_must_use)]
rln.set_leaves_from(bad_index, &mut buffer)
.expect_err("Should throw an error");
@@ -598,7 +602,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// Generate identity pair
@@ -670,7 +674,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// Generate identity pair
@@ -753,7 +757,7 @@ mod tree_test {
let mut rln = RLN::new(tree_height, generate_input_buffer()).unwrap();
// We add leaves in a batch into the tree
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves).unwrap());
let mut buffer = Cursor::new(vec_fr_to_bytes_le(&leaves));
rln.init_tree_with_leaves(&mut buffer).unwrap();
// Generate identity pair

View File

@@ -1,29 +1,28 @@
// This crate provides cross-module useful utilities (mainly type conversions) not necessarily specific to RLN
use crate::circuit::Fr;
use crate::error::ConversionError;
use ark_ff::PrimeField;
use color_eyre::{Report, Result};
use num_bigint::{BigInt, BigUint};
use num_traits::Num;
use serde_json::json;
use std::io::Cursor;
use crate::circuit::Fr;
#[inline(always)]
pub fn to_bigint(el: &Fr) -> Result<BigInt> {
Ok(BigUint::from(*el).into())
pub fn to_bigint(el: &Fr) -> BigInt {
BigUint::from(*el).into()
}
#[inline(always)]
pub fn fr_byte_size() -> usize {
pub const fn fr_byte_size() -> usize {
let mbs = <Fr as PrimeField>::MODULUS_BIT_SIZE;
((mbs + 64 - (mbs % 64)) / 8) as usize
}
#[inline(always)]
pub fn str_to_fr(input: &str, radix: u32) -> Result<Fr> {
pub fn str_to_fr(input: &str, radix: u32) -> Result<Fr, ConversionError> {
if !(radix == 10 || radix == 16) {
return Err(Report::msg("wrong radix"));
return Err(ConversionError::WrongRadix);
}
// We remove any quote present and we trim
@@ -48,6 +47,15 @@ pub fn bytes_le_to_fr(input: &[u8]) -> (Fr, usize) {
)
}
#[inline(always)]
pub fn bytes_be_to_fr(input: &[u8]) -> (Fr, usize) {
let el_size = fr_byte_size();
(
Fr::from(BigUint::from_bytes_be(&input[0..el_size])),
el_size,
)
}
#[inline(always)]
pub fn fr_to_bytes_le(input: &Fr) -> Vec<u8> {
let input_biguint: BigUint = (*input).into();
@@ -58,7 +66,20 @@ pub fn fr_to_bytes_le(input: &Fr) -> Vec<u8> {
}
#[inline(always)]
pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Result<Vec<u8>> {
pub fn fr_to_bytes_be(input: &Fr) -> Vec<u8> {
let input_biguint: BigUint = (*input).into();
let mut res = input_biguint.to_bytes_be();
// For BE, insert 0 at the start of the Vec (see also fr_to_bytes_le comments)
let to_insert_count = fr_byte_size().saturating_sub(res.len());
if to_insert_count > 0 {
// Insert multi 0 at index 0
res.splice(0..0, std::iter::repeat_n(0, to_insert_count));
}
res
}
#[inline(always)]
pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Vec<u8> {
// Calculate capacity for Vec:
// - 8 bytes for normalized vector length (usize)
// - each Fr element requires fr_byte_size() bytes (typically 32 bytes)
@@ -72,11 +93,11 @@ pub fn vec_fr_to_bytes_le(input: &[Fr]) -> Result<Vec<u8>> {
bytes.extend_from_slice(&fr_to_bytes_le(el));
}
Ok(bytes)
bytes
}
#[inline(always)]
pub fn vec_u8_to_bytes_le(input: &[u8]) -> Result<Vec<u8>> {
pub fn vec_u8_to_bytes_le(input: &[u8]) -> Vec<u8> {
// Calculate capacity for Vec:
// - 8 bytes for normalized vector length (usize)
// - variable length input data
@@ -88,11 +109,11 @@ pub fn vec_u8_to_bytes_le(input: &[u8]) -> Result<Vec<u8>> {
// We store the input
bytes.extend_from_slice(input);
Ok(bytes)
bytes
}
#[inline(always)]
pub fn bytes_le_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize)> {
pub fn bytes_le_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize), ConversionError> {
let mut read: usize = 0;
let len = usize::try_from(u64::from_le_bytes(input[0..8].try_into()?))?;
@@ -105,12 +126,11 @@ pub fn bytes_le_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize)> {
}
#[inline(always)]
pub fn bytes_le_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize)> {
pub fn bytes_le_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize), ConversionError> {
let mut read: usize = 0;
let mut res: Vec<Fr> = Vec::new();
let len = usize::try_from(u64::from_le_bytes(input[0..8].try_into()?))?;
read += 8;
let mut res: Vec<Fr> = Vec::with_capacity(len);
let el_size = fr_byte_size();
for i in 0..len {
@@ -123,7 +143,25 @@ pub fn bytes_le_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize)> {
}
#[inline(always)]
pub fn bytes_le_to_vec_usize(input: &[u8]) -> Result<Vec<usize>> {
pub fn bytes_be_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize), ConversionError> {
let mut read: usize = 0;
let mut res: Vec<Fr> = Vec::new();
let len = usize::try_from(u64::from_be_bytes(input[0..8].try_into()?))?;
read += 8;
let el_size = fr_byte_size();
for i in 0..len {
let (curr_el, _) = bytes_be_to_fr(&input[8 + el_size * i..8 + el_size * (i + 1)]);
res.push(curr_el);
read += el_size;
}
Ok((res, read))
}
#[inline(always)]
pub fn bytes_le_to_vec_usize(input: &[u8]) -> Result<Vec<usize>, ConversionError> {
let nof_elem = usize::try_from(u64::from_le_bytes(input[0..8].try_into()?))?;
if nof_elem == 0 {
Ok(vec![])
@@ -136,12 +174,32 @@ pub fn bytes_le_to_vec_usize(input: &[u8]) -> Result<Vec<usize>> {
}
}
/// Normalizes a `usize` into an 8-byte array, ensuring consistency across architectures.
/// On 32-bit systems, the result is zero-padded to 8 bytes.
/// On 64-bit systems, it directly represents the `usize` value.
#[inline(always)]
pub fn normalize_usize(input: usize) -> [u8; 8] {
input.to_le_bytes()
let mut bytes = [0u8; 8];
let input_bytes = input.to_le_bytes();
bytes[..input_bytes.len()].copy_from_slice(&input_bytes);
bytes
}
#[inline(always)] // using for test
pub fn generate_input_buffer() -> Cursor<String> {
Cursor::new(json!({}).to_string())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_fr_be() {
let fr_1 = Fr::from(255);
let b = fr_to_bytes_be(&fr_1);
let fr_1_de = bytes_be_to_fr(&b).0;
assert_eq!(fr_1, fr_1_de);
}
}

View File

@@ -3,7 +3,7 @@
mod test {
use ark_std::{rand::thread_rng, UniformRand};
use rand::Rng;
use rln::circuit::*;
use rln::circuit::{Fr, TEST_TREE_HEIGHT};
use rln::ffi::{hash as ffi_hash, poseidon_hash as ffi_poseidon_hash, *};
use rln::hashers::{hash_to_field, poseidon_hash as utils_poseidon_hash, ROUND_PARAMS};
use rln::protocol::*;
@@ -13,7 +13,9 @@ mod test {
use std::fs::File;
use std::io::Read;
use std::mem::MaybeUninit;
use std::str::FromStr;
use std::time::{Duration, Instant};
use claims::assert_lt;
const NO_OF_LEAVES: usize = 256;
@@ -27,7 +29,7 @@ mod test {
}
fn set_leaves_init(rln_pointer: &mut RLN, leaves: &[Fr]) {
let leaves_ser = vec_fr_to_bytes_le(&leaves).unwrap();
let leaves_ser = vec_fr_to_bytes_le(&leaves);
let input_buffer = &Buffer::from(leaves_ser.as_ref());
let success = init_tree_with_leaves(rln_pointer, input_buffer);
assert!(success, "init tree with leaves call failed");
@@ -156,6 +158,7 @@ mod test {
// random number between 0..no_of_leaves
let mut rng = thread_rng();
let set_index = rng.gen_range(0..NO_OF_LEAVES) as usize;
println!("set_index: {}", set_index);
// We add leaves in a batch into the tree
set_leaves_init(rln_pointer, &leaves);
@@ -169,14 +172,17 @@ mod test {
set_leaves_init(rln_pointer, &leaves[0..set_index]);
// We add the remaining n leaves in a batch starting from index set_index
let leaves_n = vec_fr_to_bytes_le(&leaves[set_index..]).unwrap();
let leaves_n = vec_fr_to_bytes_le(&leaves[set_index..]);
let buffer = &Buffer::from(leaves_n.as_ref());
let success = set_leaves_from(rln_pointer, set_index, buffer);
assert!(success, "set leaves from call failed");
// We get the root of the tree obtained adding leaves in batch
let root_batch_with_custom_index = get_tree_root(rln_pointer);
assert_eq!(root_batch_with_init, root_batch_with_custom_index);
assert_eq!(
root_batch_with_init, root_batch_with_custom_index,
"root batch !="
);
// We reset the tree to default
let success = set_tree(rln_pointer, TEST_TREE_HEIGHT);
@@ -192,7 +198,10 @@ mod test {
// We get the root of the tree obtained adding leaves using the internal index
let root_single_additions = get_tree_root(rln_pointer);
assert_eq!(root_batch_with_init, root_single_additions);
assert_eq!(
root_batch_with_init, root_single_additions,
"root single additions !="
);
}
#[test]
@@ -213,9 +222,9 @@ mod test {
let last_leaf_index = NO_OF_LEAVES - 1;
let indices = vec![last_leaf_index as u8];
let last_leaf = vec![*last_leaf];
let indices = vec_u8_to_bytes_le(&indices).unwrap();
let indices = vec_u8_to_bytes_le(&indices);
let indices_buffer = &Buffer::from(indices.as_ref());
let leaves = vec_fr_to_bytes_le(&last_leaf).unwrap();
let leaves = vec_fr_to_bytes_le(&last_leaf);
let leaves_buffer = &Buffer::from(leaves.as_ref());
let success = atomic_operation(
@@ -246,7 +255,7 @@ mod test {
let root_empty = get_tree_root(rln_pointer);
// We add leaves in a batch into the tree
let leaves = vec_fr_to_bytes_le(&leaves).unwrap();
let leaves = vec_fr_to_bytes_le(&leaves);
let buffer = &Buffer::from(leaves.as_ref());
let success = set_leaves_from(rln_pointer, bad_index, buffer);
assert!(!success, "set leaves from call succeeded");
@@ -853,6 +862,72 @@ mod test {
assert_eq!(id_commitment, expected_id_commitment_seed_bytes.unwrap());
}
#[test]
// Tests hash to field using FFI APIs
fn test_extended_keygen_be_ffi() {
let q = ark_bn254::Fr::from_str("21888242871839275222246405745257275088548364400416034343698204186575808495616").unwrap();
let mut c = 0;
loop {
// We create a RLN instance
let rln_pointer = create_rln_instance();
// We generate a new identity tuple from an input seed
// let seed_bytes: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
// let input_buffer = &Buffer::from(seed_bytes);
let mut output_buffer = MaybeUninit::<Buffer>::uninit();
let success =
extended_key_gen_be(rln_pointer, output_buffer.as_mut_ptr());
// assert!(success, "seeded key gen call failed");
let output_buffer = unsafe { output_buffer.assume_init() };
let result_data = <&[u8]>::from(&output_buffer).to_vec();
let (identity_secret_hash, id_commitment) =
deserialize_identity_pair_be(result_data);
// We check against expected values
// let expected_identity_trapdoor_seed_bytes = str_to_fr(
// "0x766ce6c7e7a01bdf5b3f257616f603918c30946fa23480f2859c597817e6716",
// 16,
// );
// let expected_identity_nullifier_seed_bytes = str_to_fr(
// "0x1f18714c7bc83b5bca9e89d404cf6f2f585bc4c0f7ed8b53742b7e2b298f50b4",
// 16,
// );
// let expected_identity_secret_hash_seed_bytes = str_to_fr(
// "0x2aca62aaa7abaf3686fff2caf00f55ab9462dc12db5b5d4bcf3994e671f8e521",
// 16,
// );
// let expected_id_commitment_seed_bytes = str_to_fr(
// "0x68b66aa0a8320d2e56842581553285393188714c48f9b17acd198b4f1734c5c",
// 16,
// );
// assert_eq!(
// identity_trapdoor,
// expected_identity_trapdoor_seed_bytes.unwrap()
// );
// assert_eq!(
// identity_nullifier,
// expected_identity_nullifier_seed_bytes.unwrap()
// );
// assert_eq!(
// identity_secret_hash,
// expected_identity_secret_hash_seed_bytes.unwrap()
// );
// assert_eq!(id_commitment, expected_id_commitment_seed_bytes.unwrap());
assert_lt!(identity_secret_hash, q);
assert_lt!(id_commitment, q);
c+=1;
if c > 1000 {
break;
}
}
}
#[test]
// Tests hash to field using FFI APIs
fn test_hash_to_field_ffi() {
@@ -885,7 +960,7 @@ mod test {
for _ in 0..number_of_inputs {
inputs.push(Fr::rand(&mut rng));
}
let inputs_ser = vec_fr_to_bytes_le(&inputs).unwrap();
let inputs_ser = vec_fr_to_bytes_le(&inputs);
let input_buffer = &Buffer::from(inputs_ser.as_ref());
let expected_hash = utils_poseidon_hash(inputs.as_ref());
@@ -1426,7 +1501,7 @@ mod stateless_test {
for _ in 0..number_of_inputs {
inputs.push(Fr::rand(&mut rng));
}
let inputs_ser = vec_fr_to_bytes_le(&inputs).unwrap();
let inputs_ser = vec_fr_to_bytes_le(&inputs);
let input_buffer = &Buffer::from(inputs_ser.as_ref());
let expected_hash = utils_poseidon_hash(inputs.as_ref());

View File

@@ -5,7 +5,10 @@
#[cfg(test)]
mod test {
use rln::hashers::{poseidon_hash, PoseidonHash};
use rln::{circuit::*, poseidon_tree::PoseidonTree};
use rln::{
circuit::{Fr, TEST_TREE_HEIGHT},
poseidon_tree::PoseidonTree,
};
use utils::{FullMerkleTree, OptimalMerkleTree, ZerokitMerkleProof, ZerokitMerkleTree};
#[test]
@@ -23,6 +26,7 @@ mod test {
assert_eq!(proof.leaf_index(), i);
tree_opt.set(i, leaves[i]).unwrap();
assert_eq!(tree_opt.root(), tree_full.root());
let proof = tree_opt.proof(i).expect("index should be set");
assert_eq!(proof.leaf_index(), i);
}
@@ -37,11 +41,11 @@ mod test {
#[test]
fn test_subtree_root() {
const DEPTH: usize = 3;
const LEAVES_LEN: usize = 6;
const LEAVES_LEN: usize = 8;
let mut tree = PoseidonTree::default(DEPTH).unwrap();
let leaves: Vec<Fr> = (0..LEAVES_LEN).map(|s| Fr::from(s as i32)).collect();
let _ = tree.set_range(0, leaves);
let _ = tree.set_range(0, leaves.into_iter());
for i in 0..LEAVES_LEN {
// check leaves
@@ -78,7 +82,7 @@ mod test {
let leaves: Vec<Fr> = (0..nof_leaves).map(|s| Fr::from(s as i32)).collect();
// check set_range
let _ = tree.set_range(0, leaves.clone());
let _ = tree.set_range(0, leaves.clone().into_iter());
assert!(tree.get_empty_leaves_indices().is_empty());
let mut vec_idxs = Vec::new();
@@ -98,26 +102,28 @@ mod test {
// check remove_indices_and_set_leaves inside override_range function
assert!(tree.get_empty_leaves_indices().is_empty());
let leaves_2: Vec<Fr> = (0..2).map(|s| Fr::from(s as i32)).collect();
tree.override_range(0, leaves_2.clone(), [0, 1, 2, 3])
tree.override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter())
.unwrap();
assert_eq!(tree.get_empty_leaves_indices(), vec![2, 3]);
// check remove_indices inside override_range function
tree.override_range(0, [], [0, 1]).unwrap();
tree.override_range(0, [].into_iter(), [0, 1].into_iter())
.unwrap();
assert_eq!(tree.get_empty_leaves_indices(), vec![0, 1, 2, 3]);
// check set_range inside override_range function
tree.override_range(0, leaves_2.clone(), []).unwrap();
tree.override_range(0, leaves_2.clone().into_iter(), [].into_iter())
.unwrap();
assert_eq!(tree.get_empty_leaves_indices(), vec![2, 3]);
let leaves_4: Vec<Fr> = (0..4).map(|s| Fr::from(s as i32)).collect();
// check if the indexes for write and delete are the same
tree.override_range(0, leaves_4.clone(), [0, 1, 2, 3])
tree.override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.unwrap();
assert!(tree.get_empty_leaves_indices().is_empty());
// check if indexes for deletion are before indexes for overwriting
tree.override_range(4, leaves_4.clone(), [0, 1, 2, 3])
tree.override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.unwrap();
// The result will be like this, because in the set_range function in pmtree
// the next_index value is increased not by the number of elements to insert,
@@ -128,7 +134,7 @@ mod test {
);
// check if the indices for write and delete do not overlap completely
tree.override_range(2, leaves_4.clone(), [0, 1, 2, 3])
tree.override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.unwrap();
// The result will be like this, because in the set_range function in pmtree
// the next_index value is increased not by the number of elements to insert,

View File

@@ -5,7 +5,12 @@ mod test {
use rln::circuit::{Fr, TEST_TREE_HEIGHT};
use rln::hashers::{hash_to_field, poseidon_hash};
use rln::poseidon_tree::PoseidonTree;
use rln::protocol::*;
use rln::protocol::{
deserialize_proof_values, deserialize_witness, generate_proof, keygen,
proof_values_from_witness, rln_witness_from_json, rln_witness_from_values,
rln_witness_to_json, seeded_keygen, serialize_proof_values, serialize_witness,
verify_proof, RLNWitnessInput,
};
use rln::utils::str_to_fr;
use utils::{ZerokitMerkleProof, ZerokitMerkleTree};

View File

@@ -12,7 +12,10 @@ mod test {
use rln::hashers::{hash_to_field, poseidon_hash as utils_poseidon_hash, ROUND_PARAMS};
use rln::protocol::deserialize_identity_tuple;
use rln::public::{hash as public_hash, poseidon_hash as public_poseidon_hash, RLN};
use rln::utils::*;
use rln::utils::{
bytes_le_to_fr, bytes_le_to_vec_fr, bytes_le_to_vec_u8, bytes_le_to_vec_usize,
fr_to_bytes_le, generate_input_buffer, str_to_fr, vec_fr_to_bytes_le,
};
use std::io::Cursor;
#[test]
@@ -243,7 +246,7 @@ mod test {
}
let expected_hash = utils_poseidon_hash(&inputs);
let mut input_buffer = Cursor::new(vec_fr_to_bytes_le(&inputs).unwrap());
let mut input_buffer = Cursor::new(vec_fr_to_bytes_le(&inputs));
let mut output_buffer = Cursor::new(Vec::<u8>::new());
public_poseidon_hash(&mut input_buffer, &mut output_buffer).unwrap();

78
rln/tests/public_be.rs Normal file
View File

@@ -0,0 +1,78 @@
#[cfg(test)]
mod test {
use std::io::Cursor;
use std::str::FromStr;
use ark_bn254::Fr;
use claims::assert_lt;
use rln::public::RLN;
use rln::utils::{bytes_be_to_fr, bytes_le_to_fr};
#[test]
fn test_keygen_be() {
// let q = Fr::from_str("21888242871839275222246405745257275088548364400416034343698204186575808495617").unwrap();
let q = Fr::from_str("21888242871839275222246405745257275088548364400416034343698204186575808495616").unwrap();
// println!("q: {}", q);
let mut c = 0;
loop {
let rln = RLN::default();
let mut output_buffer = Cursor::new(Vec::<u8>::new());
let (secret, id_co) = rln.key_gen_be_2(&mut output_buffer).expect("TODO: panic message");
let serialized_output = output_buffer.into_inner();
let (secret_de, read) = bytes_be_to_fr(&serialized_output);
let (id_commitment, _) = bytes_be_to_fr(&serialized_output[read..].to_vec());
assert_eq!(secret, secret_de);
assert_lt!(secret, q);
assert_eq!(id_commitment, id_co);
assert_lt!(id_commitment, q);
c+=1;
if c > 10_000 {
break;
}
}
println!("c: {}", c);
}
#[test]
fn test_extended_keygen_be() {
// let q = Fr::from_str("21888242871839275222246405745257275088548364400416034343698204186575808495617").unwrap();
let q = Fr::from_str("21888242871839275222246405745257275088548364400416034343698204186575808495616").unwrap();
// println!("q: {}", q);
let mut c = 0;
loop {
let rln = RLN::default();
let mut output_buffer = Cursor::new(Vec::<u8>::new());
let (trap, nulli, secret, id_co) = rln.extended_key_gen_be_2(&mut output_buffer).expect("TODO: panic message");
let serialized_output = output_buffer.into_inner();
let (trap_de, read) = bytes_be_to_fr(&serialized_output);
let (nulli_de, read) = bytes_be_to_fr(&serialized_output[read..].to_vec());
let (secret_de, read) = bytes_be_to_fr(&serialized_output[read+read..].to_vec());
let (id_co_de, _read) = bytes_be_to_fr(&serialized_output[read+read+read..].to_vec());
assert_eq!(trap, trap_de);
assert_eq!(nulli, nulli_de);
assert_eq!(secret, secret_de);
assert_lt!(secret, q);
assert_eq!(id_co, id_co_de);
assert_lt!(id_co, q);
c+=1;
if c > 5_000 {
break;
}
}
println!("c: {}", c);
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "zerokit_utils"
version = "0.5.2"
version = "0.6.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Various utilities for Zerokit"
@@ -12,29 +12,36 @@ repository = "https://github.com/vacp2p/zerokit"
bench = false
[dependencies]
ark-ff = { version = "0.5.0", features = ["asm"] }
num-bigint = { version = "0.4.6", default-features = false, features = [
"rand",
ark-ff = { version = "0.5.0", default-features = false, features = [
"parallel",
] }
color-eyre = "0.6.2"
pmtree = { package = "vacp2p_pmtree", version = "=2.0.2", optional = true }
num-bigint = { version = "0.4.6", default-features = false }
pmtree = { package = "vacp2p_pmtree", version = "2.0.2", optional = true }
sled = "0.34.7"
serde = "1.0"
lazy_static = "1.4.0"
hex = "0.4"
serde_json = "1.0"
lazy_static = "1.5.0"
hex = "0.4.3"
rayon = "1.7.0"
thiserror = "2.0"
[dev-dependencies]
ark-bn254 = { version = "0.5.0", features = ["std"] }
num-traits = "0.2.19"
hex-literal = "0.3.4"
hex-literal = "0.4.1"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
criterion = { version = "0.4.0", features = ["html_reports"] }
criterion = { version = "0.6.0", features = ["html_reports"] }
[features]
default = ["parallel"]
parallel = ["ark-ff/parallel"]
default = []
pmtree-ft = ["pmtree"]
[[bench]]
name = "merkle_tree_benchmark"
harness = false
[[bench]]
name = "poseidon_benchmark"
harness = false
[package.metadata.docs.rs]
all-features = true

View File

@@ -8,4 +8,4 @@ args = ["test", "--release"]
[tasks.bench]
command = "cargo"
args = ["bench"]
args = ["bench"]

View File

@@ -1,68 +1,124 @@
# Zerokit Utils Crate
[![Crates.io](https://img.shields.io/crates/v/zerokit_utils.svg)](https://crates.io/crates/zerokit_utils)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
Cryptographic primitives for zero-knowledge applications, featuring efficient Merkle tree implementations and a Poseidon hash function.
**Zerokit Utils** provides essential cryptographic primitives optimized for zero-knowledge applications.
This crate features efficient Merkle tree implementations and a Poseidon hash function,
designed to be robust and performant.
## Overview
This crate provides core cryptographic components optimized for zero-knowledge proof systems:
1. Multiple Merkle tree implementations with different space/time tradeoffs
2. A Poseidon hash implementation
- **Multiple Merkle Trees**: Various implementations optimised for the trade-off between space and time.
- **Poseidon Hash Function**: An efficient hashing algorithm suitable for ZK contexts, with customizable parameters.
- **Parallel Performance**: Leverages Rayon for significant speed-ups in Merkle tree computations.
- **Arkworks Compatibility**: Poseidon hash implementation is designed to work seamlessly
with Arkworks field traits and data structures.
## Merkle Tree Implementations
The crate supports two interchangeable Merkle tree implementations:
Merkle trees are fundamental data structures for verifying data integrity and set membership.
Zerokit Utils offers two interchangeable implementations:
### Understanding Merkle Tree Terminology
To better understand the structure and parameters of our Merkle trees, here's a quick glossary:
- **Depth (`depth`)**: level of leaves if we count from root.
If the root is at level 0, leaves are at level `depth`.
- **Number of Levels**: `depth + 1`.
- **Capacity (Number of Leaves)**: $2^{\text{depth}}$. This is the maximum number of leaves the tree can hold.
- **Total Number of Nodes**: $2^{(\text{depth} + 1)} - 1$ for a full binary tree.
**Example for a tree with `depth: 3`**:
- Number of Levels: 4 (levels 0, 1, 2, 3)
- Capacity (Number of Leaves): $2^3 = 8$
- Total Number of Nodes: $2^{(3+1)} - 1 = 15$
Visual representation of a Merkle tree with `depth: 3`:
```mermaid
flowchart TD
A[Root] --> N1
A[Root] --> N2
N1 --> N3
N1 --> N4
N2 --> N5
N2 --> N6
N3 -->|Leaf| L1
N3 -->|Leaf| L2
N4 -->|Leaf| L3
N4 -->|Leaf| L4
N5 -->|Leaf| L5
N5 -->|Leaf| L6
N6 -->|Leaf| L7
N6 -->|Leaf| L8
```
### Available Implementations
- **FullMerkleTree**
- Stores each tree node in memory
- Stores all tree nodes in memory.
- Use Case: Use when memory is abundant and operation speed is critical.
- **OptimalMerkleTree**
- Only stores nodes used to prove accumulation of set leaves
- Stores only the nodes required to prove the accumulation of set leaves (i.e., authentication paths).
- Use Case: Suited for environments where memory efficiency is a higher priority than raw speed.
#### Parallel Processing with Rayon
Both `OptimalMerkleTree` and `FullMerkleTree` internally utilize the Rayon crate
to accelerate computations through data parallelism.
This can lead to significant performance improvements, particularly during updates to large Merkle trees.
## Poseidon Hash Implementation
This crate provides an implementation to compute the Poseidon hash round constants and MDS matrices:
This crate provides an implementation for computing Poseidon hash round constants and MDS matrices.
Key characteristics include:
- **Customizable parameters**: Supports different security levels and input sizes
- **Arkworks-friendly**: Adapted to work over arkworks field traits and custom data structures
- **Customizable parameters**: Supports various security levels and input sizes,
allowing you to tailor the hash function to your specific needs.
- **Arkworks-friendly**: Adapted to integrate smoothly with Arkworks field traits and custom data structures.
### Security Note
### ⚠️ Security Note
The MDS matrices are generated iteratively using the Grain LFSR until certain criteria are met.
According to the paper, such matrices must respect specific conditions which are checked by 3 different algorithms in the reference implementation.
The MDS matrices used in the Poseidon hash function are generated iteratively
using the Grain LFSR (Linear Feedback Shift Register) algorithm until specific cryptographic criteria are met.
These validation algorithms are not currently implemented in this crate.
For the hardcoded parameters, the first random matrix generated satisfies these conditions.
If using different parameters, you should check against the reference implementation how many matrices are generated before outputting the correct one,
and pass this number to the `skip_matrices` parameter of the `find_poseidon_ark_and_mds` function.
- The reference Poseidon implementation includes validation algorithms to ensure these criteria are satisfied.
These validation algorithms are not currently implemented in this crate.
- For the hardcoded parameters provided within this crate,
the initially generated random matrix has been verified to meet these conditions.
- If you intend to use custom parameters, it is crucial to verify your generated MDS matrix.
You should consult the Poseidon reference implementation to determine
how many matrices are typically skipped before a valid one is found.
This count should then be passed as the `skip_matrices parameter` to the `find_poseidon_ark_and_mds`
function in this crate.
## Installation
Add Zerokit Utils to your Rust project:
Add zerokit-utils as a dependency to your Cargo.toml file:
```toml
[dependencies]
zerokit-utils = "0.5.1"
zerokit-utils = "0.6.0"
```
## Performance Considerations
- **FullMerkleTree**: Use when memory is abundant and operation speed is critical
- **OptimalMerkleTree**: Use when memory efficiency is more important than raw speed
- **Poseidon**: Offers a good balance between security and performance for ZK applications
## Building and Testing
```bash
# Build the crate
cargo build
cargo make build
# Run tests
cargo test
cargo make test
# Run benchmarks
cargo bench
cargo make bench
```
To view the results of the benchmark, open the `target/criterion/report/index.html` file generated after the bench

View File

@@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, Criterion};
use hex_literal::hex;
use lazy_static::lazy_static;
use std::{fmt::Display, str::FromStr};
use criterion::{criterion_group, criterion_main, Criterion};
use lazy_static::lazy_static;
use tiny_keccak::{Hasher as _, Keccak};
use zerokit_utils::{
FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree,
@@ -47,54 +47,78 @@ impl FromStr for TestFr {
}
lazy_static! {
static ref LEAVES: [TestFr; 4] = [
hex!("0000000000000000000000000000000000000000000000000000000000000001"),
hex!("0000000000000000000000000000000000000000000000000000000000000002"),
hex!("0000000000000000000000000000000000000000000000000000000000000003"),
hex!("0000000000000000000000000000000000000000000000000000000000000004"),
]
.map(TestFr);
static ref LEAVES: Vec<TestFr> = {
let mut leaves = Vec::with_capacity(1 << 20);
for i in 0..(1 << 20) {
let mut bytes = [0u8; 32];
bytes[28..].copy_from_slice(&(i as u32).to_be_bytes());
leaves.push(TestFr(bytes));
}
leaves
};
static ref INDICES: Vec<usize> = (0..(1 << 20)).collect();
}
const NOF_LEAVES: usize = 8192;
pub fn optimal_merkle_tree_benchmark(c: &mut Criterion) {
let mut tree =
OptimalMerkleTree::<Keccak256>::new(2, TestFr([0; 32]), OptimalMerkleConfig::default())
OptimalMerkleTree::<Keccak256>::new(20, TestFr([0; 32]), OptimalMerkleConfig::default())
.unwrap();
for i in 0..NOF_LEAVES {
tree.set(i, LEAVES[i % LEAVES.len()]).unwrap();
}
c.bench_function("OptimalMerkleTree::set", |b| {
let mut index = NOF_LEAVES;
b.iter(|| {
tree.set(0, LEAVES[0]).unwrap();
tree.set(index % (1 << 20), LEAVES[index % LEAVES.len()])
.unwrap();
index = (index + 1) % (1 << 20);
})
});
c.bench_function("OptimalMerkleTree::delete", |b| {
let mut index = 0;
b.iter(|| {
tree.delete(0).unwrap();
tree.delete(index % NOF_LEAVES).unwrap();
tree.set(index % NOF_LEAVES, LEAVES[index % LEAVES.len()])
.unwrap();
index = (index + 1) % NOF_LEAVES;
})
});
c.bench_function("OptimalMerkleTree::override_range", |b| {
let mut offset = 0;
b.iter(|| {
tree.override_range(0, *LEAVES, [0, 1, 2, 3]).unwrap();
})
});
c.bench_function("OptimalMerkleTree::compute_root", |b| {
b.iter(|| {
tree.compute_root().unwrap();
let range = offset..offset + NOF_LEAVES;
tree.override_range(
offset,
LEAVES[range.clone()].iter().cloned(),
INDICES[range.clone()].iter().cloned(),
)
.unwrap();
offset = (offset + NOF_LEAVES) % (1 << 20);
})
});
c.bench_function("OptimalMerkleTree::get", |b| {
let mut index = 0;
b.iter(|| {
tree.get(0).unwrap();
tree.get(index % NOF_LEAVES).unwrap();
index = (index + 1) % NOF_LEAVES;
})
});
// check intermediate node getter which required additional computation of sub root index
c.bench_function("OptimalMerkleTree::get_subtree_root", |b| {
let mut level = 1;
let mut index = 0;
b.iter(|| {
tree.get_subtree_root(1, 0).unwrap();
tree.get_subtree_root(level % 20, index % (1 << (20 - (level % 20))))
.unwrap();
index = (index + 1) % (1 << (20 - (level % 20)));
level = 1 + (level % 20);
})
});
@@ -107,42 +131,61 @@ pub fn optimal_merkle_tree_benchmark(c: &mut Criterion) {
pub fn full_merkle_tree_benchmark(c: &mut Criterion) {
let mut tree =
FullMerkleTree::<Keccak256>::new(2, TestFr([0; 32]), FullMerkleConfig::default()).unwrap();
FullMerkleTree::<Keccak256>::new(20, TestFr([0; 32]), FullMerkleConfig::default()).unwrap();
for i in 0..NOF_LEAVES {
tree.set(i, LEAVES[i % LEAVES.len()]).unwrap();
}
c.bench_function("FullMerkleTree::set", |b| {
let mut index = NOF_LEAVES;
b.iter(|| {
tree.set(0, LEAVES[0]).unwrap();
tree.set(index % (1 << 20), LEAVES[index % LEAVES.len()])
.unwrap();
index = (index + 1) % (1 << 20);
})
});
c.bench_function("FullMerkleTree::delete", |b| {
let mut index = 0;
b.iter(|| {
tree.delete(0).unwrap();
tree.delete(index % NOF_LEAVES).unwrap();
tree.set(index % NOF_LEAVES, LEAVES[index % LEAVES.len()])
.unwrap();
index = (index + 1) % NOF_LEAVES;
})
});
c.bench_function("FullMerkleTree::override_range", |b| {
let mut offset = 0;
b.iter(|| {
tree.override_range(0, *LEAVES, [0, 1, 2, 3]).unwrap();
})
});
c.bench_function("FullMerkleTree::compute_root", |b| {
b.iter(|| {
tree.compute_root().unwrap();
let range = offset..offset + NOF_LEAVES;
tree.override_range(
offset,
LEAVES[range.clone()].iter().cloned(),
INDICES[range.clone()].iter().cloned(),
)
.unwrap();
offset = (offset + NOF_LEAVES) % (1 << 20);
})
});
c.bench_function("FullMerkleTree::get", |b| {
let mut index = 0;
b.iter(|| {
tree.get(0).unwrap();
tree.get(index % NOF_LEAVES).unwrap();
index = (index + 1) % NOF_LEAVES;
})
});
// check intermediate node getter which required additional computation of sub root index
c.bench_function("FullMerkleTree::get_subtree_root", |b| {
let mut level = 1;
let mut index = 0;
b.iter(|| {
tree.get_subtree_root(1, 0).unwrap();
tree.get_subtree_root(level % 20, index % (1 << (20 - (level % 20))))
.unwrap();
index = (index + 1) % (1 << (20 - (level % 20)));
level = 1 + (level % 20);
})
});

View File

@@ -0,0 +1,65 @@
use std::hint::black_box;
use ark_bn254::Fr;
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput};
use zerokit_utils::Poseidon;
const ROUND_PARAMS: [(usize, usize, usize, usize); 8] = [
(2, 8, 56, 0),
(3, 8, 57, 0),
(4, 8, 56, 0),
(5, 8, 60, 0),
(6, 8, 60, 0),
(7, 8, 63, 0),
(8, 8, 64, 0),
(9, 8, 63, 0),
];
pub fn poseidon_benchmark(c: &mut Criterion) {
let hasher = Poseidon::<Fr>::from(&ROUND_PARAMS);
let mut group = c.benchmark_group("poseidon Fr");
for size in [10u32, 100, 1000].iter() {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("Array hash", size), size, |b, &size| {
b.iter_batched(
// Setup: create values for each benchmark iteration
|| {
let mut values = Vec::with_capacity(size as usize);
for i in 0..size {
values.push([Fr::from(i)]);
}
values
},
// Actual benchmark
|values| {
for v in values.iter() {
let _ = hasher.hash(black_box(&v[..]));
}
},
BatchSize::SmallInput,
)
});
}
// Benchmark single hash operation separately
group.bench_function("Single hash", |b| {
let input = [Fr::from(u64::MAX)];
b.iter(|| {
let _ = hasher.hash(black_box(&input[..]));
})
});
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default()
.warm_up_time(std::time::Duration::from_millis(500))
.measurement_time(std::time::Duration::from_secs(4))
.sample_size(20);
targets = poseidon_benchmark
}
criterion_main!(benches);

View File

@@ -0,0 +1,31 @@
#[derive(thiserror::Error, Debug)]
pub enum ZerokitMerkleTreeError {
#[error("Invalid index")]
InvalidIndex,
// InvalidProof,
#[error("Leaf index out of bounds")]
InvalidLeaf,
#[error("Level exceeds tree depth")]
InvalidLevel,
#[error("Subtree index out of bounds")]
InvalidSubTreeIndex,
#[error("Start level is != from end level")]
InvalidStartAndEndLevel,
#[error("set_range got too many leaves")]
TooManySet,
#[error("Unknown error while computing merkle proof")]
ComputingProofError,
#[error("Invalid witness length (!= tree depth)")]
InvalidWitness,
#[cfg(feature = "pmtree-ft")]
#[error("Pmtree error: {0}")]
PmtreeErrorKind(#[from] pmtree::PmtreeErrorKind),
}
#[derive(Debug, thiserror::Error)]
pub enum FromConfigError {
#[error("Error while reading pmtree config: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Error while creating pmtree config: path already exists")]
PathExists,
}

View File

@@ -1,28 +1,29 @@
use crate::merkle_tree::{FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree};
use color_eyre::{Report, Result};
use std::{
cmp::max,
fmt::Debug,
iter::{once, repeat, successors},
iter::{once, repeat_n},
str::FromStr,
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use crate::merkle_tree::{
error::{FromConfigError, ZerokitMerkleTreeError},
FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES,
};
////////////////////////////////////////////////////////////
///// Full Merkle Tree Implementation
////////////////////////////////////////////////////////////
/// Merkle tree with all leaf and intermediate hashes stored
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct FullMerkleTree<H: Hasher> {
pub struct FullMerkleTree<H>
where
H: Hasher,
{
/// The depth of the tree, i.e. the number of levels from leaf to root
depth: usize,
/// The nodes cached from the empty part of the tree (where leaves are set to default).
/// Since the rightmost part of the tree is usually changed much later than its creation,
/// we can prove accumulation of elements in the leftmost part, with no need to initialize the full tree
/// and by caching few intermediate nodes to the root computed from default leaves
cached_nodes: Vec<H::Fr>,
/// The tree nodes
nodes: Vec<H::Fr>,
@@ -30,11 +31,11 @@ pub struct FullMerkleTree<H: Hasher> {
/// Set to 0 if the leaf is empty and set to 1 in otherwise.
cached_leaves_indices: Vec<u8>,
// The next available (i.e., never used) tree index. Equivalently, the number of leaves added to the tree
// (deletions leave next_index unchanged)
/// The next available (i.e., never used) tree index. Equivalently, the number of leaves added to the tree
/// (deletions leave next_index unchanged)
next_index: usize,
// metadata that an application may use to store additional information
/// metadata that an application may use to store additional information
metadata: Vec<u8>,
}
@@ -56,9 +57,9 @@ pub struct FullMerkleProof<H: Hasher>(pub Vec<FullMerkleBranch<H>>);
pub struct FullMerkleConfig(());
impl FromStr for FullMerkleConfig {
type Err = Report;
type Err = FromConfigError;
fn from_str(_s: &str) -> Result<Self> {
fn from_str(_s: &str) -> Result<Self, Self::Err> {
Ok(FullMerkleConfig::default())
}
}
@@ -72,86 +73,89 @@ where
type Hasher = H;
type Config = FullMerkleConfig;
fn default(depth: usize) -> Result<Self> {
fn default(depth: usize) -> Result<Self, ZerokitMerkleTreeError> {
FullMerkleTree::<H>::new(depth, Self::Hasher::default_leaf(), Self::Config::default())
}
/// Creates a new `MerkleTree`
/// depth - the height of the tree made only of hash nodes. 2^depth is the maximum number of leaves hash nodes
fn new(depth: usize, initial_leaf: FrOf<Self::Hasher>, _config: Self::Config) -> Result<Self> {
fn new(
depth: usize,
default_leaf: FrOf<Self::Hasher>,
_config: Self::Config,
) -> Result<Self, ZerokitMerkleTreeError> {
// Compute cache node values, leaf to root
let cached_nodes = successors(Some(initial_leaf), |prev| Some(H::hash(&[*prev, *prev])))
.take(depth + 1)
.collect::<Vec<_>>();
let mut cached_nodes: Vec<H::Fr> = Vec::with_capacity(depth + 1);
cached_nodes.push(default_leaf);
for i in 0..depth {
cached_nodes.push(H::hash(&[cached_nodes[i]; 2]));
}
cached_nodes.reverse();
// Compute node values
let nodes = cached_nodes
.iter()
.rev()
.enumerate()
.flat_map(|(levels, hash)| repeat(hash).take(1 << levels))
.flat_map(|(levels, hash)| repeat_n(hash, 1 << levels))
.cloned()
.collect::<Vec<_>>();
debug_assert!(nodes.len() == (1 << (depth + 1)) - 1);
let next_index = 0;
Ok(Self {
depth,
cached_nodes,
nodes,
cached_leaves_indices: vec![0; 1 << depth],
next_index,
next_index: 0,
metadata: Vec::new(),
})
}
fn close_db_connection(&mut self) -> Result<()> {
fn close_db_connection(&mut self) -> Result<(), ZerokitMerkleTreeError> {
Ok(())
}
// Returns the depth of the tree
/// Returns the depth of the tree
fn depth(&self) -> usize {
self.depth
}
// Returns the capacity of the tree, i.e. the maximum number of accumulatable leaves
/// Returns the capacity of the tree, i.e. the maximum number of accumulatable leaves
fn capacity(&self) -> usize {
1 << self.depth
}
// Returns the total number of leaves set
/// Returns the total number of leaves set
fn leaves_set(&self) -> usize {
self.next_index
}
#[must_use]
// Returns the root of the tree
/// Returns the root of the tree
fn root(&self) -> FrOf<Self::Hasher> {
self.nodes[0]
}
// Sets a leaf at the specified tree index
fn set(&mut self, leaf: usize, hash: FrOf<Self::Hasher>) -> Result<()> {
/// Sets a leaf at the specified tree index
fn set(&mut self, leaf: usize, hash: FrOf<Self::Hasher>) -> Result<(), ZerokitMerkleTreeError> {
self.set_range(leaf, once(hash))?;
self.next_index = max(self.next_index, leaf + 1);
Ok(())
}
// Get a leaf from the specified tree index
fn get(&self, leaf: usize) -> Result<FrOf<Self::Hasher>> {
/// Get a leaf from the specified tree index
fn get(&self, leaf: usize) -> Result<FrOf<Self::Hasher>, ZerokitMerkleTreeError> {
if leaf >= self.capacity() {
return Err(Report::msg("leaf index out of bounds"));
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
Ok(self.nodes[self.capacity() + leaf - 1])
}
fn get_subtree_root(&self, n: usize, index: usize) -> Result<H::Fr> {
/// Returns the root of the subtree at level n and index
fn get_subtree_root(&self, n: usize, index: usize) -> Result<H::Fr, ZerokitMerkleTreeError> {
if n > self.depth() {
return Err(Report::msg("level exceeds depth size"));
return Err(ZerokitMerkleTreeError::InvalidIndex);
}
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
if n == 0 {
Ok(self.root())
@@ -161,7 +165,7 @@ where
let mut idx = self.capacity() + index - 1;
let mut nd = self.depth;
loop {
let parent = self.parent(idx).unwrap();
let parent = self.parent(idx).expect("parent should exist");
nd -= 1;
if nd == n {
return Ok(self.nodes[parent]);
@@ -172,6 +176,8 @@ where
}
}
}
/// Returns the indices of the leaves that are empty
fn get_empty_leaves_indices(&self) -> Vec<usize> {
self.cached_leaves_indices
.iter()
@@ -182,40 +188,45 @@ where
.collect()
}
// Sets tree nodes, starting from start index
// Function proper of FullMerkleTree implementation
fn set_range<I: IntoIterator<Item = FrOf<Self::Hasher>>>(
/// Sets multiple leaves from the specified tree index
fn set_range<I: ExactSizeIterator<Item = FrOf<Self::Hasher>>>(
&mut self,
start: usize,
hashes: I,
) -> Result<()> {
leaves: I,
) -> Result<(), ZerokitMerkleTreeError> {
let index = self.capacity() + start - 1;
let mut count = 0;
// first count number of hashes, and check that they fit in the tree
// first count number of leaves, and check that they fit in the tree
// then insert into the tree
let hashes = hashes.into_iter().collect::<Vec<_>>();
if hashes.len() + start > self.capacity() {
return Err(Report::msg("provided hashes do not fit in the tree"));
let leaves = leaves.into_iter().collect::<Vec<_>>();
if leaves.len() + start > self.capacity() {
return Err(ZerokitMerkleTreeError::TooManySet);
}
hashes.into_iter().for_each(|hash| {
leaves.into_iter().for_each(|hash| {
self.nodes[index + count] = hash;
self.cached_leaves_indices[start + count] = 1;
count += 1;
});
if count != 0 {
self.update_nodes(index, index + (count - 1))?;
self.update_hashes(index, index + (count - 1))?;
self.next_index = max(self.next_index, start + count);
}
Ok(())
}
fn override_range<I, J>(&mut self, start: usize, leaves: I, indices: J) -> Result<()>
/// Overrides a range of leaves while resetting specified indices to default and preserving unaffected values.
fn override_range<I, J>(
&mut self,
start: usize,
leaves: I,
indices: J,
) -> Result<(), ZerokitMerkleTreeError>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>,
I: ExactSizeIterator<Item = FrOf<Self::Hasher>>,
J: ExactSizeIterator<Item = usize>,
{
let indices = indices.into_iter().collect::<Vec<_>>();
let min_index = *indices.first().unwrap();
let min_index = *indices.first().expect("indices should not be empty");
let leaves_vec = leaves.into_iter().collect::<Vec<_>>();
let max_index = start + leaves_vec.len();
@@ -237,18 +248,17 @@ where
self.cached_leaves_indices[i] = 0;
}
self.set_range(start, set_values)
.map_err(|e| Report::msg(e.to_string()))
self.set_range(start, set_values.into_iter())
}
// Sets a leaf at the next available index
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()> {
/// Sets a leaf at the next available index
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<(), ZerokitMerkleTreeError> {
self.set(self.next_index, leaf)?;
Ok(())
}
// Deletes a leaf at a certain index by setting it to its default value (next_index is not updated)
fn delete(&mut self, index: usize) -> Result<()> {
/// Deletes a leaf at a certain index by setting it to its default value (next_index is not updated)
fn delete(&mut self, index: usize) -> Result<(), ZerokitMerkleTreeError> {
// We reset the leaf only if we previously set a leaf at that index
if index < self.next_index {
self.set(index, H::default_leaf())?;
@@ -258,9 +268,9 @@ where
}
// Computes a merkle proof the leaf at the specified index
fn proof(&self, leaf: usize) -> Result<FullMerkleProof<H>> {
fn proof(&self, leaf: usize) -> Result<FullMerkleProof<H>, ZerokitMerkleTreeError> {
if leaf >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
let mut index = self.capacity() + leaf - 1;
let mut path = Vec::with_capacity(self.depth + 1);
@@ -277,30 +287,29 @@ where
}
// Verifies a Merkle proof with respect to the input leaf and the tree root
fn verify(&self, hash: &FrOf<Self::Hasher>, proof: &FullMerkleProof<H>) -> Result<bool> {
fn verify(
&self,
hash: &FrOf<Self::Hasher>,
proof: &FullMerkleProof<H>,
) -> Result<bool, ZerokitMerkleTreeError> {
Ok(proof.compute_root_from(hash) == self.root())
}
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
Ok(self.root())
}
fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> {
fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), ZerokitMerkleTreeError> {
self.metadata = metadata.to_vec();
Ok(())
}
fn metadata(&self) -> Result<Vec<u8>> {
fn metadata(&self) -> Result<Vec<u8>, ZerokitMerkleTreeError> {
Ok(self.metadata.to_vec())
}
}
// Utilities for updating the tree nodes
impl<H: Hasher> FullMerkleTree<H>
where
H: Hasher,
{
// Utilities for updating the tree nodes
/// For a given node index, return the parent node index
/// Returns None if there is no parent (root node)
fn parent(&self, index: usize) -> Option<usize> {
@@ -316,23 +325,60 @@ where
(index << 1) + 1
}
/// Returns the depth level of a node based on its index in the flattened tree.
fn levels(&self, index: usize) -> usize {
// `n.next_power_of_two()` will return `n` iff `n` is a power of two.
// The extra offset corrects this.
(index + 2).next_power_of_two().trailing_zeros() as usize - 1
}
fn update_nodes(&mut self, start: usize, end: usize) -> Result<()> {
if self.levels(start) != self.levels(end) {
return Err(Report::msg("self.levels(start) != self.levels(end)"));
/// Updates parent hashes after modifying a range of nodes at the same level.
///
/// - `start_index`: The first index at the current level that was updated.
/// - `end_index`: The last index (inclusive) at the same level that was updated.
fn update_hashes(
&mut self,
start_index: usize,
end_index: usize,
) -> Result<(), ZerokitMerkleTreeError> {
// Ensure the range is within the same tree level
if self.levels(start_index) != self.levels(end_index) {
return Err(ZerokitMerkleTreeError::InvalidStartAndEndLevel);
}
if let (Some(start), Some(end)) = (self.parent(start), self.parent(end)) {
for parent in start..=end {
let child = self.first_child(parent);
self.nodes[parent] = H::hash(&[self.nodes[child], self.nodes[child + 1]]);
// Compute parent indices for the range
if let (Some(start_parent), Some(end_parent)) =
(self.parent(start_index), self.parent(end_index))
{
// Use parallel processing when the number of pairs exceeds the threshold
if end_parent - start_parent + 1 >= MIN_PARALLEL_NODES {
let updates: Vec<(usize, H::Fr)> = (start_parent..=end_parent)
.into_par_iter()
.map(|parent| {
let left_child = self.first_child(parent);
let right_child = left_child + 1;
let hash = H::hash(&[self.nodes[left_child], self.nodes[right_child]]);
(parent, hash)
})
.collect();
for (parent, hash) in updates {
self.nodes[parent] = hash;
}
} else {
// Otherwise, fallback to sequential update for small ranges
for parent in start_parent..=end_parent {
let left_child = self.first_child(parent);
let right_child = left_child + 1;
self.nodes[parent] =
H::hash(&[self.nodes[left_child], self.nodes[right_child]]);
}
}
self.update_nodes(start, end)?;
// Recurse to update upper levels
self.update_hashes(start_parent, end_parent)?;
}
Ok(())
}
}
@@ -341,14 +387,12 @@ impl<H: Hasher> ZerokitMerkleProof for FullMerkleProof<H> {
type Index = u8;
type Hasher = H;
#[must_use]
// Returns the length of a Merkle proof
fn length(&self) -> usize {
self.0.len()
}
/// Computes the leaf index corresponding to a Merkle proof
#[must_use]
fn leaf_index(&self) -> usize {
self.0.iter().rev().fold(0, |index, branch| match branch {
FullMerkleBranch::Left(_) => index << 1,
@@ -356,7 +400,6 @@ impl<H: Hasher> ZerokitMerkleProof for FullMerkleProof<H> {
})
}
#[must_use]
/// Returns the path elements forming a Merkle proof
fn get_path_elements(&self) -> Vec<FrOf<Self::Hasher>> {
self.0
@@ -368,7 +411,6 @@ impl<H: Hasher> ZerokitMerkleProof for FullMerkleProof<H> {
}
/// Returns the path indexes forming a Merkle proof
#[must_use]
fn get_path_index(&self) -> Vec<Self::Index> {
self.0
.iter()
@@ -380,7 +422,6 @@ impl<H: Hasher> ZerokitMerkleProof for FullMerkleProof<H> {
}
/// Computes the Merkle root corresponding by iteratively hashing a Merkle proof with a given input leaf
#[must_use]
fn compute_root_from(&self, hash: &FrOf<Self::Hasher>) -> FrOf<Self::Hasher> {
self.0.iter().fold(*hash, |hash, branch| match branch {
FullMerkleBranch::Left(sibling) => H::hash(&[hash, *sibling]),

View File

@@ -8,20 +8,25 @@
// and https://github.com/worldcoin/semaphore-rs/blob/d462a4372f1fd9c27610f2acfe4841fab1d396aa/src/merkle_tree.rs
//!
//! # To do
//! # TODO
//!
//! * Disk based storage backend (using mmaped files should be easy)
//! * Implement serialization for tree and Merkle proof
use std::str::FromStr;
use crate::merkle_tree::error::ZerokitMerkleTreeError;
use std::{
fmt::{Debug, Display},
str::FromStr,
};
use color_eyre::Result;
/// Enables parallel hashing when there are at least 8 nodes (4 pairs to hash), justifying the overhead.
pub const MIN_PARALLEL_NODES: usize = 8;
/// In the Hasher trait we define the node type, the default leaf
/// and the hash function used to initialize a Merkle Tree implementation
pub trait Hasher {
/// Type of the leaf and tree node
type Fr: Clone + Copy + Eq + Default + std::fmt::Debug + std::fmt::Display + FromStr;
type Fr: Clone + Copy + Eq + Default + Debug + Display + FromStr + Send + Sync;
/// Returns the default tree leaf
fn default_leaf() -> Self::Fr;
@@ -39,35 +44,52 @@ pub trait ZerokitMerkleTree {
type Hasher: Hasher;
type Config: Default + FromStr;
fn default(depth: usize) -> Result<Self>
fn default(depth: usize) -> Result<Self, ZerokitMerkleTreeError>
where
Self: Sized;
fn new(depth: usize, default_leaf: FrOf<Self::Hasher>, config: Self::Config) -> Result<Self>
fn new(
depth: usize,
default_leaf: FrOf<Self::Hasher>,
config: Self::Config,
) -> Result<Self, ZerokitMerkleTreeError>
where
Self: Sized;
fn depth(&self) -> usize;
fn capacity(&self) -> usize;
fn leaves_set(&self) -> usize;
fn root(&self) -> FrOf<Self::Hasher>;
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>>;
fn get_subtree_root(&self, n: usize, index: usize) -> Result<FrOf<Self::Hasher>>;
fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>) -> Result<()>;
fn set_range<I>(&mut self, start: usize, leaves: I) -> Result<()>
fn get_subtree_root(
&self,
n: usize,
index: usize,
) -> Result<FrOf<Self::Hasher>, ZerokitMerkleTreeError>;
fn set(&mut self, index: usize, leaf: FrOf<Self::Hasher>)
-> Result<(), ZerokitMerkleTreeError>;
fn set_range<I>(&mut self, start: usize, leaves: I) -> Result<(), ZerokitMerkleTreeError>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>;
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>>;
I: ExactSizeIterator<Item = FrOf<Self::Hasher>>;
fn get(&self, index: usize) -> Result<FrOf<Self::Hasher>, ZerokitMerkleTreeError>;
fn get_empty_leaves_indices(&self) -> Vec<usize>;
fn override_range<I, J>(&mut self, start: usize, leaves: I, to_remove_indices: J) -> Result<()>
fn override_range<I, J>(
&mut self,
start: usize,
leaves: I,
to_remove_indices: J,
) -> Result<(), ZerokitMerkleTreeError>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>;
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<()>;
fn delete(&mut self, index: usize) -> Result<()>;
fn proof(&self, index: usize) -> Result<Self::Proof>;
fn verify(&self, leaf: &FrOf<Self::Hasher>, witness: &Self::Proof) -> Result<bool>;
fn set_metadata(&mut self, metadata: &[u8]) -> Result<()>;
fn metadata(&self) -> Result<Vec<u8>>;
fn close_db_connection(&mut self) -> Result<()>;
I: ExactSizeIterator<Item = FrOf<Self::Hasher>>,
J: ExactSizeIterator<Item = usize>;
fn update_next(&mut self, leaf: FrOf<Self::Hasher>) -> Result<(), ZerokitMerkleTreeError>;
fn delete(&mut self, index: usize) -> Result<(), ZerokitMerkleTreeError>;
fn proof(&self, index: usize) -> Result<Self::Proof, ZerokitMerkleTreeError>;
fn verify(
&self,
leaf: &FrOf<Self::Hasher>,
witness: &Self::Proof,
) -> Result<bool, ZerokitMerkleTreeError>;
fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), ZerokitMerkleTreeError>;
fn metadata(&self) -> Result<Vec<u8>, ZerokitMerkleTreeError>;
fn close_db_connection(&mut self) -> Result<(), ZerokitMerkleTreeError>;
}
pub trait ZerokitMerkleProof {

View File

@@ -1,7 +1,11 @@
pub mod error;
pub mod full_merkle_tree;
#[allow(clippy::module_inception)]
pub mod merkle_tree;
pub mod optimal_merkle_tree;
pub use self::full_merkle_tree::*;
pub use self::merkle_tree::*;
pub use self::optimal_merkle_tree::*;
pub use self::full_merkle_tree::{FullMerkleConfig, FullMerkleProof, FullMerkleTree};
pub use self::merkle_tree::{
FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES,
};
pub use self::optimal_merkle_tree::{OptimalMerkleConfig, OptimalMerkleProof, OptimalMerkleTree};

View File

@@ -1,10 +1,11 @@
use crate::merkle_tree::{Hasher, ZerokitMerkleProof, ZerokitMerkleTree};
use crate::FrOf;
use color_eyre::{Report, Result};
use std::collections::HashMap;
use std::str::FromStr;
use std::{cmp::max, fmt::Debug};
use std::{cmp::max, collections::HashMap, fmt::Debug, str::FromStr};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use crate::merkle_tree::{
error::{FromConfigError, ZerokitMerkleTreeError},
FrOf, Hasher, ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES,
};
////////////////////////////////////////////////////////////
///// Optimal Merkle Tree Implementation
////////////////////////////////////////////////////////////
@@ -31,11 +32,11 @@ where
/// Set to 0 if the leaf is empty and set to 1 in otherwise.
cached_leaves_indices: Vec<u8>,
// The next available (i.e., never used) tree index. Equivalently, the number of leaves added to the tree
// (deletions leave next_index unchanged)
/// The next available (i.e., never used) tree index. Equivalently, the number of leaves added to the tree
/// (deletions leave next_index unchanged)
next_index: usize,
// metadata that an application may use to store additional information
/// metadata that an application may use to store additional information
metadata: Vec<u8>,
}
@@ -48,17 +49,14 @@ pub struct OptimalMerkleProof<H: Hasher>(pub Vec<(H::Fr, u8)>);
pub struct OptimalMerkleConfig(());
impl FromStr for OptimalMerkleConfig {
type Err = Report;
type Err = FromConfigError;
fn from_str(_s: &str) -> Result<Self> {
fn from_str(_s: &str) -> Result<Self, Self::Err> {
Ok(OptimalMerkleConfig::default())
}
}
////////////////////////////////////////////////////////////
///// Implementations
////////////////////////////////////////////////////////////
/// Implementations
impl<H: Hasher> ZerokitMerkleTree for OptimalMerkleTree<H>
where
H: Hasher,
@@ -67,60 +65,86 @@ where
type Hasher = H;
type Config = OptimalMerkleConfig;
fn default(depth: usize) -> Result<Self> {
fn default(depth: usize) -> Result<Self, ZerokitMerkleTreeError> {
OptimalMerkleTree::<H>::new(depth, H::default_leaf(), Self::Config::default())
}
/// Creates a new `MerkleTree`
/// depth - the height of the tree made only of hash nodes. 2^depth is the maximum number of leaves hash nodes
fn new(depth: usize, default_leaf: H::Fr, _config: Self::Config) -> Result<Self> {
fn new(
depth: usize,
default_leaf: H::Fr,
_config: Self::Config,
) -> Result<Self, ZerokitMerkleTreeError> {
// Compute cache node values, leaf to root
let mut cached_nodes: Vec<H::Fr> = Vec::with_capacity(depth + 1);
cached_nodes.push(default_leaf);
for i in 0..depth {
cached_nodes.push(H::hash(&[cached_nodes[i]; 2]));
}
cached_nodes.reverse();
Ok(OptimalMerkleTree {
cached_nodes: cached_nodes.clone(),
depth,
nodes: HashMap::new(),
cached_nodes,
nodes: HashMap::with_capacity(1 << depth),
cached_leaves_indices: vec![0; 1 << depth],
next_index: 0,
metadata: Vec::new(),
})
}
fn close_db_connection(&mut self) -> Result<()> {
fn close_db_connection(&mut self) -> Result<(), ZerokitMerkleTreeError> {
Ok(())
}
// Returns the depth of the tree
/// Returns the depth of the tree
fn depth(&self) -> usize {
self.depth
}
// Returns the capacity of the tree, i.e. the maximum number of accumulatable leaves
/// Returns the capacity of the tree, i.e. the maximum number of accumulatable leaves
fn capacity(&self) -> usize {
1 << self.depth
}
// Returns the total number of leaves set
/// Returns the total number of leaves set
fn leaves_set(&self) -> usize {
self.next_index
}
#[must_use]
// Returns the root of the tree
/// Returns the root of the tree
fn root(&self) -> H::Fr {
self.get_node(0, 0)
}
fn get_subtree_root(&self, n: usize, index: usize) -> Result<H::Fr> {
/// Sets a leaf at the specified tree index
fn set(&mut self, index: usize, leaf: H::Fr) -> Result<(), ZerokitMerkleTreeError> {
if index >= self.capacity() {
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
self.nodes.insert((self.depth, index), leaf);
self.update_hashes(index, 1)?;
self.next_index = max(self.next_index, index + 1);
self.cached_leaves_indices[index] = 1;
Ok(())
}
/// Get a leaf from the specified tree index
fn get(&self, index: usize) -> Result<H::Fr, ZerokitMerkleTreeError> {
if index >= self.capacity() {
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
Ok(self.get_node(self.depth, index))
}
/// Returns the root of the subtree at level n and index
fn get_subtree_root(&self, n: usize, index: usize) -> Result<H::Fr, ZerokitMerkleTreeError> {
if n > self.depth() {
return Err(Report::msg("level exceeds depth size"));
return Err(ZerokitMerkleTreeError::InvalidLevel);
}
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
if n == 0 {
Ok(self.root())
@@ -131,26 +155,7 @@ where
}
}
// Sets a leaf at the specified tree index
fn set(&mut self, index: usize, leaf: H::Fr) -> Result<()> {
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
}
self.nodes.insert((self.depth, index), leaf);
self.recalculate_from(index)?;
self.next_index = max(self.next_index, index + 1);
self.cached_leaves_indices[index] = 1;
Ok(())
}
// Get a leaf from the specified tree index
fn get(&self, index: usize) -> Result<H::Fr> {
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
}
Ok(self.get_node(self.depth, index))
}
/// Returns the indices of the leaves that are empty
fn get_empty_leaves_indices(&self) -> Vec<usize> {
self.cached_leaves_indices
.iter()
@@ -161,29 +166,39 @@ where
.collect()
}
// Sets multiple leaves from the specified tree index
fn set_range<I: IntoIterator<Item = H::Fr>>(&mut self, start: usize, leaves: I) -> Result<()> {
let leaves = leaves.into_iter().collect::<Vec<_>>();
/// Sets multiple leaves from the specified tree index
fn set_range<I: ExactSizeIterator<Item = H::Fr>>(
&mut self,
start: usize,
leaves: I,
) -> Result<(), ZerokitMerkleTreeError> {
// check if the range is valid
if start + leaves.len() > self.capacity() {
return Err(Report::msg("provided range exceeds set size"));
let leaves_len = leaves.len();
if start + leaves_len > self.capacity() {
return Err(ZerokitMerkleTreeError::TooManySet);
}
for (i, leaf) in leaves.iter().enumerate() {
self.nodes.insert((self.depth, start + i), *leaf);
for (i, leaf) in leaves.enumerate() {
self.nodes.insert((self.depth, start + i), leaf);
self.cached_leaves_indices[start + i] = 1;
self.recalculate_from(start + i)?;
}
self.next_index = max(self.next_index, start + leaves.len());
self.update_hashes(start, leaves_len)?;
self.next_index = max(self.next_index, start + leaves_len);
Ok(())
}
fn override_range<I, J>(&mut self, start: usize, leaves: I, indices: J) -> Result<()>
/// Overrides a range of leaves while resetting specified indices to default and preserving unaffected values.
fn override_range<I, J>(
&mut self,
start: usize,
leaves: I,
indices: J,
) -> Result<(), ZerokitMerkleTreeError>
where
I: IntoIterator<Item = FrOf<Self::Hasher>>,
J: IntoIterator<Item = usize>,
I: ExactSizeIterator<Item = FrOf<Self::Hasher>>,
J: ExactSizeIterator<Item = usize>,
{
let indices = indices.into_iter().collect::<Vec<_>>();
let min_index = *indices.first().unwrap();
let min_index = *indices.first().expect("indices should not be empty");
let leaves_vec = leaves.into_iter().collect::<Vec<_>>();
let max_index = start + leaves_vec.len();
@@ -192,7 +207,7 @@ where
for i in min_index..start {
if !indices.contains(&i) {
let value = self.get_leaf(i);
let value = self.get(i)?;
set_values[i - min_index] = value;
}
}
@@ -205,18 +220,17 @@ where
self.cached_leaves_indices[i] = 0;
}
self.set_range(start, set_values)
.map_err(|e| Report::msg(e.to_string()))
self.set_range(start, set_values.into_iter())
}
// Sets a leaf at the next available index
fn update_next(&mut self, leaf: H::Fr) -> Result<()> {
/// Sets a leaf at the next available index
fn update_next(&mut self, leaf: H::Fr) -> Result<(), ZerokitMerkleTreeError> {
self.set(self.next_index, leaf)?;
Ok(())
}
// Deletes a leaf at a certain index by setting it to its default value (next_index is not updated)
fn delete(&mut self, index: usize) -> Result<()> {
/// Deletes a leaf at a certain index by setting it to its default value (next_index is not updated)
fn delete(&mut self, index: usize) -> Result<(), ZerokitMerkleTreeError> {
// We reset the leaf only if we previously set a leaf at that index
if index < self.next_index {
self.set(index, H::default_leaf())?;
@@ -225,17 +239,20 @@ where
Ok(())
}
// Computes a merkle proof the leaf at the specified index
fn proof(&self, index: usize) -> Result<Self::Proof> {
/// Computes a merkle proof the leaf at the specified index
fn proof(&self, index: usize) -> Result<Self::Proof, ZerokitMerkleTreeError> {
if index >= self.capacity() {
return Err(Report::msg("index exceeds set size"));
return Err(ZerokitMerkleTreeError::InvalidLeaf);
}
let mut witness = Vec::<(H::Fr, u8)>::with_capacity(self.depth);
let mut i = index;
let mut depth = self.depth;
loop {
i ^= 1;
witness.push((self.get_node(depth, i), (1 - (i & 1)).try_into().unwrap()));
witness.push((
self.get_node(depth, i),
(1 - (i & 1)).try_into().expect("0 or 1 expected"),
));
i >>= 1;
depth -= 1;
if depth == 0 {
@@ -243,78 +260,102 @@ where
}
}
if i != 0 {
Err(Report::msg("i != 0"))
Err(ZerokitMerkleTreeError::ComputingProofError)
} else {
Ok(OptimalMerkleProof(witness))
}
}
// Verifies a Merkle proof with respect to the input leaf and the tree root
fn verify(&self, leaf: &H::Fr, witness: &Self::Proof) -> Result<bool> {
/// Verifies a Merkle proof with respect to the input leaf and the tree root
fn verify(&self, leaf: &H::Fr, witness: &Self::Proof) -> Result<bool, ZerokitMerkleTreeError> {
if witness.length() != self.depth {
return Err(Report::msg("witness length doesn't match tree depth"));
return Err(ZerokitMerkleTreeError::InvalidWitness);
}
let expected_root = witness.compute_root_from(leaf);
Ok(expected_root.eq(&self.root()))
}
fn compute_root(&mut self) -> Result<FrOf<Self::Hasher>> {
self.recalculate_from(0)?;
Ok(self.root())
}
fn set_metadata(&mut self, metadata: &[u8]) -> Result<()> {
fn set_metadata(&mut self, metadata: &[u8]) -> Result<(), ZerokitMerkleTreeError> {
self.metadata = metadata.to_vec();
Ok(())
}
fn metadata(&self) -> Result<Vec<u8>> {
fn metadata(&self) -> Result<Vec<u8>, ZerokitMerkleTreeError> {
Ok(self.metadata.to_vec())
}
}
// Utilities for updating the tree nodes
impl<H: Hasher> OptimalMerkleTree<H>
where
H: Hasher,
{
// Utilities for updating the tree nodes
/// Returns the value of a node at a specific (depth, index).
/// Falls back to a cached default if the node hasn't been set.
fn get_node(&self, depth: usize, index: usize) -> H::Fr {
let node = *self
*self
.nodes
.get(&(depth, index))
.unwrap_or_else(|| &self.cached_nodes[depth]);
node
.unwrap_or_else(|| &self.cached_nodes[depth])
}
pub fn get_leaf(&self, index: usize) -> H::Fr {
self.get_node(self.depth, index)
}
fn hash_couple(&mut self, depth: usize, index: usize) -> H::Fr {
/// Computes the hash of a nodes two children at the given depth.
/// If the index is odd, it is rounded down to the nearest even index.
fn hash_couple(&self, depth: usize, index: usize) -> H::Fr {
let b = index & !1;
H::hash(&[self.get_node(depth, b), self.get_node(depth, b + 1)])
}
fn recalculate_from(&mut self, index: usize) -> Result<()> {
let mut i = index;
let mut depth = self.depth;
loop {
let h = self.hash_couple(depth, i);
i >>= 1;
depth -= 1;
self.nodes.insert((depth, i), h);
self.cached_leaves_indices[index] = 1;
if depth == 0 {
break;
/// Updates parent hashes after modifying a range of leaf nodes.
///
/// - `start`: Starting leaf index that was updated.
/// - `length`: Number of consecutive leaves that were updated.
fn update_hashes(&mut self, start: usize, length: usize) -> Result<(), ZerokitMerkleTreeError> {
// Start at the leaf level
let mut current_depth = self.depth;
// Round down to include the left sibling in the pair (if start is odd)
let mut current_index = start & !1;
// Compute the max index at this level, round up to include the last updated leafs right sibling (if start + length is odd)
let mut current_index_max = (start + length + 1) & !1;
// Traverse from the leaf level up to the root
while current_depth > 0 {
// Compute the parent level (one level above the current)
let parent_depth = current_depth - 1;
// Use parallel processing when the number of pairs exceeds the threshold
if current_index_max - current_index >= MIN_PARALLEL_NODES {
let updates: Vec<((usize, usize), H::Fr)> = (current_index..current_index_max)
.step_by(2)
.collect::<Vec<_>>()
.into_par_iter()
.map(|index| {
// Hash two child nodes at positions (current_depth, index) and (current_depth, index + 1)
let hash = self.hash_couple(current_depth, index);
// Return the computed parent hash and its position at
((parent_depth, index >> 1), hash)
})
.collect();
for (parent, hash) in updates {
self.nodes.insert(parent, hash);
}
} else {
// Otherwise, fallback to sequential update for small ranges
for index in (current_index..current_index_max).step_by(2) {
let hash = self.hash_couple(current_depth, index);
self.nodes.insert((parent_depth, index >> 1), hash);
}
}
// Move up one level in the tree
current_index >>= 1;
current_index_max = (current_index_max + 1) >> 1;
current_depth -= 1;
}
if depth != 0 {
return Err(Report::msg("did not reach the depth"));
}
if i != 0 {
return Err(Report::msg("did not go through all indexes"));
}
Ok(())
}
}
@@ -326,14 +367,12 @@ where
type Index = u8;
type Hasher = H;
#[must_use]
// Returns the length of a Merkle proof
/// Returns the length of a Merkle proof
fn length(&self) -> usize {
self.0.len()
}
/// Computes the leaf index corresponding to a Merkle proof
#[must_use]
fn leaf_index(&self) -> usize {
// In current implementation the path indexes in a proof correspond to the binary representation of the leaf index
let mut binary_repr = self.get_path_index();
@@ -343,19 +382,16 @@ where
.fold(0, |acc, digit| (acc << 1) + usize::from(digit))
}
#[must_use]
/// Returns the path elements forming a Merkle proof
fn get_path_elements(&self) -> Vec<H::Fr> {
self.0.iter().map(|x| x.0).collect()
}
/// Returns the path indexes forming a Merkle proof
#[must_use]
fn get_path_index(&self) -> Vec<u8> {
self.0.iter().map(|x| x.1).collect()
}
#[must_use]
/// Computes the Merkle root corresponding by iteratively hashing a Merkle proof with a given input leaf
fn compute_root_from(&self, leaf: &H::Fr) -> H::Fr {
let mut acc: H::Fr = *leaf;

View File

@@ -1,4 +1,4 @@
pub mod sled_adapter;
pub use self::sled_adapter::*;
pub use self::sled_adapter::SledDB;
pub use pmtree;
pub use sled::*;
pub use sled::{Config, Mode};

View File

@@ -1,5 +1,4 @@
pub mod poseidon_hash;
pub use self::poseidon_hash::*;
pub use poseidon_hash::Poseidon;
pub mod poseidon_constants;
pub use self::poseidon_constants::*;

View File

@@ -9,8 +9,6 @@
// The following implementation was adapted from https://github.com/arkworks-rs/sponge/blob/7d9b3a474c9ddb62890014aeaefcb142ac2b3776/src/poseidon/grain_lfsr.rs
#![allow(dead_code)]
use ark_ff::PrimeField;
use num_bigint::BigUint;
@@ -20,7 +18,6 @@ pub struct PoseidonGrainLFSR {
pub head: usize,
}
#[allow(unused_variables)]
impl PoseidonGrainLFSR {
pub fn new(
is_field: u64,

View File

@@ -7,7 +7,7 @@ use crate::poseidon_constants::find_poseidon_ark_and_mds;
use ark_ff::PrimeField;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RoundParamenters<F: PrimeField> {
pub struct RoundParameters<F: PrimeField> {
pub t: usize,
pub n_rounds_f: usize,
pub n_rounds_p: usize,
@@ -17,16 +17,16 @@ pub struct RoundParamenters<F: PrimeField> {
}
pub struct Poseidon<F: PrimeField> {
round_params: Vec<RoundParamenters<F>>,
round_params: Vec<RoundParameters<F>>,
}
impl<F: PrimeField> Poseidon<F> {
// Loads round parameters and generates round constants
// poseidon_params is a vector containing tuples (t, RF, RP, skip_matrices)
// where: t is the rate (input lenght + 1), RF is the number of full rounds, RP is the number of partial rounds
// where: t is the rate (input length + 1), RF is the number of full rounds, RP is the number of partial rounds
// and skip_matrices is a (temporary) parameter used to generate secure MDS matrices (see comments in the description of find_poseidon_ark_and_mds)
// TODO: implement automatic generation of round parameters
pub fn from(poseidon_params: &[(usize, usize, usize, usize)]) -> Self {
let mut read_params = Vec::<RoundParamenters<F>>::new();
let mut read_params = Vec::<RoundParameters<F>>::with_capacity(poseidon_params.len());
for &(t, n_rounds_f, n_rounds_p, skip_matrices) in poseidon_params {
let (ark, mds) = find_poseidon_ark_and_mds::<F>(
@@ -38,7 +38,7 @@ impl<F: PrimeField> Poseidon<F> {
n_rounds_p as u64,
skip_matrices,
);
let rp = RoundParamenters {
let rp = RoundParameters {
t,
n_rounds_p,
n_rounds_f,
@@ -54,24 +54,24 @@ impl<F: PrimeField> Poseidon<F> {
}
}
pub fn get_parameters(&self) -> Vec<RoundParamenters<F>> {
self.round_params.clone()
pub fn get_parameters(&self) -> &Vec<RoundParameters<F>> {
&self.round_params
}
pub fn ark(&self, state: &mut [F], c: &[F], it: usize) {
for i in 0..state.len() {
state[i] += c[it + i];
}
state.iter_mut().enumerate().for_each(|(i, elem)| {
*elem += c[it + i];
});
}
pub fn sbox(&self, n_rounds_f: usize, n_rounds_p: usize, state: &mut [F], i: usize) {
if (i < n_rounds_f / 2) || (i >= n_rounds_f / 2 + n_rounds_p) {
for current_state in &mut state.iter_mut() {
state.iter_mut().for_each(|current_state| {
let aux = *current_state;
*current_state *= *current_state;
*current_state *= *current_state;
*current_state *= aux;
}
})
} else {
let aux = state[0];
state[0] *= state[0];
@@ -80,21 +80,20 @@ impl<F: PrimeField> Poseidon<F> {
}
}
pub fn mix(&self, state: &[F], m: &[Vec<F>]) -> Vec<F> {
let mut new_state: Vec<F> = Vec::new();
pub fn mix_2(&self, state: &[F], m: &[Vec<F>], state_2: &mut [F]) {
for i in 0..state.len() {
new_state.push(F::zero());
for (j, state_item) in state.iter().enumerate() {
let mut mij = m[i][j];
mij *= state_item;
new_state[i] += mij;
// Cache the row reference
let row = &m[i];
let mut acc = F::ZERO;
for j in 0..state.len() {
acc += row[j] * state[j];
}
state_2[i] = acc;
}
new_state.clone()
}
pub fn hash(&self, inp: Vec<F>) -> Result<F, String> {
// Note that the rate t becomes input lenght + 1, hence for lenght N we pick parameters with T = N + 1
pub fn hash(&self, inp: &[F]) -> Result<F, String> {
// Note that the rate t becomes input length + 1; hence for length N we pick parameters with T = N + 1
let t = inp.len() + 1;
// We seek the index (Poseidon's round_params is an ordered vector) for the parameters corresponding to t
@@ -106,8 +105,9 @@ impl<F: PrimeField> Poseidon<F> {
let param_index = param_index.unwrap();
let mut state = vec![F::zero(); t];
state[1..].clone_from_slice(&inp);
let mut state = vec![F::ZERO; t];
let mut state_2 = state.clone();
state[1..].clone_from_slice(inp);
for i in 0..(self.round_params[param_index].n_rounds_f
+ self.round_params[param_index].n_rounds_p)
@@ -123,7 +123,8 @@ impl<F: PrimeField> Poseidon<F> {
&mut state,
i,
);
state = self.mix(&state, &self.round_params[param_index].m);
self.mix_2(&state, &self.round_params[param_index].m, &mut state_2);
std::mem::swap(&mut state, &mut state_2);
}
Ok(state[0])

View File

@@ -6,7 +6,7 @@ pub mod test {
use tiny_keccak::{Hasher as _, Keccak};
use zerokit_utils::{
FullMerkleConfig, FullMerkleTree, Hasher, OptimalMerkleConfig, OptimalMerkleTree,
ZerokitMerkleProof, ZerokitMerkleTree,
ZerokitMerkleProof, ZerokitMerkleTree, MIN_PARALLEL_NODES,
};
#[derive(Clone, Copy, Eq, PartialEq)]
struct Keccak256;
@@ -42,7 +42,7 @@ pub mod test {
type Err = std::string::FromUtf8Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(TestFr(s.as_bytes().try_into().unwrap()))
Ok(TestFr(s.as_bytes().try_into().expect("Invalid length")))
}
}
@@ -50,7 +50,7 @@ pub mod test {
fn from(value: u32) -> Self {
let mut bytes: Vec<u8> = vec![0; 28];
bytes.extend_from_slice(&value.to_be_bytes());
TestFr(bytes.as_slice().try_into().unwrap())
TestFr(bytes.as_slice().try_into().expect("Invalid length"))
}
}
@@ -58,12 +58,12 @@ pub mod test {
fn default_full_merkle_tree(depth: usize) -> FullMerkleTree<Keccak256> {
FullMerkleTree::<Keccak256>::new(depth, TestFr([0; 32]), FullMerkleConfig::default())
.unwrap()
.expect("Failed to create FullMerkleTree")
}
fn default_optimal_merkle_tree(depth: usize) -> OptimalMerkleTree<Keccak256> {
OptimalMerkleTree::<Keccak256>::new(depth, TestFr([0; 32]), OptimalMerkleConfig::default())
.unwrap()
.expect("Failed to create OptimalMerkleTree")
}
#[test]
@@ -83,21 +83,101 @@ pub mod test {
let nof_leaves = 4;
let leaves: Vec<TestFr> = (1..=nof_leaves as u32).map(TestFr::from).collect();
let mut tree = default_full_merkle_tree(DEFAULT_DEPTH);
assert_eq!(tree.root(), default_tree_root);
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
assert_eq!(tree_full.root(), default_tree_root);
for i in 0..nof_leaves {
tree.set(i, leaves[i]).unwrap();
assert_eq!(tree.root(), roots[i]);
tree_full.set(i, leaves[i]).expect("Failed to set leaf");
assert_eq!(tree_full.root(), roots[i]);
}
let mut tree = default_optimal_merkle_tree(DEFAULT_DEPTH);
assert_eq!(tree.root(), default_tree_root);
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
assert_eq!(tree_opt.root(), default_tree_root);
for i in 0..nof_leaves {
tree.set(i, leaves[i]).unwrap();
assert_eq!(tree.root(), roots[i]);
tree_opt.set(i, leaves[i]).expect("Failed to set leaf");
assert_eq!(tree_opt.root(), roots[i]);
}
}
#[test]
fn test_set_range() {
let depth = 4;
let leaves: Vec<TestFr> = (0..(1 << depth) as u32).map(TestFr::from).collect();
let mut tree_full = default_full_merkle_tree(depth);
let root_before = tree_full.root();
tree_full
.set_range(0, leaves.iter().cloned())
.expect("Failed to set leaves");
let root_after = tree_full.root();
assert_ne!(root_before, root_after);
let mut tree_opt = default_optimal_merkle_tree(depth);
let root_before = tree_opt.root();
tree_opt
.set_range(0, leaves.iter().cloned())
.expect("Failed to set leaves");
let root_after = tree_opt.root();
assert_ne!(root_before, root_after);
}
#[test]
fn test_update_next() {
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
for i in 0..4 {
let leaf = TestFr::from(i as u32);
tree_full.update_next(leaf).expect("Failed to update leaf");
tree_opt.update_next(leaf).expect("Failed to update leaf");
assert_eq!(tree_full.get(i).expect("Failed to get leaf"), leaf);
assert_eq!(tree_opt.get(i).expect("Failed to get leaf"), leaf);
}
assert_eq!(tree_full.leaves_set(), 4);
assert_eq!(tree_opt.leaves_set(), 4);
}
#[test]
fn test_delete_and_reset() {
let index = 1;
let original_leaf = TestFr::from(42);
let new_leaf = TestFr::from(99);
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
tree_full
.set(index, original_leaf)
.expect("Failed to set leaf");
let root_with_original = tree_full.root();
tree_full.delete(index).expect("Failed to delete leaf");
let root_after_delete = tree_full.root();
assert_ne!(root_with_original, root_after_delete);
tree_full.set(index, new_leaf).expect("Failed to set leaf");
let root_after_reset = tree_full.root();
assert_ne!(root_after_delete, root_after_reset);
assert_ne!(root_with_original, root_after_reset);
assert_eq!(tree_full.get(index).expect("Failed to get leaf"), new_leaf);
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
tree_opt
.set(index, original_leaf)
.expect("Failed to set leaf");
let root_with_original = tree_opt.root();
tree_opt.delete(index).expect("Failed to delete leaf");
let root_after_delete = tree_opt.root();
assert_ne!(root_with_original, root_after_delete);
tree_opt.set(index, new_leaf).expect("Failed to set leaf");
let root_after_reset = tree_opt.root();
assert_ne!(root_after_delete, root_after_reset);
assert_ne!(root_with_original, root_after_reset);
assert_eq!(tree_opt.get(index).expect("Failed to get leaf"), new_leaf);
}
#[test]
fn test_get_empty_leaves_indices() {
let depth = 4;
@@ -107,7 +187,7 @@ pub mod test {
let leaves_4: Vec<TestFr> = (0u32..4).map(TestFr::from).collect();
let mut tree_full = default_full_merkle_tree(depth);
let _ = tree_full.set_range(0, leaves.clone());
let _ = tree_full.set_range(0, leaves.clone().into_iter());
assert!(tree_full.get_empty_leaves_indices().is_empty());
let mut vec_idxs = Vec::new();
@@ -123,33 +203,31 @@ pub mod test {
assert_eq!(tree_full.get_empty_leaves_indices(), vec_idxs);
}
// Check situation when the number of items to insert is less than the number of items to delete
// check situation when the number of items to insert is less than the number of items to delete
tree_full
.override_range(0, leaves_2.clone(), [0, 1, 2, 3])
.unwrap();
.override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
// check if the indexes for write and delete are the same
tree_full
.override_range(0, leaves_4.clone(), [0, 1, 2, 3])
.unwrap();
assert_eq!(tree_full.get_empty_leaves_indices(), vec![]);
.override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
assert_eq!(tree_full.get_empty_leaves_indices(), Vec::<usize>::new());
// check if indexes for deletion are before indexes for overwriting
tree_full
.override_range(4, leaves_4.clone(), [0, 1, 2, 3])
.unwrap();
.override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
assert_eq!(tree_full.get_empty_leaves_indices(), vec![0, 1, 2, 3]);
// check if the indices for write and delete do not overlap completely
tree_full
.override_range(2, leaves_4.clone(), [0, 1, 2, 3])
.unwrap();
.override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
assert_eq!(tree_full.get_empty_leaves_indices(), vec![0, 1]);
//// Optimal Merkle Tree Trest
let mut tree_opt = default_optimal_merkle_tree(depth);
let _ = tree_opt.set_range(0, leaves.clone());
let _ = tree_opt.set_range(0, leaves.clone().into_iter());
assert!(tree_opt.get_empty_leaves_indices().is_empty());
let mut vec_idxs = Vec::new();
@@ -164,48 +242,55 @@ pub mod test {
assert_eq!(tree_opt.get_empty_leaves_indices(), vec_idxs);
}
// Check situation when the number of items to insert is less than the number of items to delete
// check situation when the number of items to insert is less than the number of items to delete
tree_opt
.override_range(0, leaves_2.clone(), [0, 1, 2, 3])
.unwrap();
.override_range(0, leaves_2.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
// check if the indexes for write and delete are the same
tree_opt
.override_range(0, leaves_4.clone(), [0, 1, 2, 3])
.unwrap();
assert_eq!(tree_opt.get_empty_leaves_indices(), vec![]);
.override_range(0, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
assert_eq!(tree_opt.get_empty_leaves_indices(), Vec::<usize>::new());
// check if indexes for deletion are before indexes for overwriting
tree_opt
.override_range(4, leaves_4.clone(), [0, 1, 2, 3])
.unwrap();
.override_range(4, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
assert_eq!(tree_opt.get_empty_leaves_indices(), vec![0, 1, 2, 3]);
// check if the indices for write and delete do not overlap completely
tree_opt
.override_range(2, leaves_4.clone(), [0, 1, 2, 3])
.unwrap();
.override_range(2, leaves_4.clone().into_iter(), [0, 1, 2, 3].into_iter())
.expect("Failed to override range");
assert_eq!(tree_opt.get_empty_leaves_indices(), vec![0, 1]);
}
#[test]
fn test_subtree_root() {
let depth = 3;
let nof_leaves: usize = 6;
let nof_leaves: usize = 4;
let leaves: Vec<TestFr> = (0..nof_leaves as u32).map(TestFr::from).collect();
let mut tree_full = default_optimal_merkle_tree(depth);
let mut tree_full = default_full_merkle_tree(depth);
let _ = tree_full.set_range(0, leaves.iter().cloned());
for i in 0..nof_leaves {
// check leaves
assert_eq!(
tree_full.get(i).unwrap(),
tree_full.get_subtree_root(depth, i).unwrap()
tree_full.get(i).expect("Failed to get leaf"),
tree_full
.get_subtree_root(depth, i)
.expect("Failed to get subtree root")
);
// check root
assert_eq!(tree_full.root(), tree_full.get_subtree_root(0, i).unwrap());
assert_eq!(
tree_full.root(),
tree_full
.get_subtree_root(0, i)
.expect("Failed to get subtree root")
);
}
// check intermediate nodes
@@ -215,26 +300,39 @@ pub mod test {
let idx_r = (i + 1) * (1 << (depth - n));
let idx_sr = idx_l;
let prev_l = tree_full.get_subtree_root(n, idx_l).unwrap();
let prev_r = tree_full.get_subtree_root(n, idx_r).unwrap();
let subroot = tree_full.get_subtree_root(n - 1, idx_sr).unwrap();
let prev_l = tree_full
.get_subtree_root(n, idx_l)
.expect("Failed to get subtree root");
let prev_r = tree_full
.get_subtree_root(n, idx_r)
.expect("Failed to get subtree root");
let subroot = tree_full
.get_subtree_root(n - 1, idx_sr)
.expect("Failed to get subtree root");
// check intermediate nodes
assert_eq!(Keccak256::hash(&[prev_l, prev_r]), subroot);
}
}
let mut tree_opt = default_full_merkle_tree(depth);
let mut tree_opt = default_optimal_merkle_tree(depth);
let _ = tree_opt.set_range(0, leaves.iter().cloned());
for i in 0..nof_leaves {
// check leaves
assert_eq!(
tree_opt.get(i).unwrap(),
tree_opt.get_subtree_root(depth, i).unwrap()
tree_opt.get(i).expect("Failed to get leaf"),
tree_opt
.get_subtree_root(depth, i)
.expect("Failed to get subtree root")
);
// check root
assert_eq!(tree_opt.root(), tree_opt.get_subtree_root(0, i).unwrap());
assert_eq!(
tree_opt.root(),
tree_opt
.get_subtree_root(0, i)
.expect("Failed to get subtree root")
);
}
// check intermediate nodes
@@ -244,9 +342,15 @@ pub mod test {
let idx_r = (i + 1) * (1 << (depth - n));
let idx_sr = idx_l;
let prev_l = tree_opt.get_subtree_root(n, idx_l).unwrap();
let prev_r = tree_opt.get_subtree_root(n, idx_r).unwrap();
let subroot = tree_opt.get_subtree_root(n - 1, idx_sr).unwrap();
let prev_l = tree_opt
.get_subtree_root(n, idx_l)
.expect("Failed to get subtree root");
let prev_r = tree_opt
.get_subtree_root(n, idx_r)
.expect("Failed to get subtree root");
let subroot = tree_opt
.get_subtree_root(n - 1, idx_sr)
.expect("Failed to get subtree root");
// check intermediate nodes
assert_eq!(Keccak256::hash(&[prev_l, prev_r]), subroot);
@@ -259,61 +363,83 @@ pub mod test {
let nof_leaves = 4;
let leaves: Vec<TestFr> = (0..nof_leaves as u32).map(TestFr::from).collect();
// We thest the FullMerkleTree implementation
let mut tree = default_full_merkle_tree(DEFAULT_DEPTH);
// We test the FullMerkleTree implementation
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
for i in 0..nof_leaves {
// We set the leaves
tree.set(i, leaves[i]).unwrap();
tree_full.set(i, leaves[i]).expect("Failed to set leaf");
// We compute a merkle proof
let proof = tree.proof(i).expect("index should be set");
let proof = tree_full.proof(i).expect("Failed to compute proof");
// We verify if the merkle proof corresponds to the right leaf index
assert_eq!(proof.leaf_index(), i);
// We verify the proof
assert!(tree.verify(&leaves[i], &proof).unwrap());
assert!(tree_full
.verify(&leaves[i], &proof)
.expect("Failed to verify proof"));
// We ensure that the Merkle proof and the leaf generate the same root as the tree
assert_eq!(proof.compute_root_from(&leaves[i]), tree.root());
assert_eq!(proof.compute_root_from(&leaves[i]), tree_full.root());
// We check that the proof is not valid for another leaf
assert!(!tree.verify(&leaves[(i + 1) % nof_leaves], &proof).unwrap());
assert!(!tree_full
.verify(&leaves[(i + 1) % nof_leaves], &proof)
.expect("Failed to verify proof"));
}
// We test the OptimalMerkleTree implementation
let mut tree = default_optimal_merkle_tree(DEFAULT_DEPTH);
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
for i in 0..nof_leaves {
// We set the leaves
tree.set(i, leaves[i]).unwrap();
tree_opt.set(i, leaves[i]).expect("Failed to set leaf");
// We compute a merkle proof
let proof = tree.proof(i).expect("index should be set");
let proof = tree_opt.proof(i).expect("Failed to compute proof");
// We verify if the merkle proof corresponds to the right leaf index
assert_eq!(proof.leaf_index(), i);
// We verify the proof
assert!(tree.verify(&leaves[i], &proof).unwrap());
assert!(tree_opt
.verify(&leaves[i], &proof)
.expect("Failed to verify proof"));
// We ensure that the Merkle proof and the leaf generate the same root as the tree
assert_eq!(proof.compute_root_from(&leaves[i]), tree.root());
assert_eq!(proof.compute_root_from(&leaves[i]), tree_opt.root());
// We check that the proof is not valid for another leaf
assert!(!tree.verify(&leaves[(i + 1) % nof_leaves], &proof).unwrap());
assert!(!tree_opt
.verify(&leaves[(i + 1) % nof_leaves], &proof)
.expect("Failed to verify proof"));
}
}
#[test]
fn test_proof_fail() {
let tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
let tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
let invalid_leaf = TestFr::from(12345);
let proof_full = tree_full.proof(0).expect("Failed to compute proof");
let proof_opt = tree_opt.proof(0).expect("Failed to compute proof");
// Should fail because no leaf was set
assert!(!tree_full
.verify(&invalid_leaf, &proof_full)
.expect("Failed to verify proof"));
assert!(!tree_opt
.verify(&invalid_leaf, &proof_opt)
.expect("Failed to verify proof"));
}
#[test]
fn test_override_range() {
let nof_leaves = 4;
let leaves: Vec<TestFr> = (0..nof_leaves as u32).map(TestFr::from).collect();
let mut tree = default_optimal_merkle_tree(DEFAULT_DEPTH);
// We set the leaves
tree.set_range(0, leaves.iter().cloned()).unwrap();
let new_leaves = [
hex!("0000000000000000000000000000000000000000000000000000000000000005"),
hex!("0000000000000000000000000000000000000000000000000000000000000006"),
@@ -322,17 +448,70 @@ pub mod test {
let to_delete_indices: [usize; 2] = [0, 1];
// We override the leaves
tree.override_range(
0, // start from the end of the initial leaves
new_leaves.iter().cloned(),
to_delete_indices.iter().cloned(),
)
.unwrap();
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
tree_full
.set_range(0, leaves.iter().cloned())
.expect("Failed to set leaves");
tree_full
.override_range(
0,
new_leaves.iter().cloned(),
to_delete_indices.iter().cloned(),
)
.expect("Failed to override range");
// ensure that the leaves are set correctly
for (i, &new_leaf) in new_leaves.iter().enumerate() {
assert_eq!(tree.get_leaf(i), new_leaf);
assert_eq!(tree_full.get(i).expect("Failed to get leaf"), new_leaf);
}
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
tree_opt
.set_range(0, leaves.iter().cloned())
.expect("Failed to set leaves");
tree_opt
.override_range(
0,
new_leaves.iter().cloned(),
to_delete_indices.iter().cloned(),
)
.expect("Failed to override range");
for (i, &new_leaf) in new_leaves.iter().enumerate() {
assert_eq!(tree_opt.get(i).expect("Failed to get leaf"), new_leaf);
}
}
#[test]
fn test_override_range_parallel_triggered() {
let depth = 13;
let nof_leaves = 8192;
// number of leaves larger than MIN_PARALLEL_NODES to trigger parallel hashing
assert!(MIN_PARALLEL_NODES < nof_leaves);
let leaves: Vec<TestFr> = (0..nof_leaves as u32).map(TestFr::from).collect();
let indices: Vec<usize> = (0..nof_leaves).collect();
let mut tree_full = default_full_merkle_tree(depth);
tree_full
.override_range(0, leaves.iter().cloned(), indices.iter().cloned())
.expect("Failed to override range");
for (i, &leaf) in leaves.iter().enumerate() {
assert_eq!(tree_full.get(i).expect("Failed to get leaf"), leaf);
}
let mut tree_opt = default_optimal_merkle_tree(depth);
tree_opt
.override_range(0, leaves.iter().cloned(), indices.iter().cloned())
.expect("Failed to override range");
for (i, &leaf) in leaves.iter().enumerate() {
assert_eq!(tree_opt.get(i).expect("Failed to get leaf"), leaf);
}
}
}

View File

@@ -3530,7 +3530,8 @@ mod test {
{
// We check if the round constants and matrices correspond to the one generated when instantiating Poseidon with ROUND_PARAMS
let (loaded_c, loaded_m) = load_constants();
let poseidon_parameters = Poseidon::<Fr>::from(&ROUND_PARAMS).get_parameters();
let poseidon_hasher = Poseidon::<Fr>::from(&ROUND_PARAMS);
let poseidon_parameters = poseidon_hasher.get_parameters();
for i in 0..poseidon_parameters.len() {
assert_eq!(loaded_c[i], poseidon_parameters[i].c);
assert_eq!(loaded_m[i], poseidon_parameters[i].m);

View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod test {
use ark_bn254::Fr;
use ark_ff::{AdditiveGroup, Field};
use std::collections::HashMap;
use std::str::FromStr;
use zerokit_utils::poseidon_hash::Poseidon;
const ROUND_PARAMS: [(usize, usize, usize, usize); 8] = [
(2, 8, 56, 0),
(3, 8, 57, 0),
(4, 8, 56, 0),
(5, 8, 60, 0),
(6, 8, 60, 0),
(7, 8, 63, 0),
(8, 8, 64, 0),
(9, 8, 63, 0),
];
#[test]
fn test_poseidon_hash_basic() {
let map = HashMap::from([
(
Fr::ZERO,
Fr::from_str(
"19014214495641488759237505126948346942972912379615652741039992445865937985820",
)
.unwrap(),
),
(
Fr::ONE,
Fr::from_str(
"18586133768512220936620570745912940619677854269274689475585506675881198879027",
)
.unwrap(),
),
(
Fr::from(255),
Fr::from_str(
"20026131459732984724454933360292530547665726761019872861025481903072111625788",
)
.unwrap(),
),
(
Fr::from(u16::MAX),
Fr::from_str(
"12358868638722666642632413418981275677998688723398440898957566982787708451243",
)
.unwrap(),
),
(
Fr::from(u64::MAX),
Fr::from_str(
"17449307747295017006142981453320720946812828330895590310359634430146721583189",
)
.unwrap(),
),
]);
// map (key: what to hash, value: expected value)
for (k, v) in map.into_iter() {
let hasher = Poseidon::<Fr>::from(&ROUND_PARAMS);
let h = hasher.hash(&[k]);
assert_eq!(h.unwrap(), v);
}
}
#[test]
fn test_poseidon_hash_multi() {
// All hashes done in a merkle tree (with leaves: [0, 1, 2, 3, 4, 5, 6, 7])
// ~ leaves
let fr_0 = Fr::ZERO;
let fr_1 = Fr::ONE;
let fr_2 = Fr::from(2);
let fr_3 = Fr::from(3);
let fr_4 = Fr::from(4);
let fr_5 = Fr::from(5);
let fr_6 = Fr::from(6);
let fr_7 = Fr::from(7);
let fr_0_1 = Fr::from_str(
"12583541437132735734108669866114103169564651237895298778035846191048104863326",
)
.unwrap();
let fr_2_3 = Fr::from_str(
"17197790661637433027297685226742709599380837544520340689137581733613433332983",
)
.unwrap();
let fr_4_5 = Fr::from_str(
"756592041685769348226045093946546956867261766023639881791475046640232555043",
)
.unwrap();
let fr_6_7 = Fr::from_str(
"5558359459771725727593826278265342308584225092343962757289948761260561575479",
)
.unwrap();
let fr_0_3 = Fr::from_str(
"3720616653028013822312861221679392249031832781774563366107458835261883914924",
)
.unwrap();
let fr_4_7 = Fr::from_str(
"7960741062684589801276390367952372418815534638314682948141519164356522829957",
)
.unwrap();
// ~ root
let fr_0_7 = Fr::from_str(
"11780650233517635876913804110234352847867393797952240856403268682492028497284",
)
.unwrap();
// map (key: what to hash, value: expected value)
let map = HashMap::from([
((fr_0, fr_1), fr_0_1),
((fr_2, fr_3), fr_2_3),
((fr_4, fr_5), fr_4_5),
((fr_6, fr_7), fr_6_7),
((fr_0_1, fr_2_3), fr_0_3),
((fr_4_5, fr_6_7), fr_4_7),
((fr_0_3, fr_4_7), fr_0_7),
]);
for (k, v) in map.into_iter() {
let hasher = Poseidon::<Fr>::from(&ROUND_PARAMS);
let h = hasher.hash(&[k.0, k.1]);
assert_eq!(h.unwrap(), v);
}
}
}