mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
4 Commits
am/chore/m
...
pa/32bit-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
359ebe02b8 | ||
|
|
b89d4419d7 | ||
|
|
538692cc48 | ||
|
|
4505c5209a |
1
.github/workflows/benchmark_cpu.yml
vendored
1
.github/workflows/benchmark_cpu.yml
vendored
@@ -14,6 +14,7 @@ on:
|
||||
- signed_integer
|
||||
- integer_compression
|
||||
- integer_zk
|
||||
- msm_zk
|
||||
- shortint
|
||||
- shortint_oprf
|
||||
- hlapi_unsigned
|
||||
|
||||
2
.github/workflows/benchmark_gpu.yml
vendored
2
.github/workflows/benchmark_gpu.yml
vendored
@@ -31,6 +31,8 @@ on:
|
||||
- pbs128
|
||||
- ks
|
||||
- ks_pbs
|
||||
- tfhe_zk_pok
|
||||
- msm_zk
|
||||
- integer_zk
|
||||
- integer_aes
|
||||
- integer_aes256
|
||||
|
||||
8
.github/workflows/gpu_zk_tests.yml
vendored
8
.github/workflows/gpu_zk_tests.yml
vendored
@@ -51,7 +51,12 @@ jobs:
|
||||
with:
|
||||
files_yaml: |
|
||||
gpu:
|
||||
- tfhe/Cargo.toml
|
||||
- tfhe/build.rs
|
||||
- backends/zk-cuda-backend/**
|
||||
- tfhe/src/integer/gpu/zk/**
|
||||
- tfhe-zk-pok/**
|
||||
- 'tfhe/docs/**/**.md'
|
||||
- '.github/workflows/gpu_zk_tests.yml'
|
||||
- ci/slab.toml
|
||||
|
||||
@@ -126,6 +131,9 @@ jobs:
|
||||
- name: Run zk-cuda-backend integration tests
|
||||
run: |
|
||||
make test_zk_cuda_backend
|
||||
make test_zk_pok_gpu
|
||||
make test_integer_zk_gpu
|
||||
make test_integer_zk_experimental_gpu
|
||||
|
||||
slack-notify:
|
||||
name: gpu_zk_tests/slack-notify
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -25,6 +25,7 @@ dieharder_run.log
|
||||
|
||||
# Cuda local build
|
||||
backends/tfhe-cuda-backend/cuda/cmake-build-debug/
|
||||
backends/tfhe-cuda-backend/cuda/build/
|
||||
|
||||
# WASM tests
|
||||
tfhe/web_wasm_parallel_tests/server.PID
|
||||
|
||||
68
Makefile
68
Makefile
@@ -353,14 +353,14 @@ check_typos: install_typos_checker
|
||||
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
|
||||
clippy_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,gpu-experimental-zk,pbs-stats,extended-types,zk-pok \
|
||||
--all-targets \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: check_gpu # Run check on tfhe with "gpu" enabled
|
||||
check_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" check \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,gpu-experimental-zk,pbs-stats \
|
||||
--all-targets \
|
||||
-p tfhe
|
||||
|
||||
@@ -374,7 +374,7 @@ clippy_hpu: install_rs_check_toolchain
|
||||
.PHONY: clippy_gpu_hpu # Run clippy lints on tfhe with "gpu" and "hpu" enabled
|
||||
clippy_gpu_hpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,hpu,pbs-stats,extended-types,zk-pok \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,gpu-experimental-zk,hpu,pbs-stats,extended-types,zk-pok \
|
||||
--all-targets \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
@@ -467,7 +467,7 @@ clippy_rustdoc_gpu: install_rs_check_toolchain
|
||||
fi && \
|
||||
CARGO_TERM_QUIET=true CLIPPYFLAGS="-D warnings" RUSTDOCFLAGS="--no-run --test-builder ./scripts/clippy_driver.sh -Z unstable-options" \
|
||||
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" test --doc \
|
||||
--features=boolean,shortint,integer,zk-pok,pbs-stats,strings,experimental,gpu \
|
||||
--features=boolean,shortint,integer,zk-pok,pbs-stats,strings,experimental,gpu,gpu-experimental-zk \
|
||||
-p tfhe -- --nocapture
|
||||
|
||||
.PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API
|
||||
@@ -670,7 +670,7 @@ build_c_api: install_rs_check_toolchain
|
||||
.PHONY: build_c_api_gpu # Build the C API for boolean, shortint and integer
|
||||
build_c_api_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \
|
||||
--features=boolean-c-api,shortint-c-api,high-level-c-api,zk-pok,extended-types,gpu \
|
||||
--features=boolean-c-api,shortint-c-api,high-level-c-api,zk-pok,extended-types,gpu,gpu-experimental-zk \
|
||||
-p tfhe
|
||||
|
||||
.PHONY: build_c_api_experimental_deterministic_fft # Build the C API for boolean, shortint and integer with experimental deterministic FFT
|
||||
@@ -769,7 +769,7 @@ test_zk_cuda_backend:
|
||||
|
||||
|
||||
.PHONY: test_gpu # Run the tests of the core_crypto module including experimental on the gpu backend
|
||||
test_gpu: test_core_crypto_gpu test_integer_gpu test_cuda_backend
|
||||
test_gpu: test_core_crypto_gpu test_integer_gpu test_cuda_backend test_zk_cuda_backend
|
||||
|
||||
.PHONY: test_core_crypto_gpu # Run the tests of the core_crypto module including experimental on the gpu backend
|
||||
test_core_crypto_gpu:
|
||||
@@ -1205,12 +1205,31 @@ test_tfhe_csprng_big_endian: install_cargo_cross
|
||||
RUSTFLAGS="" cross test --profile $(CARGO_PROFILE) \
|
||||
-p tfhe-csprng --target=powerpc64-unknown-linux-gnu
|
||||
|
||||
|
||||
.PHONY: test_zk_pok # Run tfhe-zk-pok tests
|
||||
test_zk_pok:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \
|
||||
-p tfhe-zk-pok --features experimental
|
||||
|
||||
.PHONY: test_zk_pok_gpu # Run tfhe-zk-pok GPU-accelerated tests
|
||||
test_zk_pok_gpu:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \
|
||||
-p tfhe-zk-pok --features experimental,gpu-experimental -- gpu
|
||||
|
||||
.PHONY: test_integer_zk_gpu # Run tfhe-zk-pok tests
|
||||
test_integer_zk_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile release \
|
||||
--features=integer,zk-pok,gpu -p tfhe -- \
|
||||
integer::gpu::zk::
|
||||
|
||||
.PHONY: test_integer_zk_experimental_gpu # Run tfhe-zk-pok tests
|
||||
test_integer_zk_experimental_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile release \
|
||||
--features=integer,zk-pok,gpu,gpu-experimental-zk -p tfhe -- \
|
||||
integer::gpu::zk::
|
||||
|
||||
.PHONY: test_zk_cuda # Run all GPU MSM integration tests (CPU vs GPU comparison + integration test)
|
||||
test_zk_cuda: install_rs_check_toolchain test_zk_cuda_backend test_zk_pok_gpu test_integer_zk_gpu test_integer_zk_experimental_gpu
|
||||
|
||||
.PHONY: test_zk_wasm_x86_compat_ci
|
||||
test_zk_wasm_x86_compat_ci: check_nvm_installed
|
||||
source ~/.nvm/nvm.sh && \
|
||||
@@ -1503,27 +1522,47 @@ bench_integer_compression_128b_gpu: install_rs_check_toolchain
|
||||
--bench glwe_packing_compression_128b-integer-bench \
|
||||
--features=integer,internal-keycache,gpu,pbs-stats -p tfhe-benchmark --
|
||||
|
||||
.PHONY: bench_msm_zk
|
||||
bench_msm_zk: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench zk-msm \
|
||||
--features=zk-pok -p tfhe-benchmark --profile release --
|
||||
|
||||
.PHONY: bench_msm_zk_gpu
|
||||
bench_msm_zk_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench zk-msm \
|
||||
--features=gpu,gpu-experimental-zk,zk-pok -p tfhe-benchmark --profile release --
|
||||
|
||||
.PHONY: bench_integer_zk_gpu
|
||||
bench_integer_zk_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) __TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench integer-zk-pke \
|
||||
--features=integer,internal-keycache,gpu,pbs-stats,zk-pok -p tfhe-benchmark --profile release_lto_off --
|
||||
--features=integer,internal-keycache,gpu,pbs-stats,zk-pok -p tfhe-benchmark --profile release --
|
||||
|
||||
.PHONY: bench_integer_zk_experimental_gpu
|
||||
bench_integer_zk_experimental_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) __TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench integer-zk-pke \
|
||||
--features=integer,internal-keycache,gpu,gpu-experimental-zk,pbs-stats,zk-pok -p tfhe-benchmark --profile release --
|
||||
|
||||
.PHONY: bench_integer_aes_gpu # Run benchmarks for AES on GPU backend
|
||||
bench_integer_aes_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench integer-aes \
|
||||
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
|
||||
--features=integer,internal-keycache,gpu -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_integer_aes256_gpu # Run benchmarks for AES256 on GPU backend
|
||||
bench_integer_aes256_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--bench integer-aes256 \
|
||||
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
|
||||
--features=integer,internal-keycache,gpu -p tfhe-benchmark --profile release_lto_off --
|
||||
|
||||
.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend
|
||||
bench_integer_trivium_gpu: install_rs_check_toolchain
|
||||
@@ -1806,6 +1845,13 @@ bench_tfhe_zk_pok: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench -p tfhe-zk-pok --
|
||||
|
||||
.PHONY: bench_tfhe_zk_pok_gpu # Run benchmarks for the tfhe_zk_pok crate using GPU acceleration
|
||||
bench_tfhe_zk_pok_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
|
||||
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
|
||||
--package tfhe-zk-pok \
|
||||
--features=gpu-experimental --profile release
|
||||
|
||||
.PHONY: bench_hlapi_noise_squash # Run benchmarks for noise squash operation
|
||||
bench_hlapi_noise_squash: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) \
|
||||
|
||||
@@ -46,7 +46,10 @@ fn main() {
|
||||
}
|
||||
|
||||
// Build CUDA library using cmake crate
|
||||
let limb_bits = std::env::var("ZK_CUDA_LIMB_BITS").unwrap_or_else(|_| "64".to_string());
|
||||
println!("cargo::rerun-if-env-changed=ZK_CUDA_LIMB_BITS");
|
||||
let mut cmake_config = cmake::Config::new("cuda");
|
||||
cmake_config.define("ZK_CUDA_LIMB_BITS", &limb_bits);
|
||||
let dest = cmake_config.build();
|
||||
|
||||
// cmake crate installs to dest/lib subdirectory
|
||||
|
||||
@@ -51,6 +51,16 @@ else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES 70)
|
||||
endif()
|
||||
|
||||
# Limb size configuration: 32 or 64 (default: 64)
|
||||
# 32-bit limbs enable PTX carry-chain optimizations on GPU
|
||||
set(ZK_CUDA_LIMB_BITS "64" CACHE STRING "Limb size in bits for Fp arithmetic (32 or 64)")
|
||||
set_property(CACHE ZK_CUDA_LIMB_BITS PROPERTY STRINGS "32" "64")
|
||||
if(NOT ZK_CUDA_LIMB_BITS STREQUAL "32" AND NOT ZK_CUDA_LIMB_BITS STREQUAL "64")
|
||||
message(FATAL_ERROR "ZK_CUDA_LIMB_BITS must be 32 or 64, got: ${ZK_CUDA_LIMB_BITS}")
|
||||
endif()
|
||||
add_compile_definitions(LIMB_BITS_CONFIG=${ZK_CUDA_LIMB_BITS})
|
||||
message(STATUS "Limb size: ${ZK_CUDA_LIMB_BITS}-bit")
|
||||
|
||||
# Enable CUDA separable compilation for better optimization
|
||||
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
|
||||
|
||||
|
||||
@@ -17,7 +17,15 @@ __host__ __device__ void fp2_zero(Fp2 &a);
|
||||
|
||||
// G1 point: (x, y) coordinates in Fp
|
||||
// Curve equation: y^2 = x^3 + b (short Weierstrass form with a = 0)
|
||||
struct G1Affine {
|
||||
// alignas(8) ensures identical struct layout (size 120) in both 32-bit and
|
||||
// 64-bit limb modes, matching the Rust FFI bindings generated from 64-bit.
|
||||
// Without this, 32-bit mode produces 116-byte structs (4-byte alignment from
|
||||
// uint32_t limbs) vs 120 bytes in Rust FFI, causing array stride mismatches
|
||||
// that corrupt point data for n>1.
|
||||
// The 4-byte padding overhead is negligible: MSM is compute-bound (Montgomery
|
||||
// multiplications dominate), and point access patterns in Pippenger-style MSM
|
||||
// are non-coalescing regardless of struct size.
|
||||
struct alignas(8) G1Affine {
|
||||
Fp x;
|
||||
Fp y;
|
||||
bool infinity; // true if point at infinity (identity element)
|
||||
@@ -36,7 +44,9 @@ struct G1Affine {
|
||||
|
||||
// G2 point: (x, y) coordinates in Fp2
|
||||
// Curve equation: y^2 = x^3 + b' (twisted curve over Fp2)
|
||||
struct G2Affine {
|
||||
// alignas(8): same rationale as G1Affine above — ensures FFI layout
|
||||
// compatibility (size 232) between 32-bit and 64-bit limb modes.
|
||||
struct alignas(8) G2Affine {
|
||||
Fp2 x;
|
||||
Fp2 y;
|
||||
bool infinity; // true if point at infinity (identity element)
|
||||
|
||||
@@ -282,6 +282,8 @@ __host__ __device__ void fp_cmov(Fp &dst, const Fp &src, uint64_t condition);
|
||||
// Helper functions to access constants
|
||||
// Get modulus reference (device: from constant memory, host: static copy)
|
||||
__host__ __device__ const Fp &fp_modulus();
|
||||
// Get Montgomery reduction constant p' = -p^(-1) mod 2^LIMB_BITS
|
||||
__host__ __device__ UNSIGNED_LIMB fp_p_prime();
|
||||
|
||||
// ============================================================================
|
||||
// Async/Sync API for device memory operations
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "bls12_446_params.h"
|
||||
#include "device.h"
|
||||
#include "fp.h"
|
||||
#include "fp_ptx32.cuh"
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
@@ -187,6 +188,9 @@ __host__ __device__ void fp_copy(Fp &dst, const Fp &src) {
|
||||
// "Raw" means without modular reduction - performs a + b and returns carry.
|
||||
// This is an internal helper used by fp_add() which handles reduction.
|
||||
__host__ __device__ UNSIGNED_LIMB fp_add_raw(Fp &c, const Fp &a, const Fp &b) {
|
||||
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
|
||||
return fp_add_raw_ptx32(c, a, b);
|
||||
#else
|
||||
UNSIGNED_LIMB carry = 0;
|
||||
|
||||
for (int i = 0; i < FP_LIMBS; i++) {
|
||||
@@ -199,12 +203,16 @@ __host__ __device__ UNSIGNED_LIMB fp_add_raw(Fp &c, const Fp &a, const Fp &b) {
|
||||
}
|
||||
|
||||
return carry;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Subtraction with borrow propagation
|
||||
// "Raw" means without modular reduction - performs a - b and returns borrow.
|
||||
// This is an internal helper used by fp_sub() which handles reduction.
|
||||
__host__ __device__ UNSIGNED_LIMB fp_sub_raw(Fp &c, const Fp &a, const Fp &b) {
|
||||
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
|
||||
return fp_sub_raw_ptx32(c, a, b);
|
||||
#else
|
||||
UNSIGNED_LIMB borrow = 0;
|
||||
|
||||
for (int i = 0; i < FP_LIMBS; i++) {
|
||||
@@ -218,11 +226,15 @@ __host__ __device__ UNSIGNED_LIMB fp_sub_raw(Fp &c, const Fp &a, const Fp &b) {
|
||||
}
|
||||
|
||||
return borrow;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Addition with modular reduction: c = (a + b) mod p
|
||||
// MONTGOMERY: Both inputs and output must be in Montgomery form
|
||||
__host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) {
|
||||
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
|
||||
fp_add_ptx32(c, a, b);
|
||||
#else
|
||||
Fp sum;
|
||||
UNSIGNED_LIMB carry = fp_add_raw(sum, a, b);
|
||||
|
||||
@@ -235,11 +247,15 @@ __host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) {
|
||||
} else {
|
||||
fp_copy(c, sum);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Subtraction with modular reduction: c = (a - b) mod p
|
||||
// MONTGOMERY: Both inputs and output must be in Montgomery form
|
||||
__host__ __device__ void fp_sub(Fp &c, const Fp &a, const Fp &b) {
|
||||
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
|
||||
fp_sub_ptx32(c, a, b);
|
||||
#else
|
||||
Fp diff;
|
||||
UNSIGNED_LIMB borrow = fp_sub_raw(diff, a, b);
|
||||
|
||||
@@ -250,6 +266,7 @@ __host__ __device__ void fp_sub(Fp &c, const Fp &a, const Fp &b) {
|
||||
} else {
|
||||
fp_copy(c, diff);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Small-constant multiplication via addition chains.
|
||||
@@ -458,6 +475,9 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
|
||||
// Uses only FP_LIMBS+1 limbs of working space instead of 2*FP_LIMBS.
|
||||
// Both a and b are in Montgomery form, result is in Montgomery form.
|
||||
__host__ __device__ void fp_mont_mul_cios(Fp &c, const Fp &a, const Fp &b) {
|
||||
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
|
||||
fp_mont_mul_cios_ptx32(c, a, b);
|
||||
#else
|
||||
const Fp &p = fp_modulus();
|
||||
UNSIGNED_LIMB p_prime = fp_p_prime();
|
||||
|
||||
@@ -545,6 +565,7 @@ __host__ __device__ void fp_mont_mul_cios(Fp &c, const Fp &a, const Fp &b) {
|
||||
fp_copy(c, reduced);
|
||||
}
|
||||
// Result is in Montgomery form
|
||||
#endif
|
||||
}
|
||||
|
||||
// Montgomery multiplication: c = (a * b * R_INV) mod p
|
||||
|
||||
@@ -29,6 +29,7 @@ rand = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
tfhe = { path = "../tfhe", default-features = false }
|
||||
tfhe-csprng = { path = "../tfhe-csprng" }
|
||||
tfhe-zk-pok = { path = "../tfhe-zk-pok", optional = true }
|
||||
cpu-time = "1.0"
|
||||
num_cpus = "1.17"
|
||||
gag = "1.0.0"
|
||||
@@ -39,12 +40,14 @@ boolean = ["tfhe/boolean"]
|
||||
shortint = ["tfhe/shortint"]
|
||||
integer = ["shortint", "tfhe/integer"]
|
||||
gpu = ["tfhe/gpu"]
|
||||
# gpu enables tfhe-cuda-backend which provides CUDA stream management used by tfhe-zk-pok
|
||||
gpu-experimental-zk = ["gpu", "zk-pok", "tfhe/gpu-experimental-zk", "tfhe-zk-pok/gpu-experimental"]
|
||||
hpu = ["tfhe/hpu"]
|
||||
hpu-v80 = ["tfhe/hpu-v80"]
|
||||
internal-keycache = ["tfhe/internal-keycache"]
|
||||
avx512 = ["tfhe/avx512"]
|
||||
pbs-stats = ["tfhe/pbs-stats"]
|
||||
zk-pok = ["tfhe/zk-pok"]
|
||||
zk-pok = ["tfhe/zk-pok", "dep:tfhe-zk-pok"]
|
||||
|
||||
[[bench]]
|
||||
name = "boolean"
|
||||
@@ -196,6 +199,12 @@ path = "benches/core_crypto/pbs128_bench.rs"
|
||||
harness = false
|
||||
required-features = ["shortint", "internal-keycache"]
|
||||
|
||||
[[bench]]
|
||||
name = "zk-msm"
|
||||
path = "benches/zk/msm.rs"
|
||||
harness = false
|
||||
required-features = ["zk-pok"]
|
||||
|
||||
[[bin]]
|
||||
name = "boolean_key_sizes"
|
||||
path = "src/bin/boolean_key_sizes.rs"
|
||||
|
||||
@@ -485,6 +485,24 @@ mod cuda {
|
||||
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
|
||||
use tfhe::integer::gpu::CudaServerKey;
|
||||
use tfhe::integer::CompressedServerKey;
|
||||
use tfhe::GpuIndex;
|
||||
|
||||
/// Compute the number of elements for GPU ZK throughput benchmarks.
|
||||
/// Values are tuned to avoid OOM on H100 GPUs while still saturating the GPU.
|
||||
/// Memory usage scales with both CRS size and bits being proven.
|
||||
fn gpu_zk_throughput_elements(crs_size: usize, bits: usize) -> u64 {
|
||||
match (crs_size, bits) {
|
||||
// 64-bit CRS: smaller proofs, can handle more elements
|
||||
(64, _) => 30,
|
||||
// 2048-bit CRS: moderate memory usage
|
||||
(2048, b) if b <= 256 => 15,
|
||||
(2048, _) => 10,
|
||||
// 4096-bit CRS: largest proofs, most memory intensive
|
||||
(4096, _) => 6,
|
||||
// Default fallback for unknown configurations
|
||||
_ => 10,
|
||||
}
|
||||
}
|
||||
|
||||
fn gpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
|
||||
let bench_name = "integer::cuda::zk::pke_zk_verify";
|
||||
@@ -686,12 +704,8 @@ mod cuda {
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let mut elements_per_gpu = 100;
|
||||
if *bits == 4096 {
|
||||
elements_per_gpu /= 5;
|
||||
}
|
||||
// This value, found empirically, ensure saturation of 8XH100 SXM5
|
||||
let elements = elements_per_gpu * get_number_of_gpus() as u64;
|
||||
let elements = gpu_zk_throughput_elements(crs_size, *bits)
|
||||
* get_number_of_gpus() as u64;
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
|
||||
bench_id_verify = format!(
|
||||
@@ -716,15 +730,38 @@ mod cuda {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let local_streams = cuda_local_streams(num_block, elements as usize);
|
||||
let d_ksk_material_vec = local_streams
|
||||
.par_iter()
|
||||
.map(|local_stream| {
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(
|
||||
&ksk,
|
||||
local_stream,
|
||||
let gpu_count = get_number_of_gpus() as usize;
|
||||
|
||||
let gpu_sks_vec: Vec<CudaServerKey> = (0..gpu_count)
|
||||
.map(|gpu_idx| {
|
||||
let stream =
|
||||
CudaStreams::new_single_gpu(GpuIndex::new(gpu_idx as u32));
|
||||
CudaServerKey::decompress_from_cpu(
|
||||
&compressed_server_key,
|
||||
&stream,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
.collect();
|
||||
|
||||
let d_ksk_material_vec: Vec<CudaKeySwitchingKeyMaterial> = (0
|
||||
..gpu_count)
|
||||
.map(|gpu_idx| {
|
||||
let stream =
|
||||
CudaStreams::new_single_gpu(GpuIndex::new(gpu_idx as u32));
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key(
|
||||
&ksk, &stream,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let d_ksks: Vec<CudaKeySwitchingKey> = (0..gpu_count)
|
||||
.map(|gpu_idx| {
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(
|
||||
&d_ksk_material_vec[gpu_idx],
|
||||
&gpu_sks_vec[gpu_idx],
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
bench_group.bench_function(&bench_id_verify, |b| {
|
||||
b.iter(|| {
|
||||
@@ -750,17 +787,16 @@ mod cuda {
|
||||
|gpu_cts| {
|
||||
gpu_cts.par_iter().enumerate().for_each
|
||||
(|(i, gpu_ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
|
||||
let stream_idx = i % local_streams.len();
|
||||
let local_stream = &local_streams[stream_idx];
|
||||
let gpu_idx = i % gpu_count;
|
||||
let d_ksk = &d_ksks[gpu_idx];
|
||||
|
||||
gpu_ct
|
||||
.expand_without_verification(&d_ksk, local_stream)
|
||||
.expand_without_verification(d_ksk, local_stream)
|
||||
.unwrap();
|
||||
});
|
||||
}, BatchSize::SmallInput);
|
||||
}, BatchSize::PerIteration);
|
||||
});
|
||||
|
||||
bench_group.bench_function(&bench_id_verify_and_expand, |b| {
|
||||
@@ -778,18 +814,18 @@ mod cuda {
|
||||
|gpu_cts| {
|
||||
gpu_cts.par_iter().enumerate().for_each
|
||||
(|(i, gpu_ct)| {
|
||||
let local_stream = &local_streams[i % local_streams.len()];
|
||||
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
|
||||
let d_ksk =
|
||||
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
|
||||
let stream_idx = i % local_streams.len();
|
||||
let local_stream = &local_streams[stream_idx];
|
||||
let gpu_idx = i % gpu_count;
|
||||
let d_ksk = &d_ksks[gpu_idx];
|
||||
|
||||
gpu_ct
|
||||
.verify_and_expand(
|
||||
&crs, &pk, &metadata, &d_ksk, local_stream,
|
||||
&crs, &pk, &metadata, d_ksk, local_stream,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
}, BatchSize::SmallInput);
|
||||
}, BatchSize::PerIteration);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -816,11 +852,154 @@ mod cuda {
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn gpu_pke_zk_proof(c: &mut Criterion) {
|
||||
let bench_name = "zk::cuda::pke_zk_proof";
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
let params: [(
|
||||
CompactPublicKeyEncryptionParameters,
|
||||
ShortintKeySwitchingParameters,
|
||||
PBSParameters,
|
||||
); 2] = [
|
||||
(
|
||||
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(),
|
||||
),
|
||||
(
|
||||
BENCH_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(),
|
||||
),
|
||||
];
|
||||
|
||||
for (param_pke, _param_ksk, param_fhe) in params.iter() {
|
||||
let param_name = param_fhe.name();
|
||||
let param_name = param_name.as_str();
|
||||
let cks = ClientKey::new(*param_fhe);
|
||||
let sks = ServerKey::new_radix_server_key(&cks);
|
||||
let compact_private_key = CompactPrivateKey::new(*param_pke);
|
||||
let pk = CompactPublicKey::new(&compact_private_key);
|
||||
// Kept for consistency
|
||||
let _casting_key =
|
||||
KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), *_param_ksk);
|
||||
|
||||
// We have a use case with 320 bits of metadata
|
||||
let mut metadata = [0u8; (320 / u8::BITS) as usize];
|
||||
let mut rng = rand::thread_rng();
|
||||
metadata.fill_with(|| rng.gen());
|
||||
|
||||
let zk_vers = param_pke.zk_scheme;
|
||||
|
||||
for proof_config in default_proof_config().iter() {
|
||||
let msg_bits =
|
||||
(param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize;
|
||||
println!("Generating CRS... ");
|
||||
let crs_size = proof_config.crs_size;
|
||||
let crs = CompactPkeCrs::from_shortint_params(
|
||||
*param_pke,
|
||||
LweCiphertextCount(crs_size / msg_bits),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for bits in proof_config.bits_to_prove.iter() {
|
||||
assert_eq!(bits % 64, 0);
|
||||
// Packing, so we take the message and carry modulus to compute our block count
|
||||
let num_block = 64usize.div_ceil(msg_bits);
|
||||
|
||||
let fhe_uint_count = bits / 64;
|
||||
|
||||
for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] {
|
||||
let zk_load = match compute_load {
|
||||
ZkComputeLoad::Proof => "compute_load_proof",
|
||||
ZkComputeLoad::Verify => "compute_load_verify",
|
||||
};
|
||||
|
||||
let bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
bench_id = format!(
|
||||
"{bench_name}::{param_name}_{bits}_bits_packed_{crs_size}_bits_crs_{zk_load}_ZK{zk_vers:?}"
|
||||
);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let input_msg = rng.gen::<u64>();
|
||||
let messages = vec![input_msg; fhe_uint_count];
|
||||
|
||||
b.iter(|| {
|
||||
let _ct1 =
|
||||
tfhe::integer::ProvenCompactCiphertextList::builder(
|
||||
&pk,
|
||||
)
|
||||
.extend(messages.iter().copied())
|
||||
.build_with_proof_packed(&crs, &metadata, compute_load)
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
// The zk proof is currently not pooled, so we simply use the number
|
||||
// of threads as heuristic for the
|
||||
// batch size
|
||||
let elements =
|
||||
(rayon::current_num_threads() / num_block).max(1) + 1;
|
||||
bench_group.throughput(Throughput::Elements(elements as u64));
|
||||
|
||||
bench_id = format!(
|
||||
"{bench_name}::throughput::{param_name}_{bits}_bits_packed_{crs_size}_bits_crs_{zk_load}_ZK{zk_vers:?}"
|
||||
);
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let messages = (0..elements)
|
||||
.map(|_| {
|
||||
let input_msg = rng.gen::<u64>();
|
||||
vec![input_msg; fhe_uint_count]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
b.iter(|| {
|
||||
messages.par_iter().for_each(|msg| {
|
||||
tfhe::integer::ProvenCompactCiphertextList::builder(
|
||||
&pk,
|
||||
)
|
||||
.extend(msg.iter().copied())
|
||||
.build_with_proof_packed(&crs, &metadata, compute_load)
|
||||
.unwrap();
|
||||
})
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let shortint_params: PBSParameters = *param_fhe;
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
shortint_params,
|
||||
param_name,
|
||||
"pke_zk_proof",
|
||||
&OperatorType::Atomic,
|
||||
shortint_params.message_modulus().0 as u32,
|
||||
vec![shortint_params.message_modulus().0.ilog2(); num_block],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gpu_zk_verify() {
|
||||
let results_file = Path::new("gpu_pke_zk_crs_sizes.csv");
|
||||
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
|
||||
gpu_pke_zk_verify(&mut criterion, results_file);
|
||||
}
|
||||
|
||||
pub fn gpu_zk_proof() {
|
||||
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
|
||||
gpu_pke_zk_proof(&mut criterion);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zk_verify_and_proof() {
|
||||
@@ -831,11 +1010,14 @@ pub fn zk_verify_and_proof() {
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
|
||||
use crate::cuda::gpu_zk_verify;
|
||||
use crate::cuda::{gpu_zk_proof, gpu_zk_verify};
|
||||
|
||||
fn main() {
|
||||
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
|
||||
gpu_zk_verify();
|
||||
{
|
||||
gpu_zk_proof();
|
||||
gpu_zk_verify();
|
||||
}
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
zk_verify_and_proof();
|
||||
|
||||
|
||||
406
tfhe-benchmark/benches/zk/msm.rs
Normal file
406
tfhe-benchmark/benches/zk/msm.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
//! Benchmark comparing CPU MSM vs GPU MSM for BLS12-446
|
||||
//!
|
||||
//! This benchmark measures the performance of multi-scalar multiplication (MSM)
|
||||
//! for both G1 and G2 points on the BLS12-446 curve.
|
||||
//!
|
||||
//! CPU benchmarks use the arkworks-based `G1Affine::multi_mul_scalar` /
|
||||
//! `G2Affine::multi_mul_scalar`. GPU benchmarks (gated behind the
|
||||
//! `gpu-experimental-zk` feature) call `tfhe_zk_pok::gpu::g1_msm_gpu` /
|
||||
//! `tfhe_zk_pok::gpu::g2_msm_gpu` directly, which dispatch to the
|
||||
//! zk-cuda-backend.
|
||||
//!
|
||||
//! ## Running the benchmarks
|
||||
//!
|
||||
//! ```bash
|
||||
//! # CPU only
|
||||
//! cargo bench --package tfhe-benchmark --bench zk-msm
|
||||
//!
|
||||
//! # CPU and GPU
|
||||
//! cargo bench --package tfhe-benchmark --bench zk-msm --features gpu-experimental-zk
|
||||
//! ```
|
||||
|
||||
use benchmark::utilities::{
|
||||
get_bench_type, write_to_json, BenchmarkType, CryptoParametersRecord, OperatorType,
|
||||
};
|
||||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion, Throughput};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use rayon::prelude::*;
|
||||
use std::time::Duration;
|
||||
|
||||
use tfhe_zk_pok::curve_api::bls12_446::{G1Affine, G2Affine, Zp, G1, G2};
|
||||
use tfhe_zk_pok::curve_api::CurveGroupOps;
|
||||
|
||||
/// Compute the number of parallel elements for MSM throughput benchmarks.
|
||||
/// Uses aggressive values to maximize throughput testing while keeping setup time reasonable.
|
||||
fn msm_throughput_elements(input_size: usize) -> u64 {
|
||||
match input_size {
|
||||
n if n <= 1000 => 64,
|
||||
n if n <= 4096 => 32,
|
||||
_ => 16,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate random G1 affine points using tfhe-zk-pok
|
||||
fn generate_g1_affine_points(rng: &mut StdRng, n: usize) -> Vec<G1Affine> {
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
let point = G1::GENERATOR.mul_scalar(Zp::rand(rng));
|
||||
point.normalize()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate random G2 affine points using tfhe-zk-pok
|
||||
fn generate_g2_affine_points(rng: &mut StdRng, n: usize) -> Vec<G2Affine> {
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
let point = G2::GENERATOR.mul_scalar(Zp::rand(rng));
|
||||
point.normalize()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Generate random scalars using tfhe-zk-pok
|
||||
fn generate_scalars(rng: &mut StdRng, n: usize) -> Vec<Zp> {
|
||||
(0..n).map(|_| Zp::rand(rng)).collect()
|
||||
}
|
||||
|
||||
/// Benchmark CPU MSM for G1 points using tfhe-zk-pok entry points
|
||||
fn bench_cpu_g1_msm(c: &mut Criterion) {
|
||||
let curve_name = "bls12_446";
|
||||
let subgroup_name = "G1";
|
||||
let bench_name = format!("zk::msm::{curve_name}::{subgroup_name}");
|
||||
|
||||
let mut group = c.benchmark_group(&bench_name);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
for size in [100, 1000, 2048, 4096, 10000].iter() {
|
||||
let n = *size;
|
||||
let bench_id;
|
||||
let bench_shortname = "zk::msm::bls12_446::g1";
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let bases = generate_g1_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
|
||||
bench_id = format!("{bench_name}::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
b.iter(|| {
|
||||
let result =
|
||||
G1Affine::multi_mul_scalar(black_box(&bases), black_box(&scalars));
|
||||
black_box(result)
|
||||
});
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let elements = msm_throughput_elements(n);
|
||||
group.throughput(Throughput::Elements(elements));
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
// Setup generates test data in parallel, excluded from measurement
|
||||
let setup = || {
|
||||
(0..elements)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let mut rng = StdRng::seed_from_u64(42 + i);
|
||||
let bases = generate_g1_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
(bases, scalars)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup,
|
||||
|test_data| {
|
||||
test_data.par_iter().for_each(|(bases, scalars)| {
|
||||
let result = G1Affine::multi_mul_scalar(
|
||||
black_box(bases),
|
||||
black_box(scalars),
|
||||
);
|
||||
black_box(result);
|
||||
});
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// MSM benchmarks are curve operations, use minimal parameters
|
||||
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
params,
|
||||
"MSM_BLS12_446_G1",
|
||||
bench_shortname,
|
||||
&OperatorType::Atomic,
|
||||
64, // bit_size for curve scalar operations
|
||||
vec![], // decomposition_basis not applicable for MSM
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark CPU MSM for G2 points using tfhe-zk-pok entry points
|
||||
fn bench_cpu_g2_msm(c: &mut Criterion) {
|
||||
let curve_name = "bls12_446";
|
||||
let subgroup_name = "G2";
|
||||
let bench_name = format!("zk::msm::{curve_name}::{subgroup_name}");
|
||||
|
||||
let mut group = c.benchmark_group(&bench_name);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
for size in [100, 1000, 2048, 4096, 10000].iter() {
|
||||
let n = *size;
|
||||
let bench_id;
|
||||
let bench_shortname = "zk::msm::bls12_446::g2";
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let bases = generate_g2_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
|
||||
bench_id = format!("{bench_name}::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
b.iter(|| {
|
||||
let result =
|
||||
G2Affine::multi_mul_scalar(black_box(&bases), black_box(&scalars));
|
||||
black_box(result)
|
||||
});
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let elements = msm_throughput_elements(n);
|
||||
group.throughput(Throughput::Elements(elements));
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
// Setup generates test data in parallel, excluded from measurement
|
||||
let setup = || {
|
||||
(0..elements)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let mut rng = StdRng::seed_from_u64(42 + i);
|
||||
let bases = generate_g2_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
(bases, scalars)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup,
|
||||
|test_data| {
|
||||
test_data.par_iter().for_each(|(bases, scalars)| {
|
||||
let result = G2Affine::multi_mul_scalar(
|
||||
black_box(bases),
|
||||
black_box(scalars),
|
||||
);
|
||||
black_box(result);
|
||||
});
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// MSM benchmarks are curve operations, use minimal parameters
|
||||
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
params,
|
||||
"MSM_BLS12_446_G2",
|
||||
bench_shortname,
|
||||
&OperatorType::Atomic,
|
||||
64, // bit_size for curve scalar operations
|
||||
vec![], // decomposition_basis not applicable for MSM
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark GPU MSM for G1 points via `tfhe_zk_pok::gpu::g1_msm_gpu`
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
fn bench_gpu_g1_msm(c: &mut Criterion) {
|
||||
use tfhe_zk_pok::gpu::{g1_msm_gpu, select_gpu_for_msm};
|
||||
|
||||
let curve_name = "bls12_446";
|
||||
let subgroup_name = "G1";
|
||||
let bench_name = format!("zk::cuda::msm::{curve_name}::{subgroup_name}");
|
||||
|
||||
let mut group = c.benchmark_group(&bench_name);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
// Resolve GPU index once — stream creation/destruction is handled inside g1_msm_gpu
|
||||
let gpu_index = select_gpu_for_msm();
|
||||
|
||||
for size in [100, 1000, 2048, 4096, 10000].iter() {
|
||||
let n = *size;
|
||||
let bench_id;
|
||||
let bench_shortname = "zk::cuda::msm::bls12_446::g1";
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let bases = generate_g1_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
|
||||
bench_id = format!("{bench_name}::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
b.iter(|| {
|
||||
let result = g1_msm_gpu(black_box(&bases), black_box(&scalars), gpu_index);
|
||||
black_box(result)
|
||||
});
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let elements = msm_throughput_elements(n);
|
||||
group.throughput(Throughput::Elements(elements));
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
let setup = || {
|
||||
(0..elements)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let mut rng = StdRng::seed_from_u64(42 + i);
|
||||
let bases = generate_g1_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
(bases, scalars)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup,
|
||||
|test_data| {
|
||||
test_data.par_iter().for_each(|(bases, scalars)| {
|
||||
let result =
|
||||
g1_msm_gpu(black_box(bases), black_box(scalars), gpu_index);
|
||||
black_box(result);
|
||||
});
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
params,
|
||||
"MSM_BLS12_446_G1_CUDA",
|
||||
bench_shortname,
|
||||
&OperatorType::Atomic,
|
||||
64, // bit_size for curve scalar operations
|
||||
vec![], // decomposition_basis not applicable for MSM
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark GPU MSM for G2 points via `tfhe_zk_pok::gpu::g2_msm_gpu`
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
fn bench_gpu_g2_msm(c: &mut Criterion) {
|
||||
use tfhe_zk_pok::gpu::{g2_msm_gpu, select_gpu_for_msm};
|
||||
|
||||
let curve_name = "bls12_446";
|
||||
let subgroup_name = "G2";
|
||||
let bench_name = format!("zk::cuda::msm::{curve_name}::{subgroup_name}");
|
||||
|
||||
let mut group = c.benchmark_group(&bench_name);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
let gpu_index = select_gpu_for_msm();
|
||||
|
||||
for size in [100, 1000, 2048, 4096, 10000].iter() {
|
||||
let n = *size;
|
||||
let bench_id;
|
||||
let bench_shortname = "zk::cuda::msm::bls12_446::g2";
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let bases = generate_g2_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
|
||||
bench_id = format!("{bench_name}::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
b.iter(|| {
|
||||
let result = g2_msm_gpu(black_box(&bases), black_box(&scalars), gpu_index);
|
||||
black_box(result)
|
||||
});
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let elements = msm_throughput_elements(n);
|
||||
group.throughput(Throughput::Elements(elements));
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{n}");
|
||||
group.bench_with_input(&bench_id, &n, |b, _| {
|
||||
let setup = || {
|
||||
(0..elements)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let mut rng = StdRng::seed_from_u64(42 + i);
|
||||
let bases = generate_g2_affine_points(&mut rng, n);
|
||||
let scalars = generate_scalars(&mut rng, n);
|
||||
(bases, scalars)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup,
|
||||
|test_data| {
|
||||
test_data.par_iter().for_each(|(bases, scalars)| {
|
||||
let result =
|
||||
g2_msm_gpu(black_box(bases), black_box(scalars), gpu_index);
|
||||
black_box(result);
|
||||
});
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
params,
|
||||
"MSM_BLS12_446_G2_CUDA",
|
||||
bench_shortname,
|
||||
&OperatorType::Atomic,
|
||||
64, // bit_size for curve scalar operations
|
||||
vec![], // decomposition_basis not applicable for MSM
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// CPU benchmarks (always available)
|
||||
criterion_group!(benches_cpu, bench_cpu_g1_msm, bench_cpu_g2_msm,);
|
||||
|
||||
// GPU benchmarks (only when GPU feature is enabled)
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
criterion_group!(benches_gpu, bench_gpu_g1_msm, bench_gpu_g2_msm,);
|
||||
|
||||
// Conditionally include GPU benchmarks in main
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
criterion_main!(benches_cpu, benches_gpu);
|
||||
|
||||
#[cfg(not(feature = "gpu-experimental-zk"))]
|
||||
criterion_main!(benches_cpu);
|
||||
@@ -14,8 +14,8 @@ rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
ark-bls12-381 = "0.5.0"
|
||||
ark-ec = { version = "0.5.0", features = ["parallel"] }
|
||||
ark-ff = { version = "0.5.0", features = ["parallel"] }
|
||||
ark-ec = { workspace = true, features = ["parallel"] }
|
||||
ark-ff = { workspace = true, features = ["parallel"] }
|
||||
ark-poly = { version = "0.5.0", features = ["parallel"] }
|
||||
rand = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
@@ -24,9 +24,13 @@ serde = { workspace = true, features = ["default", "derive"] }
|
||||
zeroize = "1.7.0"
|
||||
num-bigint = "0.4.5"
|
||||
tfhe-versionable = { version = "0.7.0", path = "../utils/tfhe-versionable" }
|
||||
zk-cuda-backend = { version = "0.1.0", path = "../backends/zk-cuda-backend", optional = true }
|
||||
tfhe-cuda-backend = { version = "0.13.0", path = "../backends/tfhe-cuda-backend", optional = true }
|
||||
itertools.workspace = true
|
||||
|
||||
[features]
|
||||
experimental = []
|
||||
gpu-experimental = ["dep:zk-cuda-backend", "dep:tfhe-cuda-backend"]
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = "~1.0"
|
||||
|
||||
@@ -91,5 +91,110 @@ fn bench_pke_v1_verify(c: &mut Criterion) {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
mod gpu {
|
||||
use super::*;
|
||||
use tfhe_zk_pok::proofs::pke;
|
||||
|
||||
pub fn bench_pke_v1_prove_gpu(c: &mut Criterion) {
|
||||
let bench_shortname = "pke_zk_proof_v1";
|
||||
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
|
||||
let mut bench_group = c.benchmark_group(&bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
let rng = &mut rand::thread_rng();
|
||||
|
||||
for (params, param_name) in [
|
||||
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
|
||||
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
|
||||
] {
|
||||
let (public_param, public_commit, private_commit, metadata) = init_params_v1(params);
|
||||
let effective_t = params.t >> 1;
|
||||
let bits = (params.k as u32) * effective_t.ilog2();
|
||||
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
let bench_id = format!("{bench_name}::{param_name}_{bits}_bits_packed_{load}");
|
||||
|
||||
let seed: u128 = rng.gen();
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
pke::gpu::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&metadata,
|
||||
load,
|
||||
&seed.to_le_bytes(),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json(&bench_id, params, param_name, bench_shortname);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_pke_v1_verify_gpu(c: &mut Criterion) {
|
||||
let bench_shortname = "pke_zk_verify_v1";
|
||||
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
|
||||
let mut bench_group = c.benchmark_group(&bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
let rng = &mut rand::thread_rng();
|
||||
|
||||
for (params, param_name) in [
|
||||
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
|
||||
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
|
||||
] {
|
||||
let (public_param, public_commit, private_commit, metadata) = init_params_v1(params);
|
||||
let effective_t = params.t >> 1;
|
||||
let bits = (params.k as u32) * effective_t.ilog2();
|
||||
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
let bench_id = format!("{bench_name}::{param_name}_{bits}_bits_packed_{load}");
|
||||
|
||||
let seed: u128 = rng.gen();
|
||||
|
||||
// Use GPU prove to generate the proof
|
||||
let proof = pke::gpu::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&metadata,
|
||||
load,
|
||||
&seed.to_le_bytes(),
|
||||
);
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
pke::gpu::verify(&proof, (&public_param, &public_commit), &metadata)
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json(&bench_id, params, param_name, bench_shortname);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches_pke_v1, bench_pke_v1_verify, bench_pke_v1_prove);
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
use gpu::{bench_pke_v1_prove_gpu, bench_pke_v1_verify_gpu};
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
criterion_group!(
|
||||
benches_pke_v1_gpu,
|
||||
bench_pke_v1_verify_gpu,
|
||||
bench_pke_v1_prove_gpu
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
criterion_main!(benches_pke_v1, benches_pke_v1_gpu);
|
||||
|
||||
#[cfg(not(feature = "gpu-experimental"))]
|
||||
criterion_main!(benches_pke_v1);
|
||||
|
||||
@@ -107,5 +107,130 @@ fn bench_pke_v2_verify(c: &mut Criterion) {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
mod gpu {
|
||||
use super::*;
|
||||
use tfhe_zk_pok::proofs::pke_v2;
|
||||
|
||||
pub fn bench_pke_v2_prove_gpu(c: &mut Criterion) {
|
||||
let bench_shortname = "pke_zk_proof_v2";
|
||||
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
|
||||
let mut bench_group = c.benchmark_group(&bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
let rng = &mut rand::thread_rng();
|
||||
|
||||
for ((params, param_name), load, bound) in itertools::iproduct!(
|
||||
[
|
||||
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
|
||||
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
|
||||
],
|
||||
[ComputeLoad::Proof, ComputeLoad::Verify],
|
||||
[Bound::CS, Bound::GHL]
|
||||
) {
|
||||
let (public_param, public_commit, private_commit, metadata) =
|
||||
init_params_v2(params, bound);
|
||||
let effective_t = params.t >> 1;
|
||||
let bits = (params.k as u32) * effective_t.ilog2();
|
||||
|
||||
let bench_id =
|
||||
format!("{bench_name}::{param_name}_{bits}_bits_packed_{load}_{bound:?}");
|
||||
println!("{bench_id}");
|
||||
|
||||
let seed: u128 = rng.gen();
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
pke_v2::gpu::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&metadata,
|
||||
load,
|
||||
&seed.to_le_bytes(),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json(&bench_id, params, param_name, bench_shortname);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_pke_v2_verify_gpu(c: &mut Criterion) {
|
||||
let bench_shortname = "pke_zk_verify_v2";
|
||||
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
|
||||
let mut bench_group = c.benchmark_group(&bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
|
||||
let rng = &mut rand::thread_rng();
|
||||
|
||||
for ((params, param_name), load, bound, pairing_mode) in itertools::iproduct!(
|
||||
[
|
||||
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
|
||||
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
|
||||
],
|
||||
[ComputeLoad::Proof, ComputeLoad::Verify],
|
||||
[Bound::CS, Bound::GHL],
|
||||
[
|
||||
VerificationPairingMode::TwoSteps,
|
||||
VerificationPairingMode::Batched
|
||||
]
|
||||
) {
|
||||
let (public_param, public_commit, private_commit, metadata) =
|
||||
init_params_v2(params, bound);
|
||||
let effective_t = params.t >> 1;
|
||||
let bits = (params.k as u32) * effective_t.ilog2();
|
||||
|
||||
let bench_id = format!(
|
||||
"{bench_name}::{param_name}_{bits}_bits_packed_{load}_{bound:?}_{pairing_mode:?}"
|
||||
);
|
||||
println!("{bench_id}");
|
||||
|
||||
let seed: u128 = rng.gen();
|
||||
|
||||
// Use GPU prove to generate the proof
|
||||
let proof = pke_v2::gpu::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&metadata,
|
||||
load,
|
||||
&seed.to_le_bytes(),
|
||||
);
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
pke_v2::gpu::verify(
|
||||
&proof,
|
||||
(&public_param, &public_commit),
|
||||
&metadata,
|
||||
pairing_mode,
|
||||
)
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
write_to_json(&bench_id, params, param_name, bench_shortname);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches_pke_v2, bench_pke_v2_verify, bench_pke_v2_prove);
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
use gpu::{bench_pke_v2_prove_gpu, bench_pke_v2_verify_gpu};
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
criterion_group!(
|
||||
benches_pke_v2_gpu,
|
||||
bench_pke_v2_verify_gpu,
|
||||
bench_pke_v2_prove_gpu
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
criterion_main!(benches_pke_v2, benches_pke_v2_gpu);
|
||||
|
||||
#[cfg(not(feature = "gpu-experimental"))]
|
||||
criterion_main!(benches_pke_v2);
|
||||
|
||||
331
tfhe-zk-pok/src/gpu/mod.rs
Normal file
331
tfhe-zk-pok/src/gpu/mod.rs
Normal file
@@ -0,0 +1,331 @@
|
||||
//! GPU acceleration module for tfhe-zk-pok
|
||||
//!
|
||||
//! This module provides GPU-accelerated operations using zk-cuda-backend,
|
||||
//! type conversions between tfhe-zk-pok and zk-cuda-backend types,
|
||||
//! and GPU MSM helper functions used by `proofs::pke::gpu` and
|
||||
//! `proofs::pke_v2::gpu`.
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use crate::curve_446::{Fq, Fq2};
|
||||
use crate::curve_api::bls12_446::{G1Affine, G2Affine, Zp, G1, G2};
|
||||
use crate::curve_api::CurveGroupOps;
|
||||
use ark_ec::CurveGroup;
|
||||
use ark_ff::{BigInt, MontFp, PrimeField};
|
||||
use tfhe_cuda_backend::cuda_bind::{
|
||||
cuda_create_stream, cuda_destroy_stream, cuda_get_number_of_gpus,
|
||||
};
|
||||
use zk_cuda_backend::{G1Affine as ZkG1Affine, G2Affine as ZkG2Affine, Scalar as ZkScalar};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compile-time assertions to verify transmute safety between wrapper types and inner arkworks
|
||||
// types. These ensure that G1Affine/G2Affine wrappers are truly repr(transparent) around their
|
||||
// inner types.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const _: () = {
|
||||
assert!(
|
||||
std::mem::size_of::<G1Affine>() == std::mem::size_of::<crate::curve_446::g1::G1Affine>()
|
||||
);
|
||||
assert!(
|
||||
std::mem::align_of::<G1Affine>() == std::mem::align_of::<crate::curve_446::g1::G1Affine>()
|
||||
);
|
||||
assert!(
|
||||
std::mem::size_of::<G2Affine>() == std::mem::size_of::<crate::curve_446::g2::G2Affine>()
|
||||
);
|
||||
assert!(
|
||||
std::mem::align_of::<G2Affine>() == std::mem::align_of::<crate::curve_446::g2::G2Affine>()
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GPU helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Returns the number of available GPUs. Panics if no GPU is found.
|
||||
///
|
||||
/// The result is cached after the first call since GPU count cannot change
|
||||
/// during execution.
|
||||
pub(crate) fn get_num_gpus() -> u32 {
|
||||
static NUM_GPUS: std::sync::OnceLock<u32> = std::sync::OnceLock::new();
|
||||
*NUM_GPUS.get_or_init(|| {
|
||||
// SAFETY: cuda_get_number_of_gpus is a pure query with no preconditions
|
||||
let num_gpus = unsafe { cuda_get_number_of_gpus() };
|
||||
assert!(num_gpus > 0, "No GPU available");
|
||||
num_gpus
|
||||
.try_into()
|
||||
.expect("cuda_get_number_of_gpus returned negative value")
|
||||
})
|
||||
}
|
||||
|
||||
/// Selects a GPU for MSM based on the rayon thread index, distributing work
|
||||
/// across all available GPUs. Returns `None` (meaning GPU 0) when only one GPU
|
||||
/// is present.
|
||||
#[inline]
|
||||
pub fn select_gpu_for_msm() -> Option<u32> {
|
||||
let num_gpus = get_num_gpus();
|
||||
if num_gpus <= 1 {
|
||||
return None;
|
||||
}
|
||||
let thread_idx = rayon::current_thread_index().unwrap_or(0);
|
||||
Some(
|
||||
(thread_idx % num_gpus as usize)
|
||||
.try_into()
|
||||
.expect("GPU index fits in u32"),
|
||||
)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Type conversion helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Convert an Fq element (arkworks) to a zk-cuda-backend Fp (normal form limbs).
|
||||
#[inline]
|
||||
fn fq_to_cuda_fp(fq: &Fq) -> zk_cuda_backend::Fp {
|
||||
zk_cuda_backend::Fp::new(fq.into_bigint().0)
|
||||
}
|
||||
|
||||
/// Convert a zk-cuda-backend Fp (normal form limbs) back to an arkworks Fq.
|
||||
#[inline]
|
||||
fn fq_from_cuda_fp(fp: &zk_cuda_backend::Fp) -> Fq {
|
||||
Fq::from_bigint(BigInt::new(fp.limb)).expect("invalid Fq element from CUDA Fp limbs")
|
||||
}
|
||||
|
||||
/// Convert an Fq2 element (arkworks) to a zk-cuda-backend Fp2 (normal form).
|
||||
#[inline]
|
||||
fn fq2_to_cuda_fp2(fq2: &Fq2) -> zk_cuda_backend::bindings::Fp2 {
|
||||
zk_cuda_backend::bindings::Fp2 {
|
||||
c0: fq_to_cuda_fp(&fq2.c0),
|
||||
c1: fq_to_cuda_fp(&fq2.c1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a zk-cuda-backend Fp2 back to an arkworks Fq2.
|
||||
#[inline]
|
||||
fn fq2_from_cuda_fp2(fp2: &zk_cuda_backend::bindings::Fp2) -> Fq2 {
|
||||
Fq2::new(fq_from_cuda_fp(&fp2.c0), fq_from_cuda_fp(&fp2.c1))
|
||||
}
|
||||
|
||||
/// Convert a tfhe-zk-pok G1Affine to a zk-cuda-backend G1Affine (normal form).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the input point is non-identity but has no coordinates (malformed arkworks point).
|
||||
pub fn g1_affine_to_zk_cuda(affine: &G1Affine) -> ZkG1Affine {
|
||||
use ark_ec::AffineRepr;
|
||||
|
||||
if affine.inner.is_zero() {
|
||||
return ZkG1Affine::infinity();
|
||||
}
|
||||
let xy = affine
|
||||
.inner
|
||||
.xy()
|
||||
.expect("non-identity point must have coordinates");
|
||||
let x = fq_to_cuda_fp(&xy.0);
|
||||
let y = fq_to_cuda_fp(&xy.1);
|
||||
ZkG1Affine::new(x, y, affine.inner.infinity)
|
||||
}
|
||||
|
||||
/// Convert a zk-cuda-backend G1Affine back to a tfhe-zk-pok G1Affine.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the Fp limbs from the zk-cuda-backend point do not represent a valid `Fq` element
|
||||
/// (i.e., the value is not in the base field).
|
||||
pub fn g1_affine_from_zk_cuda(affine: &ZkG1Affine) -> G1Affine {
|
||||
if affine.is_infinity() {
|
||||
return G1::ZERO.normalize();
|
||||
}
|
||||
|
||||
let x = fq_from_cuda_fp(&affine.x());
|
||||
let y = fq_from_cuda_fp(&affine.y());
|
||||
|
||||
use crate::curve_446::g1::G1Projective;
|
||||
let one = MontFp!("1");
|
||||
let proj = G1Projective::new_unchecked(x, y, one);
|
||||
let inner = <G1Projective as CurveGroup>::into_affine(proj);
|
||||
G1Affine { inner }
|
||||
}
|
||||
|
||||
/// Convert a tfhe-zk-pok G2Affine to a zk-cuda-backend G2Affine (normal form).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the input point is non-identity but has no coordinates (malformed arkworks point).
|
||||
pub fn g2_affine_to_zk_cuda(affine: &G2Affine) -> ZkG2Affine {
|
||||
use ark_ec::AffineRepr;
|
||||
|
||||
if affine.inner.is_zero() {
|
||||
return ZkG2Affine::infinity();
|
||||
}
|
||||
let xy = affine
|
||||
.inner
|
||||
.xy()
|
||||
.expect("non-identity point must have coordinates");
|
||||
let x = fq2_to_cuda_fp2(&xy.0);
|
||||
let y = fq2_to_cuda_fp2(&xy.1);
|
||||
ZkG2Affine::new(x, y, affine.inner.infinity)
|
||||
}
|
||||
|
||||
/// Convert a zk-cuda-backend G2Affine back to a tfhe-zk-pok G2Affine.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the Fp limbs from the zk-cuda-backend point do not represent valid `Fq` elements
|
||||
/// (i.e., the values are not in the base field).
|
||||
pub fn g2_affine_from_zk_cuda(affine: &ZkG2Affine) -> G2Affine {
|
||||
if affine.is_infinity() {
|
||||
return G2::ZERO.normalize();
|
||||
}
|
||||
|
||||
let x = fq2_from_cuda_fp2(&affine.x());
|
||||
let y = fq2_from_cuda_fp2(&affine.y());
|
||||
|
||||
use crate::curve_446::g2::G2Projective;
|
||||
let one = MontFp!("1");
|
||||
let zero = MontFp!("0");
|
||||
let one_fq2 = Fq2::new(one, zero);
|
||||
let proj = G2Projective::new_unchecked(x, y, one_fq2);
|
||||
let inner = <G2Projective as CurveGroup>::into_affine(proj);
|
||||
G2Affine { inner }
|
||||
}
|
||||
|
||||
/// Convert a Zp scalar to a zk-cuda-backend Scalar.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function does not panic. The `into_bigint` conversion on arkworks `Fp` types is
|
||||
/// infallible, and `ZkScalar::from` accepts any 5-limb array.
|
||||
pub fn zp_to_zk_scalar(zp: &Zp) -> ZkScalar {
|
||||
let limbs = zp.inner.into_bigint().0;
|
||||
ZkScalar::from(limbs)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GPU MSM functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// GPU-accelerated multi-scalar multiplication for G1. The `gpu_index` parameter
|
||||
/// selects which GPU device to use; `None` defaults to GPU 0.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// - If `gpu_index` is `Some(i)` where `i >= number of available GPUs`.
|
||||
/// - If `bases` and `scalars` have different lengths (checked inside the backend).
|
||||
/// - If the GPU MSM call fails.
|
||||
pub fn g1_msm_gpu(bases: &[G1Affine], scalars: &[Zp], gpu_index: Option<u32>) -> G1 {
|
||||
use crate::curve_446::g1::G1Projective;
|
||||
|
||||
// Convert points to zk-cuda-backend format (normal form)
|
||||
let gpu_bases: Vec<_> = bases
|
||||
.iter()
|
||||
.map(|b| zk_cuda_backend::g1_affine_from_arkworks(&b.inner))
|
||||
.collect();
|
||||
|
||||
let gpu_scalars: Vec<_> = scalars
|
||||
.iter()
|
||||
.map(|s| zk_cuda_backend::Scalar::from(s.inner.into_bigint().0))
|
||||
.collect();
|
||||
|
||||
let gpu_index = gpu_index.unwrap_or(0);
|
||||
let num_gpus = get_num_gpus();
|
||||
assert!(
|
||||
gpu_index < num_gpus,
|
||||
"gpu_index {gpu_index} exceeds available GPUs ({num_gpus})",
|
||||
);
|
||||
// SAFETY: gpu_index was validated by the assert above
|
||||
let stream = unsafe { cuda_create_stream(gpu_index) };
|
||||
|
||||
let result =
|
||||
zk_cuda_backend::G1Projective::msm(&gpu_bases, &gpu_scalars, stream, gpu_index, false);
|
||||
|
||||
// SAFETY: stream was created by cuda_create_stream above with the same gpu_index and is not
|
||||
// used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
|
||||
let (gpu_result, _size_tracker) = result.unwrap_or_else(|e| panic!("G1 GPU MSM failed: {e}"));
|
||||
|
||||
// Convert result from Montgomery form back to arkworks types
|
||||
let normalized = gpu_result.from_montgomery_normalized();
|
||||
|
||||
let z_fp = normalized.Z();
|
||||
if z_fp.limb.iter().all(|&limb| limb == 0) {
|
||||
return G1::ZERO;
|
||||
}
|
||||
|
||||
let x = fq_from_cuda_fp(&normalized.X());
|
||||
let y = fq_from_cuda_fp(&normalized.Y());
|
||||
let z = fq_from_cuda_fp(&z_fp);
|
||||
G1 {
|
||||
inner: G1Projective::new_unchecked(x, y, z),
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU-accelerated multi-scalar multiplication for G2. The `gpu_index` parameter
|
||||
/// selects which GPU device to use; `None` defaults to GPU 0.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// - If `gpu_index` is `Some(i)` where `i >= number of available GPUs`.
|
||||
/// - If `bases` and `scalars` have different lengths (checked inside the backend).
|
||||
/// - If the GPU MSM call fails.
|
||||
pub fn g2_msm_gpu(bases: &[G2Affine], scalars: &[Zp], gpu_index: Option<u32>) -> G2 {
|
||||
use crate::curve_446::g2::G2Projective;
|
||||
use ark_ec::AffineRepr;
|
||||
|
||||
// Convert points to zk-cuda-backend format (normal form)
|
||||
let gpu_bases: Vec<_> = bases
|
||||
.iter()
|
||||
.map(|b| {
|
||||
if b.inner.is_zero() {
|
||||
return zk_cuda_backend::G2Affine::infinity();
|
||||
}
|
||||
let x = fq2_to_cuda_fp2(&b.inner.x);
|
||||
let y = fq2_to_cuda_fp2(&b.inner.y);
|
||||
zk_cuda_backend::G2Affine::new(x, y, false)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let gpu_scalars: Vec<_> = scalars
|
||||
.iter()
|
||||
.map(|s| zk_cuda_backend::Scalar::from(s.inner.into_bigint().0))
|
||||
.collect();
|
||||
|
||||
let gpu_index = gpu_index.unwrap_or(0);
|
||||
let num_gpus = get_num_gpus();
|
||||
assert!(
|
||||
gpu_index < num_gpus,
|
||||
"gpu_index {gpu_index} exceeds available GPUs ({num_gpus})",
|
||||
);
|
||||
// SAFETY: gpu_index was validated by the assert above
|
||||
let stream = unsafe { cuda_create_stream(gpu_index) };
|
||||
|
||||
let result =
|
||||
zk_cuda_backend::G2Projective::msm(&gpu_bases, &gpu_scalars, stream, gpu_index, false);
|
||||
|
||||
// SAFETY: stream was created by cuda_create_stream above with the same gpu_index and is not
|
||||
// used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
|
||||
let (gpu_result, _size_tracker) = result.unwrap_or_else(|e| panic!("G2 GPU MSM failed: {e}"));
|
||||
|
||||
// Convert result from Montgomery form back to arkworks types
|
||||
let normalized = gpu_result.from_montgomery_normalized();
|
||||
let x_fp2 = normalized.X();
|
||||
let y_fp2 = normalized.Y();
|
||||
let z_fp2 = normalized.Z();
|
||||
|
||||
let z_is_zero =
|
||||
z_fp2.c0.limb.iter().all(|&limb| limb == 0) && z_fp2.c1.limb.iter().all(|&limb| limb == 0);
|
||||
if z_is_zero {
|
||||
return G2::ZERO;
|
||||
}
|
||||
|
||||
let x = fq2_from_cuda_fp2(&x_fp2);
|
||||
let y = fq2_from_cuda_fp2(&y_fp2);
|
||||
let z = fq2_from_cuda_fp2(&z_fp2);
|
||||
G2 {
|
||||
inner: G2Projective::new_unchecked(x, y, z),
|
||||
}
|
||||
}
|
||||
7
tfhe-zk-pok/src/gpu/tests/mod.rs
Normal file
7
tfhe-zk-pok/src/gpu/tests/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Integration tests for zk-cuda-backend
|
||||
//!
|
||||
//! These tests verify that zk-cuda-backend produces correct results
|
||||
//! by comparing against tfhe-zk-pok's CPU implementation.
|
||||
|
||||
mod prove_verify_stress;
|
||||
mod zk_cuda_backend;
|
||||
417
tfhe-zk-pok/src/gpu/tests/prove_verify_stress.rs
Normal file
417
tfhe-zk-pok/src/gpu/tests/prove_verify_stress.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
//! Stress tests comparing GPU and CPU prove/verify for PKE v1 and v2.
|
||||
//!
|
||||
//! These tests mirror the `test_pke()` functions in `proofs/pke.rs` (v1) and
|
||||
//! `proofs/pke_v2/mod.rs` (v2), running GPU prove/verify alongside CPU to
|
||||
//! verify that:
|
||||
//! - Serialized proofs are byte-identical between CPU and GPU.
|
||||
//! - CPU and GPU verifiers agree on accept/reject for every input.
|
||||
//! - Both directions are covered: GPU-prove→GPU-verify and CPU-prove→GPU-verify.
|
||||
//!
|
||||
//! Each test iterates over 3 CRS variants (original, compressed, not-compressed)
|
||||
//! × 32 combinations of invalid inputs (e1, e2, r, m, metadata) × 2 compute
|
||||
//! loads (Proof, Verify), yielding 192 iterations per test. The v2 test doubles
|
||||
//! that by also sweeping both pairing modes (TwoSteps, Batched).
|
||||
|
||||
use crate::curve_api::Bls12_446;
|
||||
use crate::proofs::pke_v2::VerificationPairingMode;
|
||||
use crate::proofs::test::*;
|
||||
use crate::proofs::{pke, pke_v2, ComputeLoad};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{thread_rng, Rng, SeedableRng};
|
||||
|
||||
type Curve = Bls12_446;
|
||||
|
||||
/// Exhaustive GPU-vs-CPU equivalence test for PKE v1.
|
||||
///
|
||||
/// For every combination of (CRS variant, invalid flag) we:
|
||||
/// 1. Prove on CPU and GPU with identical inputs.
|
||||
/// 2. Assert the serialized proofs are byte-identical.
|
||||
/// 3. Verify the GPU proof on both CPU and GPU, checking that both agree on accept/reject and
|
||||
/// that the outcome matches the expected validity.
|
||||
#[test]
|
||||
fn test_pke_v1_gpu_cpu_equivalence() {
|
||||
let params = crate::proofs::pke::tests::PKEV1_TEST_PARAMS;
|
||||
let PkeTestParameters {
|
||||
d,
|
||||
k,
|
||||
B,
|
||||
q,
|
||||
t,
|
||||
msbs_zero_padding_bit_count,
|
||||
} = params;
|
||||
|
||||
let seed = thread_rng().gen();
|
||||
let rng = &mut StdRng::seed_from_u64(seed);
|
||||
|
||||
// Generate a valid test case: keys, randomness, noise, and plaintext
|
||||
let valid_testcase = PkeTestcase::gen(rng, params);
|
||||
let ct = valid_testcase.encrypt(params);
|
||||
|
||||
// Independent witnesses for rejection testing. The values are
|
||||
// in-range (same distribution as `testcase`), but they are the *wrong*
|
||||
// witnesses for `ct` — proving with them produces a proof that does not
|
||||
// satisfy the ciphertext-witness relation, so verification must reject.
|
||||
let invalid_testcase = PkeTestcase::gen(rng, params);
|
||||
|
||||
// CRS k > message count k exercises the path where the CRS is larger
|
||||
// than strictly needed, as happens in production
|
||||
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
|
||||
|
||||
// Same three CRS variants as in `test_pke()`: original, round-tripped
|
||||
// through compressed serde, and round-tripped through uncompressed serde.
|
||||
let original_public_param =
|
||||
pke::crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
|
||||
let public_param_that_was_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
|
||||
let public_param_that_was_not_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
|
||||
|
||||
// Runs for all combinations: 3 CRS variants × 32 invalid-witness flags
|
||||
let cases = itertools::iproduct!(
|
||||
[
|
||||
original_public_param,
|
||||
public_param_that_was_compressed,
|
||||
public_param_that_was_not_compressed
|
||||
],
|
||||
[false, true], // e1
|
||||
[false, true], // e2
|
||||
[false, true], // m
|
||||
[false, true], // r
|
||||
[false, true] // metadata
|
||||
);
|
||||
|
||||
for (
|
||||
public_param,
|
||||
use_invalid_e1,
|
||||
use_invalid_e2,
|
||||
use_invalid_m,
|
||||
use_invalid_r,
|
||||
use_invalid_metadata,
|
||||
) in cases
|
||||
{
|
||||
// Build the commit, substituting invalid witnesses where flagged.
|
||||
let (public_commit, private_commit) = pke::commit(
|
||||
valid_testcase.a.clone(),
|
||||
valid_testcase.b.clone(),
|
||||
ct.c1.clone(),
|
||||
ct.c2.clone(),
|
||||
(if use_invalid_r {
|
||||
&invalid_testcase.r
|
||||
} else {
|
||||
&valid_testcase.r
|
||||
})
|
||||
.clone(),
|
||||
(if use_invalid_e1 {
|
||||
&invalid_testcase.e1
|
||||
} else {
|
||||
&valid_testcase.e1
|
||||
})
|
||||
.clone(),
|
||||
(if use_invalid_m {
|
||||
&invalid_testcase.m
|
||||
} else {
|
||||
&valid_testcase.m
|
||||
})
|
||||
.clone(),
|
||||
(if use_invalid_e2 {
|
||||
&invalid_testcase.e2
|
||||
} else {
|
||||
&valid_testcase.e2
|
||||
})
|
||||
.clone(),
|
||||
&public_param,
|
||||
);
|
||||
|
||||
// ComputeLoad::Proof shifts work to the prover; ComputeLoad::Verify
|
||||
// shifts it to the verifier. Both must yield identical results.
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
// Produce proofs from identical inputs on CPU and GPU
|
||||
let cpu_proof = pke::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&valid_testcase.metadata,
|
||||
load,
|
||||
&seed.to_be_bytes(),
|
||||
);
|
||||
|
||||
let gpu_proof = pke::gpu::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&valid_testcase.metadata,
|
||||
load,
|
||||
&seed.to_be_bytes(),
|
||||
);
|
||||
|
||||
// GPU MSM is exact integer arithmetic, so the serialized proofs
|
||||
// must be byte-for-byte identical
|
||||
assert_eq!(
|
||||
bincode::serialize(&cpu_proof).unwrap(),
|
||||
bincode::serialize(&gpu_proof).unwrap(),
|
||||
"v1 proof mismatch: load={load}, invalid_e1={use_invalid_e1}, \
|
||||
invalid_e2={use_invalid_e2}, invalid_m={use_invalid_m}, \
|
||||
invalid_r={use_invalid_r}",
|
||||
);
|
||||
|
||||
// When invalid metadata is used at verification time (but not at
|
||||
// proving time), the proof is valid but Fiat-Shamir binding fails
|
||||
let verify_metadata = if use_invalid_metadata {
|
||||
&invalid_testcase.metadata
|
||||
} else {
|
||||
&valid_testcase.metadata
|
||||
};
|
||||
|
||||
// Any single invalid input should cause verification to reject
|
||||
let should_fail = use_invalid_e1
|
||||
|| use_invalid_e2
|
||||
|| use_invalid_r
|
||||
|| use_invalid_m
|
||||
|| use_invalid_metadata;
|
||||
|
||||
// --- Verify GPU proof on CPU ---
|
||||
let cpu_verify_result =
|
||||
pke::verify(&gpu_proof, (&public_param, &public_commit), verify_metadata);
|
||||
assert_eq!(
|
||||
cpu_verify_result.is_err(),
|
||||
should_fail,
|
||||
"v1 CPU verify mismatch: load={load}, should_fail={should_fail}",
|
||||
);
|
||||
|
||||
// --- Verify GPU proof on GPU ---
|
||||
let gpu_verify_result =
|
||||
pke::gpu::verify(&gpu_proof, (&public_param, &public_commit), verify_metadata);
|
||||
assert_eq!(
|
||||
gpu_verify_result.is_err(),
|
||||
should_fail,
|
||||
"v1 GPU verify mismatch: load={load}, should_fail={should_fail}",
|
||||
);
|
||||
|
||||
// CPU and GPU verifiers must produce the same Result value
|
||||
assert_eq!(
|
||||
cpu_verify_result, gpu_verify_result,
|
||||
"v1 CPU/GPU verify disagree: load={load}",
|
||||
);
|
||||
|
||||
// --- Cross-direction: GPU verify of CPU-produced proof ---
|
||||
// Although proofs are byte-identical, this exercises pke::gpu::verify
|
||||
// with a proof object constructed entirely on the CPU side, ensuring
|
||||
// the GPU verifier does not depend on any GPU-side proof metadata.
|
||||
let gpu_verify_cpu_proof =
|
||||
pke::gpu::verify(&cpu_proof, (&public_param, &public_commit), verify_metadata);
|
||||
assert_eq!(
|
||||
gpu_verify_cpu_proof, cpu_verify_result,
|
||||
"v1 GPU-verify-of-CPU-proof disagrees with CPU-verify: load={load}",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exhaustive GPU-vs-CPU equivalence test for PKE v2.
|
||||
///
|
||||
/// Same structure as the v1 test but exercises `pke_v2` functions. Key
|
||||
/// differences from v1:
|
||||
/// - Uses `PKEV2_TEST_PARAMS` and `pke_v2::*` functions.
|
||||
/// - Seed bytes are little-endian (`to_le_bytes`) per v2 convention.
|
||||
/// - Verify takes a `VerificationPairingMode`; we test both `TwoSteps` and `Batched` and assert
|
||||
/// they agree.
|
||||
#[test]
|
||||
fn test_pke_v2_gpu_cpu_equivalence() {
|
||||
let params = crate::proofs::pke_v2::tests::PKEV2_TEST_PARAMS;
|
||||
let PkeTestParameters {
|
||||
d,
|
||||
k,
|
||||
B,
|
||||
q,
|
||||
t,
|
||||
msbs_zero_padding_bit_count,
|
||||
} = params;
|
||||
|
||||
let seed = thread_rng().gen();
|
||||
let rng = &mut StdRng::seed_from_u64(seed);
|
||||
|
||||
// Generate a valid test case: keys, randomness, noise, and plaintext
|
||||
let testcase = PkeTestcase::gen(rng, params);
|
||||
let ct = testcase.encrypt(params);
|
||||
|
||||
// Independent witnesses for rejection testing. The values are
|
||||
// in-range (same distribution as `testcase`), but they are the *wrong*
|
||||
// witnesses for `ct` — proving with them produces a proof that does not
|
||||
// satisfy the ciphertext-witness relation, so verification must reject.
|
||||
let invalid_testcase = PkeTestcase::gen(rng, params);
|
||||
|
||||
// CRS k > message count k exercises the path where the CRS is larger
|
||||
// than strictly needed, as happens in production
|
||||
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
|
||||
|
||||
// Same three CRS variants as in `test_pke()`: original, round-tripped
|
||||
// through compressed serde, and round-tripped through uncompressed serde.
|
||||
let original_public_param =
|
||||
pke_v2::crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
|
||||
let public_param_that_was_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
|
||||
let public_param_that_was_not_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
|
||||
|
||||
// Sweep all combinations: 3 CRS variants × 32 invalid-witness flags
|
||||
let cases = itertools::iproduct!(
|
||||
[
|
||||
original_public_param,
|
||||
public_param_that_was_compressed,
|
||||
public_param_that_was_not_compressed
|
||||
],
|
||||
[false, true], // r
|
||||
[false, true], // e1
|
||||
[false, true], // e2
|
||||
[false, true], // m
|
||||
[false, true] // metadata
|
||||
);
|
||||
|
||||
for (
|
||||
public_param,
|
||||
use_invalid_r,
|
||||
use_invalid_e1,
|
||||
use_invalid_e2,
|
||||
use_invalid_m,
|
||||
use_invalid_metadata,
|
||||
) in cases
|
||||
{
|
||||
// Build the commit, substituting invalid witnesses where flagged.
|
||||
let (public_commit, private_commit) = pke_v2::commit(
|
||||
testcase.a.clone(),
|
||||
testcase.b.clone(),
|
||||
ct.c1.clone(),
|
||||
ct.c2.clone(),
|
||||
(if use_invalid_r {
|
||||
&invalid_testcase.r
|
||||
} else {
|
||||
&testcase.r
|
||||
})
|
||||
.clone(),
|
||||
(if use_invalid_e1 {
|
||||
&invalid_testcase.e1
|
||||
} else {
|
||||
&testcase.e1
|
||||
})
|
||||
.clone(),
|
||||
(if use_invalid_m {
|
||||
&invalid_testcase.m
|
||||
} else {
|
||||
&testcase.m
|
||||
})
|
||||
.clone(),
|
||||
(if use_invalid_e2 {
|
||||
&invalid_testcase.e2
|
||||
} else {
|
||||
&testcase.e2
|
||||
})
|
||||
.clone(),
|
||||
&public_param,
|
||||
);
|
||||
|
||||
// ComputeLoad::Proof shifts work to the prover; ComputeLoad::Verify
|
||||
// shifts it to the verifier. Both must yield identical results.
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
// Produce proofs from identical inputs on CPU and GPU.
|
||||
// v2 convention: seed bytes are little-endian (to_le_bytes).
|
||||
let cpu_proof = pke_v2::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&testcase.metadata,
|
||||
load,
|
||||
&seed.to_le_bytes(),
|
||||
);
|
||||
|
||||
let gpu_proof = pke_v2::gpu::prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
&testcase.metadata,
|
||||
load,
|
||||
&seed.to_le_bytes(),
|
||||
);
|
||||
|
||||
// GPU MSM is exact integer arithmetic, so the serialized proofs
|
||||
// must be byte-for-byte identical
|
||||
assert_eq!(
|
||||
bincode::serialize(&cpu_proof).unwrap(),
|
||||
bincode::serialize(&gpu_proof).unwrap(),
|
||||
"v2 proof mismatch: load={load}, invalid_r={use_invalid_r}, \
|
||||
invalid_e1={use_invalid_e1}, invalid_e2={use_invalid_e2}, \
|
||||
invalid_m={use_invalid_m}",
|
||||
);
|
||||
|
||||
// When invalid metadata is used at verification time (but not at
|
||||
// proving time), the proof is valid but Fiat-Shamir binding fails
|
||||
let verify_metadata = if use_invalid_metadata {
|
||||
&invalid_testcase.metadata
|
||||
} else {
|
||||
&testcase.metadata
|
||||
};
|
||||
|
||||
// Any single invalid input should cause verification to reject
|
||||
let should_fail = use_invalid_e1
|
||||
|| use_invalid_e2
|
||||
|| use_invalid_r
|
||||
|| use_invalid_m
|
||||
|| use_invalid_metadata;
|
||||
|
||||
// v2 supports two pairing strategies for verification:
|
||||
// - TwoSteps: two independent pairing checks
|
||||
// - Batched: single batched pairing check (faster but same result)
|
||||
// Both must agree with each other and with the CPU verifier.
|
||||
for pairing_mode in [
|
||||
VerificationPairingMode::TwoSteps,
|
||||
VerificationPairingMode::Batched,
|
||||
] {
|
||||
// --- Verify GPU proof on CPU ---
|
||||
let cpu_verify_result = pke_v2::verify(
|
||||
&gpu_proof,
|
||||
(&public_param, &public_commit),
|
||||
verify_metadata,
|
||||
pairing_mode,
|
||||
);
|
||||
assert_eq!(
|
||||
cpu_verify_result.is_err(),
|
||||
should_fail,
|
||||
"v2 CPU verify mismatch: load={load}, mode={pairing_mode:?}, \
|
||||
should_fail={should_fail}",
|
||||
);
|
||||
|
||||
// --- Verify GPU proof on GPU ---
|
||||
let gpu_verify_result = pke_v2::gpu::verify(
|
||||
&gpu_proof,
|
||||
(&public_param, &public_commit),
|
||||
verify_metadata,
|
||||
pairing_mode,
|
||||
);
|
||||
assert_eq!(
|
||||
gpu_verify_result.is_err(),
|
||||
should_fail,
|
||||
"v2 GPU verify mismatch: load={load}, mode={pairing_mode:?}, \
|
||||
should_fail={should_fail}",
|
||||
);
|
||||
|
||||
// CPU and GPU verifiers must produce the same Result value
|
||||
assert_eq!(
|
||||
cpu_verify_result, gpu_verify_result,
|
||||
"v2 CPU/GPU verify disagree: load={load}, mode={pairing_mode:?}",
|
||||
);
|
||||
|
||||
// --- Cross-direction: GPU verify of CPU-produced proof ---
|
||||
// Although proofs are byte-identical, this exercises
|
||||
// pke_v2::gpu::verify with a proof object constructed entirely
|
||||
// on the CPU side, ensuring the GPU verifier does not depend
|
||||
// on any GPU-side proof metadata.
|
||||
let gpu_verify_cpu_proof = pke_v2::gpu::verify(
|
||||
&cpu_proof,
|
||||
(&public_param, &public_commit),
|
||||
verify_metadata,
|
||||
pairing_mode,
|
||||
);
|
||||
assert_eq!(
|
||||
gpu_verify_cpu_proof, cpu_verify_result,
|
||||
"v2 GPU-verify-of-CPU-proof disagrees with CPU-verify: load={load}, \
|
||||
mode={pairing_mode:?}",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
392
tfhe-zk-pok/src/gpu/tests/zk_cuda_backend.rs
Normal file
392
tfhe-zk-pok/src/gpu/tests/zk_cuda_backend.rs
Normal file
@@ -0,0 +1,392 @@
|
||||
//! Tests comparing zk-cuda-backend MSM results against tfhe-zk-pok CPU implementation
|
||||
|
||||
use crate::curve_api::bls12_446::{Zp, G1, G2};
|
||||
use crate::curve_api::CurveGroupOps;
|
||||
use crate::gpu::{
|
||||
g1_affine_from_zk_cuda, g1_affine_to_zk_cuda, g2_affine_from_zk_cuda, g2_affine_to_zk_cuda,
|
||||
};
|
||||
use tfhe_cuda_backend::cuda_bind::{cuda_create_stream, cuda_destroy_stream};
|
||||
use zk_cuda_backend::conversions::{g1_affine_from_montgomery, g2_affine_from_montgomery};
|
||||
use zk_cuda_backend::{
|
||||
G1Affine as ZkG1Affine, G1Projective as ZkG1Projective, G2Affine as ZkG2Affine,
|
||||
G2Projective as ZkG2Projective, Scalar as ZkScalar,
|
||||
};
|
||||
|
||||
/// Helper function to compute triangular number: N * (N+1) / 2
|
||||
fn triangular_number(n: u64) -> u64 {
|
||||
n * (n + 1) / 2
|
||||
}
|
||||
|
||||
/// BLS12-446 scalar field modulus minus one (r - 1), little-endian limbs.
|
||||
/// Used by canceling-scalar tests: 1*G + (r-1)*G = r*G = O.
|
||||
const R_MINUS_1: [u64; 5] = [
|
||||
0x0428001400040000,
|
||||
0x7bb9b0e8d8ca3461,
|
||||
0xd04c98ccc4c050bc,
|
||||
0x7995b34995830fa4,
|
||||
0x00000511b70539f2,
|
||||
];
|
||||
|
||||
/// Check that a G1 projective point has Z == 0 (point at infinity).
|
||||
/// Converts from Montgomery form first so limbs are directly comparable.
|
||||
fn g1_proj_z_is_zero(p: &ZkG1Projective) -> bool {
|
||||
let z = p.from_montgomery_normalized().Z();
|
||||
z.limb.iter().all(|&limb| limb == 0)
|
||||
}
|
||||
|
||||
/// Check that a G2 projective point has Z == 0 (point at infinity).
|
||||
/// G2 lives over Fp2, so both c0 and c1 components must be zero.
|
||||
fn g2_proj_z_is_zero(p: &ZkG2Projective) -> bool {
|
||||
let z = p.from_montgomery_normalized().Z();
|
||||
z.c0.limb.iter().all(|&limb| limb == 0) && z.c1.limb.iter().all(|&limb| limb == 0)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Macro: generates G1 and G2 MSM test variants from a single template.
|
||||
//
|
||||
// Parameters:
|
||||
// $Group - tfhe-zk-pok curve group (G1 or G2)
|
||||
// $ZkAffine - zk-cuda-backend affine type (ZkG1Affine or ZkG2Affine)
|
||||
// $ZkProjective - zk-cuda-backend projective type
|
||||
// $to_zk - conversion: tfhe-zk-pok affine -> zk-cuda affine
|
||||
// $from_zk - conversion: zk-cuda affine -> tfhe-zk-pok affine
|
||||
// $from_mont - Montgomery-to-normal conversion for zk-cuda affine
|
||||
// $proj_z_is_zero - fn to check projective Z == 0 (group-specific)
|
||||
// $test_* - test function names (avoids `paste` dependency)
|
||||
// =============================================================================
|
||||
macro_rules! msm_tests {
|
||||
(
|
||||
group: $Group:ty,
|
||||
zk_affine: $ZkAffine:ty,
|
||||
zk_proj: $ZkProjective:ty,
|
||||
to_zk: $to_zk:expr,
|
||||
from_zk: $from_zk:expr,
|
||||
from_mont: $from_mont:expr,
|
||||
proj_z_is_zero: $proj_z_is_zero:expr,
|
||||
label: $label:expr,
|
||||
test_large_n: $test_large_n:ident,
|
||||
test_zero_scalars: $test_zero_scalars:ident,
|
||||
test_canceling: $test_canceling:ident,
|
||||
test_infinity_input: $test_infinity_input:ident
|
||||
) => {
|
||||
#[test]
|
||||
fn $test_large_n() {
|
||||
const MAX_N: u64 = 100;
|
||||
|
||||
let gen = <$Group>::GENERATOR.normalize();
|
||||
let gen_zk = $to_zk(&gen);
|
||||
|
||||
// Probe CUDA availability with a trivial MSM before the sweep
|
||||
{
|
||||
let probe_points: Vec<$ZkAffine> = vec![gen_zk];
|
||||
let probe_scalars: Vec<ZkScalar> = vec![ZkScalar::from_u64(1)];
|
||||
// SAFETY: gpu_index 0 is valid (checked by test setup)
|
||||
let probe_stream = unsafe { cuda_create_stream(0) };
|
||||
if <$ZkProjective>::msm(&probe_points, &probe_scalars, probe_stream, 0, false)
|
||||
.is_err()
|
||||
{
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(probe_stream, 0) };
|
||||
eprintln!("CUDA not available - Skipping test");
|
||||
return;
|
||||
}
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(probe_stream, 0) };
|
||||
}
|
||||
|
||||
// Sweep N from 1..=MAX_N: points = [G; N], scalars = [1..=N].
|
||||
// Expected result = G * triangular(N).
|
||||
for n in 1..=MAX_N {
|
||||
let points: Vec<$ZkAffine> = (0..n).map(|_| gen_zk).collect();
|
||||
let scalars: Vec<ZkScalar> = (1..=n).map(ZkScalar::from_u64).collect();
|
||||
|
||||
// SAFETY: gpu_index 0 is valid (checked by test setup)
|
||||
let stream = unsafe { cuda_create_stream(0) };
|
||||
let (gpu_result_proj, _size_tracker) =
|
||||
<$ZkProjective>::msm(&points, &scalars, stream, 0, false)
|
||||
.unwrap_or_else(|_| panic!("CUDA MSM failed at N={}", n));
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, 0) };
|
||||
|
||||
let gpu_result = $from_mont(&gpu_result_proj.to_affine());
|
||||
|
||||
let expected_scalar = Zp::from_u64(triangular_number(n));
|
||||
let cpu_result = <$Group>::GENERATOR.mul_scalar(expected_scalar).normalize();
|
||||
|
||||
let gpu_tfhe = $from_zk(&gpu_result);
|
||||
|
||||
assert_eq!(
|
||||
$to_zk(&gpu_tfhe).is_infinity(),
|
||||
$to_zk(&cpu_result).is_infinity(),
|
||||
"{} MSM large_n: N={} infinity mismatch",
|
||||
$label,
|
||||
n
|
||||
);
|
||||
if !gpu_result.is_infinity() {
|
||||
assert_eq!(
|
||||
$to_zk(&gpu_tfhe).x(),
|
||||
$to_zk(&cpu_result).x(),
|
||||
"{} MSM large_n: N={} x mismatch",
|
||||
$label,
|
||||
n
|
||||
);
|
||||
assert_eq!(
|
||||
$to_zk(&gpu_tfhe).y(),
|
||||
$to_zk(&cpu_result).y(),
|
||||
"{} MSM large_n: N={} y mismatch",
|
||||
$label,
|
||||
n
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn $test_zero_scalars() {
|
||||
let gen = <$Group>::GENERATOR.normalize();
|
||||
let gen_zk = $to_zk(&gen);
|
||||
|
||||
// All-zero scalars: 0*G + 0*G + ... = O
|
||||
let points: Vec<$ZkAffine> = vec![gen_zk; 5];
|
||||
let scalars: Vec<ZkScalar> = vec![ZkScalar::from_u64(0); 5];
|
||||
|
||||
let gpu_index = 0;
|
||||
// SAFETY: gpu_index 0 is valid (checked by test setup)
|
||||
let stream = unsafe { cuda_create_stream(gpu_index) };
|
||||
let result_proj =
|
||||
match <$ZkProjective>::msm(&points, &scalars, stream, gpu_index, false) {
|
||||
Ok((result, _size_tracker)) => result,
|
||||
Err(e) => {
|
||||
eprintln!("CUDA MSM failed: {} - Skipping test", e);
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
return;
|
||||
}
|
||||
};
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
|
||||
let is_infinity = ($proj_z_is_zero)(&result_proj);
|
||||
assert!(
|
||||
is_infinity,
|
||||
"{} MSM with all-zero scalars should return infinity",
|
||||
$label
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn $test_canceling() {
|
||||
let gen = <$Group>::GENERATOR.normalize();
|
||||
let gen_zk = $to_zk(&gen);
|
||||
|
||||
// 1*G + (r-1)*G = r*G = O
|
||||
let points: Vec<$ZkAffine> = vec![gen_zk, gen_zk];
|
||||
let scalars: Vec<ZkScalar> = vec![ZkScalar::from_u64(1), ZkScalar::from(R_MINUS_1)];
|
||||
|
||||
let gpu_index = 0;
|
||||
// SAFETY: gpu_index 0 is valid (checked by test setup)
|
||||
let stream = unsafe { cuda_create_stream(gpu_index) };
|
||||
let result_proj =
|
||||
match <$ZkProjective>::msm(&points, &scalars, stream, gpu_index, false) {
|
||||
Ok((result, _size_tracker)) => result,
|
||||
Err(e) => {
|
||||
eprintln!("CUDA MSM failed: {} - Skipping test", e);
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
return;
|
||||
}
|
||||
};
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
|
||||
let is_infinity = ($proj_z_is_zero)(&result_proj);
|
||||
assert!(
|
||||
is_infinity,
|
||||
"{} MSM with canceling scalars (1*G + (r-1)*G) should return infinity",
|
||||
$label
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn $test_infinity_input() {
|
||||
let gen = <$Group>::GENERATOR.normalize();
|
||||
let gen_zk = $to_zk(&gen);
|
||||
let inf = <$ZkAffine>::infinity();
|
||||
|
||||
// 5*O + 3*G + 7*O = 3*G (infinity inputs contribute nothing)
|
||||
let points: Vec<$ZkAffine> = vec![inf, gen_zk, inf];
|
||||
let scalars: Vec<ZkScalar> = vec![
|
||||
ZkScalar::from_u64(5),
|
||||
ZkScalar::from_u64(3),
|
||||
ZkScalar::from_u64(7),
|
||||
];
|
||||
|
||||
let gpu_index = 0;
|
||||
// SAFETY: gpu_index 0 is valid (checked by test setup)
|
||||
let stream = unsafe { cuda_create_stream(gpu_index) };
|
||||
let result_proj =
|
||||
match <$ZkProjective>::msm(&points, &scalars, stream, gpu_index, false) {
|
||||
Ok((result, _size_tracker)) => result,
|
||||
Err(e) => {
|
||||
eprintln!("CUDA MSM failed: {} - Skipping test", e);
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
return;
|
||||
}
|
||||
};
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
|
||||
let expected = <$Group>::GENERATOR.mul_scalar(Zp::from_u64(3)).normalize();
|
||||
let expected_zk = $to_zk(&expected);
|
||||
|
||||
let result = $from_mont(&result_proj.to_affine());
|
||||
|
||||
assert_eq!(
|
||||
result.x(),
|
||||
expected_zk.x(),
|
||||
"{} MSM with infinity points: x mismatch",
|
||||
$label
|
||||
);
|
||||
assert_eq!(
|
||||
result.y(),
|
||||
expected_zk.y(),
|
||||
"{} MSM with infinity points: y mismatch",
|
||||
$label
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Generate G1 MSM tests
|
||||
msm_tests! {
|
||||
group: G1,
|
||||
zk_affine: ZkG1Affine,
|
||||
zk_proj: ZkG1Projective,
|
||||
to_zk: g1_affine_to_zk_cuda,
|
||||
from_zk: g1_affine_from_zk_cuda,
|
||||
from_mont: g1_affine_from_montgomery,
|
||||
proj_z_is_zero: g1_proj_z_is_zero,
|
||||
label: "G1",
|
||||
test_large_n: test_g1_msm_large_n,
|
||||
test_zero_scalars: test_g1_msm_zero_scalars_returns_infinity,
|
||||
test_canceling: test_g1_msm_canceling_scalars_returns_infinity,
|
||||
test_infinity_input: test_g1_msm_infinity_point_input
|
||||
}
|
||||
|
||||
// Generate G2 MSM tests
|
||||
msm_tests! {
|
||||
group: G2,
|
||||
zk_affine: ZkG2Affine,
|
||||
zk_proj: ZkG2Projective,
|
||||
to_zk: g2_affine_to_zk_cuda,
|
||||
from_zk: g2_affine_from_zk_cuda,
|
||||
from_mont: g2_affine_from_montgomery,
|
||||
proj_z_is_zero: g2_proj_z_is_zero,
|
||||
label: "G2",
|
||||
test_large_n: test_g2_msm_large_n,
|
||||
test_zero_scalars: test_g2_msm_zero_scalars_returns_infinity,
|
||||
test_canceling: test_g2_msm_canceling_scalars_returns_infinity,
|
||||
test_infinity_input: test_g2_msm_infinity_point_input
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Non-macro tests: these test unique behavior not shared between G1/G2
|
||||
// =============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_g1_conversion_roundtrip() {
|
||||
let tfhe_g1_gen = G1::GENERATOR.normalize();
|
||||
let zk_g1 = g1_affine_to_zk_cuda(&tfhe_g1_gen);
|
||||
let tfhe_g1_again = g1_affine_from_zk_cuda(&zk_g1);
|
||||
let zk_g1_again = g1_affine_to_zk_cuda(&tfhe_g1_again);
|
||||
|
||||
assert_eq!(zk_g1.x(), zk_g1_again.x());
|
||||
assert_eq!(zk_g1.y(), zk_g1_again.y());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_g2_conversion_roundtrip() {
|
||||
let tfhe_g2_gen = G2::GENERATOR.normalize();
|
||||
let zk_g2 = g2_affine_to_zk_cuda(&tfhe_g2_gen);
|
||||
let tfhe_g2_again = g2_affine_from_zk_cuda(&zk_g2);
|
||||
let zk_g2_again = g2_affine_to_zk_cuda(&tfhe_g2_again);
|
||||
|
||||
assert_eq!(zk_g2.x(), zk_g2_again.x());
|
||||
assert_eq!(zk_g2.y(), zk_g2_again.y());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_g1_msm_multi_limb_scalar() {
|
||||
let tfhe_g1_gen = G1::GENERATOR.normalize();
|
||||
let g1_gen = g1_affine_to_zk_cuda(&tfhe_g1_gen);
|
||||
|
||||
// Scalar = 2^64 (requires 2 limbs) to exercise multi-limb scalar handling
|
||||
let scalar = ZkScalar::new([0u64, 1u64, 0u64, 0u64, 0u64]);
|
||||
let points: Vec<ZkG1Affine> = vec![g1_gen];
|
||||
let scalars: Vec<ZkScalar> = vec![scalar];
|
||||
|
||||
let gpu_index = 0;
|
||||
// SAFETY: gpu_index 0 is valid (checked by test setup)
|
||||
let stream = unsafe { cuda_create_stream(gpu_index) };
|
||||
let gpu_result_proj = match ZkG1Projective::msm(&points, &scalars, stream, gpu_index, false) {
|
||||
Ok((result, _size_tracker)) => result,
|
||||
Err(e) => {
|
||||
eprintln!("CUDA MSM failed: {} - Skipping test", e);
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
return;
|
||||
}
|
||||
};
|
||||
// SAFETY: stream was created above and is not used after this point
|
||||
unsafe { cuda_destroy_stream(stream, gpu_index) };
|
||||
|
||||
let gpu_result_mont = gpu_result_proj.to_affine();
|
||||
let gpu_result = g1_affine_from_montgomery(&gpu_result_mont);
|
||||
|
||||
let cpu_scalar = Zp::from_bigint([0u64, 1u64, 0u64, 0u64, 0u64]);
|
||||
let cpu_result = G1::GENERATOR.mul_scalar(cpu_scalar).normalize();
|
||||
|
||||
let gpu_result_tfhe = g1_affine_from_zk_cuda(&gpu_result);
|
||||
|
||||
assert_eq!(
|
||||
g1_affine_to_zk_cuda(&gpu_result_tfhe).x(),
|
||||
g1_affine_to_zk_cuda(&cpu_result).x(),
|
||||
"G1 MSM 2^64 scalar x mismatch"
|
||||
);
|
||||
assert_eq!(
|
||||
g1_affine_to_zk_cuda(&gpu_result_tfhe).y(),
|
||||
g1_affine_to_zk_cuda(&cpu_result).y(),
|
||||
"G1 MSM 2^64 scalar y mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test scalar validation functions
|
||||
#[test]
|
||||
fn test_scalar_validation() {
|
||||
// Valid scalar (less than modulus)
|
||||
let valid_scalar = ZkScalar::from_u64(12345);
|
||||
assert!(valid_scalar.is_valid(), "Small scalar should be valid");
|
||||
|
||||
// Scalar equal to modulus minus 1 (r - 1)
|
||||
let max_valid = ZkScalar::from(R_MINUS_1);
|
||||
assert!(max_valid.is_valid(), "r-1 should be valid");
|
||||
|
||||
// Scalar equal to modulus (invalid)
|
||||
let r: [u64; 5] = [
|
||||
0x0428001400040001,
|
||||
0x7bb9b0e8d8ca3461,
|
||||
0xd04c98ccc4c050bc,
|
||||
0x7995b34995830fa4,
|
||||
0x00000511b70539f2,
|
||||
];
|
||||
let equal_to_r = ZkScalar::from(r);
|
||||
assert!(!equal_to_r.is_valid(), "r should be invalid");
|
||||
|
||||
// Test reduction
|
||||
let reduced = equal_to_r.reduce_once();
|
||||
assert!(
|
||||
reduced.is_valid(),
|
||||
"Reduced scalar should be valid (equals zero)"
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod curve_446;
|
||||
pub mod curve_api;
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
pub mod gpu;
|
||||
pub mod proofs;
|
||||
pub mod serialization;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ use tfhe_versionable::Versionize;
|
||||
|
||||
#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct OneBased<T: ?Sized>(T);
|
||||
pub(crate) struct OneBased<T: ?Sized>(pub(crate) T);
|
||||
|
||||
/// The proving scheme is available in 2 versions, one that puts more load on the prover and one
|
||||
/// that puts more load on the verifier
|
||||
@@ -153,7 +153,7 @@ pub(crate) enum ProofSanityCheckMode {
|
||||
/// Check the preconditions of the pke proof before computing it. Panic if one of the conditions
|
||||
/// does not hold.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn assert_pke_proof_preconditions(
|
||||
pub(crate) fn assert_pke_proof_preconditions(
|
||||
a: &[i64],
|
||||
b: &[i64],
|
||||
c1: &[i64],
|
||||
@@ -179,7 +179,7 @@ fn assert_pke_proof_preconditions(
|
||||
|
||||
/// q (modulus) is encoded on 64b, with 0 meaning 2^64. This converts the encoded q to its effective
|
||||
/// value for modular operations.
|
||||
fn decode_q(q: u64) -> u128 {
|
||||
pub(crate) fn decode_q(q: u64) -> u128 {
|
||||
if q == 0 {
|
||||
1u128 << 64
|
||||
} else {
|
||||
@@ -193,7 +193,7 @@ fn decode_q(q: u64) -> u128 {
|
||||
/// implies
|
||||
/// phi(r1) = (rot(a) * phi(bar(r)) + phi(e1) - phi(c1)) / q
|
||||
/// (phi is the function that maps a polynomial to its coeffs vector)
|
||||
fn compute_r1(
|
||||
pub(crate) fn compute_r1(
|
||||
e1: &[i64],
|
||||
c1: &[i64],
|
||||
a: &[i64],
|
||||
@@ -233,7 +233,7 @@ fn compute_r1(
|
||||
/// r2_i = (phi_[d - i](b).T * phi(bar(r)) + delta * m_i + e2_i - c2_i) / q
|
||||
/// (phi is the function that maps a polynomial to its coeffs vector)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compute_r2(
|
||||
pub(crate) fn compute_r2(
|
||||
e2: &[i64],
|
||||
c2: &[i64],
|
||||
m: &[i64],
|
||||
@@ -314,7 +314,7 @@ impl Sid {
|
||||
Self(Some(rng.gen()))
|
||||
}
|
||||
|
||||
fn to_le_bytes(self) -> SidBytes {
|
||||
pub(crate) fn to_le_bytes(self) -> SidBytes {
|
||||
self.0
|
||||
.map(|val| SidBytes(Some(val.to_le_bytes())))
|
||||
.unwrap_or_default()
|
||||
@@ -322,10 +322,10 @@ impl Sid {
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SidBytes(Option<[u8; 16]>);
|
||||
pub(crate) struct SidBytes(Option<[u8; 16]>);
|
||||
|
||||
impl SidBytes {
|
||||
fn as_slice(&self) -> &[u8] {
|
||||
pub(crate) fn as_slice(&self) -> &[u8] {
|
||||
self.0.as_ref().map(|val| val.as_slice()).unwrap_or(&[])
|
||||
}
|
||||
}
|
||||
@@ -367,7 +367,7 @@ fn get_or_init_pools() -> &'static Vec<VerificationPool> {
|
||||
///
|
||||
/// When multiple calls of this function are made in parallel, each of them is executed in a
|
||||
/// dedicated pool, if there is enough free cores on the CPU.
|
||||
fn run_in_pool<OP, R>(f: OP) -> R
|
||||
pub(crate) fn run_in_pool<OP, R>(f: OP) -> R
|
||||
where
|
||||
OP: FnOnce() -> R + Send,
|
||||
R: Send,
|
||||
@@ -422,7 +422,7 @@ pub mod pke;
|
||||
pub mod pke_v2;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
pub(crate) mod test {
|
||||
#![allow(non_snake_case)]
|
||||
use std::fmt::Display;
|
||||
use std::num::Wrapping;
|
||||
@@ -439,9 +439,9 @@ mod test {
|
||||
use crate::proofs::decode_q;
|
||||
|
||||
// One of our usecases uses 320 bits of additional metadata
|
||||
pub(super) const METADATA_LEN: usize = (320 / u8::BITS) as usize;
|
||||
pub(crate) const METADATA_LEN: usize = (320 / u8::BITS) as usize;
|
||||
|
||||
pub(super) enum Compress {
|
||||
pub(crate) enum Compress {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
@@ -453,7 +453,7 @@ mod test {
|
||||
Nominal,
|
||||
}
|
||||
|
||||
pub(super) fn serialize_then_deserialize<
|
||||
pub(crate) fn serialize_then_deserialize<
|
||||
Params: Compressible + Serialize + for<'de> Deserialize<'de>,
|
||||
>(
|
||||
public_params: &Params,
|
||||
@@ -503,36 +503,36 @@ mod test {
|
||||
|
||||
/// Parameters needed for a PKE zk proof test
|
||||
#[derive(Copy, Clone)]
|
||||
pub(super) struct PkeTestParameters {
|
||||
pub(super) d: usize,
|
||||
pub(super) k: usize,
|
||||
pub(super) B: u64,
|
||||
pub(super) q: u64,
|
||||
pub(super) t: u64,
|
||||
pub(super) msbs_zero_padding_bit_count: u64,
|
||||
pub(crate) struct PkeTestParameters {
|
||||
pub(crate) d: usize,
|
||||
pub(crate) k: usize,
|
||||
pub(crate) B: u64,
|
||||
pub(crate) q: u64,
|
||||
pub(crate) t: u64,
|
||||
pub(crate) msbs_zero_padding_bit_count: u64,
|
||||
}
|
||||
|
||||
/// An encrypted PKE ciphertext
|
||||
pub struct PkeTestCiphertext {
|
||||
pub(super) c1: Vec<i64>,
|
||||
pub(super) c2: Vec<i64>,
|
||||
pub(crate) struct PkeTestCiphertext {
|
||||
pub(crate) c1: Vec<i64>,
|
||||
pub(crate) c2: Vec<i64>,
|
||||
}
|
||||
|
||||
/// A randomly generated testcase of pke encryption
|
||||
#[derive(Clone)]
|
||||
pub(super) struct PkeTestcase {
|
||||
pub(super) a: Vec<i64>,
|
||||
pub(super) e1: Vec<i64>,
|
||||
pub(super) e2: Vec<i64>,
|
||||
pub(super) r: Vec<i64>,
|
||||
pub(super) m: Vec<i64>,
|
||||
pub(super) b: Vec<i64>,
|
||||
pub(super) metadata: [u8; METADATA_LEN],
|
||||
pub(super) s: Vec<i64>,
|
||||
pub(crate) struct PkeTestcase {
|
||||
pub(crate) a: Vec<i64>,
|
||||
pub(crate) e1: Vec<i64>,
|
||||
pub(crate) e2: Vec<i64>,
|
||||
pub(crate) r: Vec<i64>,
|
||||
pub(crate) m: Vec<i64>,
|
||||
pub(crate) b: Vec<i64>,
|
||||
pub(crate) metadata: [u8; METADATA_LEN],
|
||||
pub(crate) s: Vec<i64>,
|
||||
}
|
||||
|
||||
impl PkeTestcase {
|
||||
pub(super) fn gen(rng: &mut StdRng, params: PkeTestParameters) -> Self {
|
||||
pub(crate) fn gen(rng: &mut StdRng, params: PkeTestParameters) -> Self {
|
||||
let PkeTestParameters {
|
||||
d,
|
||||
k,
|
||||
@@ -587,7 +587,7 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn sk_encrypt_zero(
|
||||
pub(crate) fn sk_encrypt_zero(
|
||||
&self,
|
||||
params: PkeTestParameters,
|
||||
rng: &mut StdRng,
|
||||
@@ -618,7 +618,7 @@ mod test {
|
||||
}
|
||||
|
||||
/// Decrypt a ciphertext list
|
||||
pub(super) fn decrypt(
|
||||
pub(crate) fn decrypt(
|
||||
&self,
|
||||
ct: &PkeTestCiphertext,
|
||||
params: PkeTestParameters,
|
||||
@@ -659,7 +659,7 @@ mod test {
|
||||
}
|
||||
|
||||
/// Encrypt using compact pke, the encryption is validated by doing a decryption
|
||||
pub(super) fn encrypt(&self, params: PkeTestParameters) -> PkeTestCiphertext {
|
||||
pub(crate) fn encrypt(&self, params: PkeTestParameters) -> PkeTestCiphertext {
|
||||
let ct = self.encrypt_unchecked(params);
|
||||
|
||||
// Check decryption
|
||||
@@ -671,7 +671,7 @@ mod test {
|
||||
}
|
||||
|
||||
/// Encrypt using compact pke, without checking that the decryption is correct
|
||||
pub(super) fn encrypt_unchecked(&self, params: PkeTestParameters) -> PkeTestCiphertext {
|
||||
pub(crate) fn encrypt_unchecked(&self, params: PkeTestParameters) -> PkeTestCiphertext {
|
||||
let PkeTestParameters {
|
||||
d,
|
||||
k,
|
||||
|
||||
438
tfhe-zk-pok/src/proofs/pke/gpu.rs
Normal file
438
tfhe-zk-pok/src/proofs/pke/gpu.rs
Normal file
@@ -0,0 +1,438 @@
|
||||
//! GPU-accelerated prove/verify for PKE v1.
|
||||
//!
|
||||
//! `prove` duplicates the logic of [`super::prove_impl`] but replaces every
|
||||
//! `multi_mul_scalar` call with the GPU-accelerated [`crate::gpu::g1_msm_gpu`]
|
||||
//! / [`crate::gpu::g2_msm_gpu`]. `verify` simply delegates to the CPU
|
||||
//! verifier since v1 verification contains no MSM calls.
|
||||
|
||||
use crate::curve_api::bls12_446::{Zp, G1, G2};
|
||||
use crate::curve_api::{Bls12_446, CurveGroupOps, FieldOps};
|
||||
use crate::gpu::select_gpu_for_msm;
|
||||
use crate::proofs::{
|
||||
assert_pke_proof_preconditions, compute_r1, compute_r2, decode_q, ComputeLoad, OneBased,
|
||||
ProofSanityCheckMode,
|
||||
};
|
||||
|
||||
use super::{
|
||||
bit_iter, compute_a_theta, ComputeLoadProofFields, PrivateCommit, Proof, PublicCommit,
|
||||
PublicParams,
|
||||
};
|
||||
|
||||
/// GPU-accelerated proof generation for PKE v1.
|
||||
///
|
||||
/// Identical to [`super::prove`] but dispatches MSM to the GPU via
|
||||
/// [`crate::gpu::g1_msm_gpu`] / [`crate::gpu::g2_msm_gpu`].
|
||||
pub fn prove(
|
||||
public: (&PublicParams<Bls12_446>, &PublicCommit<Bls12_446>),
|
||||
private_commit: &PrivateCommit<Bls12_446>,
|
||||
metadata: &[u8],
|
||||
load: ComputeLoad,
|
||||
seed: &[u8],
|
||||
) -> Proof<Bls12_446> {
|
||||
prove_impl(
|
||||
public,
|
||||
private_commit,
|
||||
metadata,
|
||||
load,
|
||||
seed,
|
||||
ProofSanityCheckMode::Panic,
|
||||
)
|
||||
}
|
||||
|
||||
/// GPU-accelerated verification for PKE v1.
|
||||
///
|
||||
/// PKE v1 verification has no MSM calls, so this delegates directly to the CPU
|
||||
/// verifier.
|
||||
#[allow(clippy::result_unit_err)]
|
||||
pub fn verify(
|
||||
proof: &Proof<Bls12_446>,
|
||||
public: (&PublicParams<Bls12_446>, &PublicCommit<Bls12_446>),
|
||||
metadata: &[u8],
|
||||
) -> Result<(), ()> {
|
||||
super::verify(proof, public, metadata)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// prove_impl – GPU variant of super::prove_impl
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn prove_impl(
|
||||
public: (&PublicParams<Bls12_446>, &PublicCommit<Bls12_446>),
|
||||
private_commit: &PrivateCommit<Bls12_446>,
|
||||
metadata: &[u8],
|
||||
load: ComputeLoad,
|
||||
seed: &[u8],
|
||||
sanity_check_mode: ProofSanityCheckMode,
|
||||
) -> Proof<Bls12_446> {
|
||||
let &PublicParams {
|
||||
ref g_lists,
|
||||
big_d: big_d_max,
|
||||
n,
|
||||
d,
|
||||
b,
|
||||
b_r,
|
||||
q,
|
||||
t,
|
||||
msbs_zero_padding_bit_count,
|
||||
k: k_max,
|
||||
sid,
|
||||
domain_separators: ref ds,
|
||||
} = public.0;
|
||||
let g_list = &g_lists.g_list;
|
||||
let g_hat_list = &g_lists.g_hat_list;
|
||||
|
||||
let b_i = b;
|
||||
|
||||
let PublicCommit { a, b, c1, c2, .. } = public.1;
|
||||
let PrivateCommit { r, e1, m, e2, .. } = private_commit;
|
||||
|
||||
let k = c2.len();
|
||||
|
||||
let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count;
|
||||
|
||||
let decoded_q = decode_q(q);
|
||||
|
||||
let big_d = d
|
||||
+ k * effective_t_for_decomposition.ilog2() as usize
|
||||
+ (d + k) * (2 + b_i.ilog2() as usize + b_r.ilog2() as usize);
|
||||
|
||||
if sanity_check_mode == ProofSanityCheckMode::Panic {
|
||||
assert_pke_proof_preconditions(a, b, c1, e1, c2, e2, d, k_max, big_d, big_d_max);
|
||||
}
|
||||
|
||||
// FIXME: div_round
|
||||
let delta = {
|
||||
// delta takes the encoding with the padding bit
|
||||
// decoded_q <= 2^64 and t >= 1, so the quotient always fits in u64
|
||||
(decoded_q / t as u128) as u64
|
||||
};
|
||||
|
||||
let g = G1::GENERATOR;
|
||||
let g_hat = G2::GENERATOR;
|
||||
let mut gamma_list = [Zp::ZERO; 2];
|
||||
Zp::hash(&mut gamma_list, &[ds.hash_gamma(), seed]);
|
||||
let [gamma, gamma_y] = gamma_list;
|
||||
|
||||
let r1 = compute_r1(e1, c1, a, r, d, decoded_q);
|
||||
let r2 = compute_r2(e2, c2, m, b, r, d, delta, decoded_q);
|
||||
|
||||
let mut w = vec![false; n];
|
||||
|
||||
// Reinterpret as unsigned for bit decomposition; bit pattern is preserved,
|
||||
// which is correct for torus arithmetic
|
||||
let u64 = |x: i64| x as u64;
|
||||
|
||||
w[..big_d]
|
||||
.iter_mut()
|
||||
.zip(
|
||||
r.iter()
|
||||
.rev()
|
||||
.flat_map(|&r| bit_iter(u64(r), 1))
|
||||
.chain(
|
||||
m.iter()
|
||||
.flat_map(|&m| bit_iter(u64(m), effective_t_for_decomposition.ilog2())),
|
||||
)
|
||||
.chain(e1.iter().flat_map(|&e1| bit_iter(u64(e1), 1 + b_i.ilog2())))
|
||||
.chain(e2.iter().flat_map(|&e2| bit_iter(u64(e2), 1 + b_i.ilog2())))
|
||||
.chain(r1.iter().flat_map(|&r1| bit_iter(u64(r1), 1 + b_r.ilog2())))
|
||||
.chain(r2.iter().flat_map(|&r2| bit_iter(u64(r2), 1 + b_r.ilog2()))),
|
||||
)
|
||||
.for_each(|(dst, src)| *dst = src);
|
||||
|
||||
let w = OneBased(w);
|
||||
|
||||
let mut c_hat = g_hat.mul_scalar(gamma);
|
||||
for j in 1..big_d + 1 {
|
||||
if w[j] {
|
||||
c_hat += G2::projective(g_hat_list[j]);
|
||||
}
|
||||
}
|
||||
|
||||
let x_bytes = &*[
|
||||
q.to_le_bytes().as_slice(),
|
||||
(d as u64).to_le_bytes().as_slice(),
|
||||
b_i.to_le_bytes().as_slice(),
|
||||
t.to_le_bytes().as_slice(),
|
||||
msbs_zero_padding_bit_count.to_le_bytes().as_slice(),
|
||||
&*a.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
|
||||
&*b.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
|
||||
&*c1.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
|
||||
&*c2.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
|
||||
]
|
||||
.iter()
|
||||
.copied()
|
||||
.flatten()
|
||||
.copied()
|
||||
.collect::<Box<_>>();
|
||||
|
||||
let mut y = vec![Zp::ZERO; n];
|
||||
Zp::hash(
|
||||
&mut y,
|
||||
&[
|
||||
ds.hash(),
|
||||
sid.to_le_bytes().as_slice(),
|
||||
metadata,
|
||||
x_bytes,
|
||||
c_hat.to_le_bytes().as_ref(),
|
||||
],
|
||||
);
|
||||
let y = OneBased(y);
|
||||
|
||||
// GPU MSM: c_y commitment
|
||||
let scalars = (n + 1 - big_d..n + 1)
|
||||
.map(|j| y[n + 1 - j] * Zp::from_u64(w[n + 1 - j] as u64))
|
||||
.collect::<Vec<_>>();
|
||||
let c_y = g.mul_scalar(gamma_y)
|
||||
+ crate::gpu::g1_msm_gpu(&g_list.0[n - big_d..n], &scalars, select_gpu_for_msm());
|
||||
|
||||
let mut theta = vec![Zp::ZERO; d + k + 1];
|
||||
Zp::hash(
|
||||
&mut theta,
|
||||
&[
|
||||
ds.hash_lmap(),
|
||||
sid.to_le_bytes().as_slice(),
|
||||
metadata,
|
||||
x_bytes,
|
||||
c_hat.to_le_bytes().as_ref(),
|
||||
c_y.to_le_bytes().as_ref(),
|
||||
],
|
||||
);
|
||||
|
||||
let theta0 = &theta[..d + k];
|
||||
|
||||
let delta_theta = theta[d + k];
|
||||
|
||||
let mut a_theta = vec![Zp::ZERO; big_d];
|
||||
|
||||
compute_a_theta::<Bls12_446>(
|
||||
theta0,
|
||||
d,
|
||||
a,
|
||||
k,
|
||||
b,
|
||||
&mut a_theta,
|
||||
effective_t_for_decomposition,
|
||||
delta,
|
||||
b_i,
|
||||
b_r,
|
||||
decoded_q,
|
||||
);
|
||||
|
||||
let mut t = vec![Zp::ZERO; n];
|
||||
Zp::hash_128bit(
|
||||
&mut t,
|
||||
&[
|
||||
ds.hash_t(),
|
||||
sid.to_le_bytes().as_slice(),
|
||||
metadata,
|
||||
&(1..n + 1)
|
||||
.flat_map(|i| y[i].to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<_>>(),
|
||||
x_bytes,
|
||||
c_hat.to_le_bytes().as_ref(),
|
||||
c_y.to_le_bytes().as_ref(),
|
||||
],
|
||||
);
|
||||
let t = OneBased(t);
|
||||
|
||||
let mut delta = [Zp::ZERO; 2];
|
||||
Zp::hash(
|
||||
&mut delta,
|
||||
&[
|
||||
ds.hash_agg(),
|
||||
sid.to_le_bytes().as_slice(),
|
||||
metadata,
|
||||
x_bytes,
|
||||
c_hat.to_le_bytes().as_ref(),
|
||||
c_y.to_le_bytes().as_ref(),
|
||||
],
|
||||
);
|
||||
let [delta_eq, delta_y] = delta;
|
||||
let delta = [delta_eq, delta_y, delta_theta];
|
||||
|
||||
let mut poly_0 = vec![Zp::ZERO; n + 1];
|
||||
let mut poly_1 = vec![Zp::ZERO; big_d + 1];
|
||||
let mut poly_2 = vec![Zp::ZERO; n + 1];
|
||||
let mut poly_3 = vec![Zp::ZERO; n + 1];
|
||||
|
||||
poly_0[0] = delta_y * gamma_y;
|
||||
for i in 1..n + 1 {
|
||||
poly_0[n + 1 - i] =
|
||||
delta_y * (y[i] * Zp::from_u64(w[i] as u64)) + (delta_eq * t[i] - delta_y) * y[i];
|
||||
|
||||
if i < big_d + 1 {
|
||||
poly_0[n + 1 - i] += delta_theta * a_theta[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
poly_1[0] = gamma;
|
||||
for i in 1..big_d + 1 {
|
||||
poly_1[i] = Zp::from_u64(w[i] as u64);
|
||||
}
|
||||
|
||||
poly_2[0] = gamma_y;
|
||||
for i in 1..big_d + 1 {
|
||||
poly_2[n + 1 - i] = y[i] * Zp::from_u64(w[i] as u64);
|
||||
}
|
||||
|
||||
for i in 1..n + 1 {
|
||||
poly_3[i] = delta_eq * t[i];
|
||||
}
|
||||
|
||||
let mut t_theta = Zp::ZERO;
|
||||
for i in 0..d {
|
||||
t_theta += theta0[i] * Zp::from_i64(c1[i]);
|
||||
}
|
||||
for i in 0..k {
|
||||
t_theta += theta0[d + i] * Zp::from_i64(c2[i]);
|
||||
}
|
||||
|
||||
let mul = rayon::join(
|
||||
|| Zp::poly_mul(&poly_0, &poly_1),
|
||||
|| Zp::poly_mul(&poly_2, &poly_3),
|
||||
);
|
||||
let mut poly = Zp::poly_sub(&mul.0, &mul.1);
|
||||
if poly.len() > n + 1 {
|
||||
poly[n + 1] -= t_theta * delta_theta;
|
||||
}
|
||||
|
||||
// GPU MSM: pi commitment
|
||||
let pi = g.mul_scalar(poly[0])
|
||||
+ crate::gpu::g1_msm_gpu(
|
||||
&g_list.0[..poly.len() - 1],
|
||||
&poly[1..],
|
||||
select_gpu_for_msm(),
|
||||
);
|
||||
|
||||
if load == ComputeLoad::Proof {
|
||||
// GPU MSM: c_hat_t commitment
|
||||
let c_hat_t = crate::gpu::g2_msm_gpu(&g_hat_list.0, &t.0, select_gpu_for_msm());
|
||||
let scalars = (1..n + 1)
|
||||
.map(|i| {
|
||||
let i = n + 1 - i;
|
||||
(delta_eq * t[i] - delta_y) * y[i]
|
||||
+ if i < big_d + 1 {
|
||||
delta_theta * a_theta[i - 1]
|
||||
} else {
|
||||
Zp::ZERO
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// GPU MSM: c_h commitment
|
||||
let c_h = crate::gpu::g1_msm_gpu(&g_list.0[..n], &scalars, select_gpu_for_msm());
|
||||
|
||||
let mut z = Zp::ZERO;
|
||||
Zp::hash(
|
||||
core::array::from_mut(&mut z),
|
||||
&[
|
||||
ds.hash_z(),
|
||||
sid.to_le_bytes().as_slice(),
|
||||
metadata,
|
||||
x_bytes,
|
||||
c_hat.to_le_bytes().as_ref(),
|
||||
c_y.to_le_bytes().as_ref(),
|
||||
pi.to_le_bytes().as_ref(),
|
||||
c_h.to_le_bytes().as_ref(),
|
||||
c_hat_t.to_le_bytes().as_ref(),
|
||||
&y.0.iter()
|
||||
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<[_]>>(),
|
||||
&t.0.iter()
|
||||
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<[_]>>(),
|
||||
&delta
|
||||
.iter()
|
||||
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<[_]>>(),
|
||||
],
|
||||
);
|
||||
|
||||
let mut pow = z;
|
||||
let mut p_t = Zp::ZERO;
|
||||
let mut p_h = Zp::ZERO;
|
||||
|
||||
for i in 1..n + 1 {
|
||||
p_t += t[i] * pow;
|
||||
if n - i < big_d {
|
||||
p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]
|
||||
+ delta_theta * a_theta[n - i])
|
||||
* pow;
|
||||
} else {
|
||||
p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow;
|
||||
}
|
||||
pow = pow * z;
|
||||
}
|
||||
|
||||
let mut w = Zp::ZERO;
|
||||
Zp::hash(
|
||||
core::array::from_mut(&mut w),
|
||||
&[
|
||||
ds.hash_w(),
|
||||
sid.to_le_bytes().as_slice(),
|
||||
metadata,
|
||||
x_bytes,
|
||||
c_hat.to_le_bytes().as_ref(),
|
||||
c_y.to_le_bytes().as_ref(),
|
||||
pi.to_le_bytes().as_ref(),
|
||||
c_h.to_le_bytes().as_ref(),
|
||||
c_hat_t.to_le_bytes().as_ref(),
|
||||
&y.0.iter()
|
||||
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<[_]>>(),
|
||||
&t.0.iter()
|
||||
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<[_]>>(),
|
||||
&delta
|
||||
.iter()
|
||||
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
|
||||
.collect::<Box<[_]>>(),
|
||||
z.to_le_bytes().as_ref(),
|
||||
p_h.to_le_bytes().as_ref(),
|
||||
p_t.to_le_bytes().as_ref(),
|
||||
],
|
||||
);
|
||||
|
||||
let mut poly = vec![Zp::ZERO; n + 1];
|
||||
for i in 1..n + 1 {
|
||||
poly[i] += w * t[i];
|
||||
if i < big_d + 1 {
|
||||
poly[n + 1 - i] +=
|
||||
(delta_eq * t[i] - delta_y) * y[i] + delta_theta * a_theta[i - 1];
|
||||
} else {
|
||||
poly[n + 1 - i] += (delta_eq * t[i] - delta_y) * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
let mut q = vec![Zp::ZERO; n];
|
||||
// Polynomial long division by (X - z)
|
||||
for i in (0..n).rev() {
|
||||
poly[i] = poly[i] + z * poly[i + 1];
|
||||
q[i] = poly[i + 1];
|
||||
poly[i + 1] = Zp::ZERO;
|
||||
}
|
||||
|
||||
// GPU MSM: pi_kzg commitment
|
||||
let pi_kzg = g.mul_scalar(q[0])
|
||||
+ crate::gpu::g1_msm_gpu(&g_list.0[..n - 1], &q[1..n], select_gpu_for_msm());
|
||||
|
||||
Proof {
|
||||
c_hat,
|
||||
c_y,
|
||||
pi,
|
||||
compute_load_proof_fields: Some(ComputeLoadProofFields {
|
||||
c_hat_t,
|
||||
c_h,
|
||||
pi_kzg,
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
Proof {
|
||||
c_hat,
|
||||
c_y,
|
||||
pi,
|
||||
compute_load_proof_fields: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
pub mod gpu;
|
||||
|
||||
// TODO: refactor copy-pasted code in proof/verify
|
||||
|
||||
use crate::backward_compatibility::pke::{
|
||||
@@ -15,7 +18,7 @@ use core::marker::PhantomData;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
|
||||
pub(crate) fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
|
||||
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
|
||||
}
|
||||
|
||||
@@ -441,10 +444,10 @@ where
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PublicCommit<G: Curve> {
|
||||
a: Vec<i64>,
|
||||
b: Vec<i64>,
|
||||
c1: Vec<i64>,
|
||||
c2: Vec<i64>,
|
||||
pub(crate) a: Vec<i64>,
|
||||
pub(crate) b: Vec<i64>,
|
||||
pub(crate) c1: Vec<i64>,
|
||||
pub(crate) c2: Vec<i64>,
|
||||
__marker: PhantomData<G>,
|
||||
}
|
||||
|
||||
@@ -462,10 +465,10 @@ impl<G: Curve> PublicCommit<G> {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PrivateCommit<G: Curve> {
|
||||
r: Vec<i64>,
|
||||
e1: Vec<i64>,
|
||||
m: Vec<i64>,
|
||||
e2: Vec<i64>,
|
||||
pub(crate) r: Vec<i64>,
|
||||
pub(crate) e1: Vec<i64>,
|
||||
pub(crate) m: Vec<i64>,
|
||||
pub(crate) e2: Vec<i64>,
|
||||
__marker: PhantomData<G>,
|
||||
}
|
||||
|
||||
@@ -932,7 +935,7 @@ fn prove_impl<G: Curve>(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compute_a_theta<G: Curve>(
|
||||
pub(crate) fn compute_a_theta<G: Curve>(
|
||||
theta0: &[G::Zp],
|
||||
d: usize,
|
||||
a: &[i64],
|
||||
@@ -1353,7 +1356,7 @@ pub fn verify<G: Curve>(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub(crate) mod tests {
|
||||
use crate::curve_api::{self, bls12_446};
|
||||
|
||||
use super::super::test::*;
|
||||
@@ -1364,7 +1367,7 @@ mod tests {
|
||||
type Curve = curve_api::Bls12_446;
|
||||
|
||||
/// Compact key params used with pkev1
|
||||
pub(super) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
|
||||
pub(crate) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
|
||||
d: 1024,
|
||||
k: 320,
|
||||
B: 4398046511104, // 2**42
|
||||
1587
tfhe-zk-pok/src/proofs/pke_v2/gpu.rs
Normal file
1587
tfhe-zk-pok/src/proofs/pke_v2/gpu.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -337,13 +337,13 @@ struct RInputs<'a> {
|
||||
mode: PkeV2HashMode,
|
||||
}
|
||||
|
||||
pub(super) struct RHash<'a> {
|
||||
pub(crate) struct RHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
}
|
||||
|
||||
impl<'a> RHash<'a> {
|
||||
pub(super) fn new<G: Curve>(
|
||||
pub(crate) fn new<G: Curve>(
|
||||
public: (&'a PublicParams<G>, &PublicCommit<G>),
|
||||
metadata: &'a [u8],
|
||||
C_hat_e_bytes: &'a [u8],
|
||||
@@ -502,7 +502,7 @@ impl<'a> RHash<'a> {
|
||||
]
|
||||
}
|
||||
|
||||
pub(super) fn gen_phi<Zp: FieldOps>(self, C_R_bytes: &'a [u8]) -> ([Zp; 128], PhiHash<'a>) {
|
||||
pub(crate) fn gen_phi<Zp: FieldOps>(self, C_R_bytes: &'a [u8]) -> ([Zp; 128], PhiHash<'a>) {
|
||||
let mode = self.R_inputs.mode;
|
||||
let phi_inputs = PhiInputs { C_R_bytes };
|
||||
|
||||
@@ -526,7 +526,7 @@ struct PhiInputs<'a> {
|
||||
C_R_bytes: &'a [u8],
|
||||
}
|
||||
|
||||
pub(super) struct PhiHash<'a> {
|
||||
pub(crate) struct PhiHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -572,7 +572,7 @@ impl<'a> PhiHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_xi<Zp: FieldOps>(self, C_hat_bin_bytes: &'a [u8]) -> ([Zp; 128], XiHash<'a>) {
|
||||
pub(crate) fn gen_xi<Zp: FieldOps>(self, C_hat_bin_bytes: &'a [u8]) -> ([Zp; 128], XiHash<'a>) {
|
||||
let mode = self.R_inputs.mode;
|
||||
let xi_inputs = XiInputs { C_hat_bin_bytes };
|
||||
|
||||
@@ -598,7 +598,7 @@ struct XiInputs<'a> {
|
||||
C_hat_bin_bytes: &'a [u8],
|
||||
}
|
||||
|
||||
pub(super) struct XiHash<'a> {
|
||||
pub(crate) struct XiHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -650,7 +650,7 @@ impl<'a> XiHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_y<Zp: FieldOps>(self) -> (Vec<Zp>, YHash<'a>) {
|
||||
pub(crate) fn gen_y<Zp: FieldOps>(self) -> (Vec<Zp>, YHash<'a>) {
|
||||
let mode = self.R_inputs.mode;
|
||||
|
||||
let mut y = vec![Zp::ZERO; self.R_inputs.D + 128 * self.R_inputs.m];
|
||||
@@ -671,7 +671,7 @@ impl<'a> XiHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct YHash<'a> {
|
||||
pub(crate) struct YHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -729,7 +729,7 @@ impl<'a> YHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_t<Zp: FieldOps>(self, C_y_bytes: &'a [u8]) -> (Vec<Zp>, THash<'a>) {
|
||||
pub(crate) fn gen_t<Zp: FieldOps>(self, C_y_bytes: &'a [u8]) -> (Vec<Zp>, THash<'a>) {
|
||||
let mode = self.R_inputs.mode;
|
||||
let t_inputs = TInputs { C_y_bytes };
|
||||
|
||||
@@ -761,7 +761,7 @@ struct TInputs<'a> {
|
||||
C_y_bytes: &'a [u8],
|
||||
}
|
||||
|
||||
pub(super) struct THash<'a> {
|
||||
pub(crate) struct THash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -825,7 +825,7 @@ impl<'a> THash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_theta<Zp: FieldOps>(self) -> (Vec<Zp>, ThetaHash<'a>) {
|
||||
pub(crate) fn gen_theta<Zp: FieldOps>(self) -> (Vec<Zp>, ThetaHash<'a>) {
|
||||
let mode = self.R_inputs.mode;
|
||||
|
||||
let mut theta = vec![Zp::ZERO; self.R_inputs.d + self.R_inputs.k];
|
||||
@@ -849,7 +849,7 @@ impl<'a> THash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct ThetaHash<'a> {
|
||||
pub(crate) struct ThetaHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -917,7 +917,7 @@ impl<'a> ThetaHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_omega<Zp: FieldOps>(self) -> (Vec<Zp>, OmegaHash<'a>) {
|
||||
pub(crate) fn gen_omega<Zp: FieldOps>(self) -> (Vec<Zp>, OmegaHash<'a>) {
|
||||
let mode = self.R_inputs.mode;
|
||||
|
||||
let mut omega = vec![Zp::ZERO; self.R_inputs.n];
|
||||
@@ -946,7 +946,7 @@ impl<'a> ThetaHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct OmegaHash<'a> {
|
||||
pub(crate) struct OmegaHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -1018,7 +1018,7 @@ impl<'a> OmegaHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_delta<Zp: FieldOps>(self) -> ([Zp; 7], DeltaHash<'a>) {
|
||||
pub(crate) fn gen_delta<Zp: FieldOps>(self) -> ([Zp; 7], DeltaHash<'a>) {
|
||||
let mut delta = [Zp::ZERO; 7];
|
||||
|
||||
// Delta does not use the compact hash optimization
|
||||
@@ -1048,7 +1048,7 @@ impl<'a> OmegaHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct DeltaHash<'a> {
|
||||
pub(crate) struct DeltaHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -1161,7 +1161,7 @@ impl<'a> DeltaHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_z<Zp: FieldOps>(
|
||||
pub(crate) fn gen_z<Zp: FieldOps>(
|
||||
self,
|
||||
C_h1_bytes: &'a [u8],
|
||||
C_h2_bytes: &'a [u8],
|
||||
@@ -1210,7 +1210,7 @@ struct ZInputs<'a> {
|
||||
C_hat_omega_bytes: &'a [u8],
|
||||
}
|
||||
|
||||
pub(super) struct ZHash<'a> {
|
||||
pub(crate) struct ZHash<'a> {
|
||||
R_inputs: RInputs<'a>,
|
||||
R_bytes: Box<[u8]>,
|
||||
phi_inputs: PhiInputs<'a>,
|
||||
@@ -1351,7 +1351,7 @@ impl<'a> ZHash<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn gen_chi<Zp: FieldOps>(
|
||||
pub(crate) fn gen_chi<Zp: FieldOps>(
|
||||
self,
|
||||
p_h1: Zp,
|
||||
p_h2: Zp,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
// to follow the notation of the paper
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
#[cfg(feature = "gpu-experimental")]
|
||||
pub mod gpu;
|
||||
|
||||
use super::*;
|
||||
use crate::backward_compatibility::pke_v2::*;
|
||||
use crate::backward_compatibility::BoundVersions;
|
||||
@@ -16,13 +19,13 @@ use core::marker::PhantomData;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod hashes;
|
||||
pub(crate) mod hashes;
|
||||
|
||||
use hashes::RHash;
|
||||
|
||||
pub use hashes::*;
|
||||
|
||||
fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
|
||||
pub(crate) fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
|
||||
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
|
||||
}
|
||||
|
||||
@@ -436,7 +439,7 @@ pub(crate) struct ComputeLoadProofFields<G: Curve> {
|
||||
}
|
||||
|
||||
impl<G: Curve> ComputeLoadProofFields<G> {
|
||||
fn to_le_bytes(fields: &Option<Self>) -> (Box<[u8]>, Box<[u8]>) {
|
||||
pub(crate) fn to_le_bytes(fields: &Option<Self>) -> (Box<[u8]>, Box<[u8]>) {
|
||||
if let Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) = fields.as_ref() {
|
||||
(
|
||||
Box::from(G::G2::to_le_bytes(*C_hat_h3).as_ref()),
|
||||
@@ -589,13 +592,13 @@ where
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PublicCommit<G: Curve> {
|
||||
/// Mask of the public key
|
||||
a: Vec<i64>,
|
||||
pub(crate) a: Vec<i64>,
|
||||
/// Body of the public key
|
||||
b: Vec<i64>,
|
||||
pub(crate) b: Vec<i64>,
|
||||
/// Mask of the ciphertexts
|
||||
c1: Vec<i64>,
|
||||
pub(crate) c1: Vec<i64>,
|
||||
/// Bodies of the ciphertexts
|
||||
c2: Vec<i64>,
|
||||
pub(crate) c2: Vec<i64>,
|
||||
__marker: PhantomData<G>,
|
||||
}
|
||||
|
||||
@@ -614,13 +617,13 @@ impl<G: Curve> PublicCommit<G> {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PrivateCommit<G: Curve> {
|
||||
/// Public key sampling vector
|
||||
r: Vec<i64>,
|
||||
pub(crate) r: Vec<i64>,
|
||||
/// Error vector associated with the masks
|
||||
e1: Vec<i64>,
|
||||
pub(crate) e1: Vec<i64>,
|
||||
/// Input messages
|
||||
m: Vec<i64>,
|
||||
pub(crate) m: Vec<i64>,
|
||||
/// Error vector associated with the bodies
|
||||
e2: Vec<i64>,
|
||||
pub(crate) e2: Vec<i64>,
|
||||
__marker: PhantomData<G>,
|
||||
}
|
||||
|
||||
@@ -697,7 +700,7 @@ The computed m parameter is {m_bound} > 64. Please select a smaller B, d and/or
|
||||
///
|
||||
/// Use the relationship: `||x||_2 <= sqrt(dim)*||x||_inf`. Since we are only interested in the
|
||||
/// squared bound, we avoid the sqrt by returning dim*(||x||_inf)^2.
|
||||
fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
|
||||
pub(crate) fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
|
||||
let norm_squared = sqr(B_inf);
|
||||
norm_squared
|
||||
.checked_mul(dim as u128)
|
||||
@@ -1812,7 +1815,7 @@ fn precompute_xi_powers<Zp: FieldOps>(xi: &[Zp; 128], m: usize) -> Box<[Zp]> {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compute_a_theta<G: Curve>(
|
||||
pub(crate) fn compute_a_theta<G: Curve>(
|
||||
a_theta: &mut [G::Zp],
|
||||
theta: &[G::Zp],
|
||||
a: &[i64],
|
||||
@@ -1931,23 +1934,23 @@ pub enum VerificationPairingMode {
|
||||
Batched,
|
||||
}
|
||||
|
||||
struct GeneratedScalars<G: Curve> {
|
||||
phi: [G::Zp; 128],
|
||||
xi: [G::Zp; 128],
|
||||
theta: Vec<G::Zp>,
|
||||
omega: Vec<G::Zp>,
|
||||
delta: [G::Zp; 7],
|
||||
chi_powers: [G::Zp; 4],
|
||||
z: G::Zp,
|
||||
t_theta: G::Zp,
|
||||
pub(crate) struct GeneratedScalars<G: Curve> {
|
||||
pub(crate) phi: [G::Zp; 128],
|
||||
pub(crate) xi: [G::Zp; 128],
|
||||
pub(crate) theta: Vec<G::Zp>,
|
||||
pub(crate) omega: Vec<G::Zp>,
|
||||
pub(crate) delta: [G::Zp; 7],
|
||||
pub(crate) chi_powers: [G::Zp; 4],
|
||||
pub(crate) z: G::Zp,
|
||||
pub(crate) t_theta: G::Zp,
|
||||
}
|
||||
|
||||
struct EvaluationPoints<G: Curve> {
|
||||
p_h1: G::Zp,
|
||||
p_h2: G::Zp,
|
||||
p_h3: G::Zp,
|
||||
p_t: G::Zp,
|
||||
p_omega: G::Zp,
|
||||
pub(crate) struct EvaluationPoints<G: Curve> {
|
||||
pub(crate) p_h1: G::Zp,
|
||||
pub(crate) p_h2: G::Zp,
|
||||
pub(crate) p_h3: G::Zp,
|
||||
pub(crate) p_t: G::Zp,
|
||||
pub(crate) p_omega: G::Zp,
|
||||
}
|
||||
|
||||
#[allow(clippy::result_unit_err)]
|
||||
@@ -2655,7 +2658,7 @@ fn pairing_check_batched<G: Curve>(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub(crate) mod tests {
|
||||
use crate::curve_api::{self, bls12_446};
|
||||
|
||||
use super::super::test::*;
|
||||
@@ -2667,7 +2670,7 @@ mod tests {
|
||||
type Curve = curve_api::Bls12_446;
|
||||
|
||||
/// Compact key params used with pkev2
|
||||
pub(super) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
|
||||
pub(crate) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
|
||||
d: 2048,
|
||||
k: 320,
|
||||
B: 131072, // 2**17
|
||||
|
||||
@@ -102,6 +102,7 @@ integer = ["shortint", "dep:strum"]
|
||||
strings = ["integer"]
|
||||
internal-keycache = ["dep:fs2"]
|
||||
gpu = ["dep:tfhe-cuda-backend", "shortint"]
|
||||
gpu-experimental-zk = ["gpu", "zk-pok", "tfhe-zk-pok/gpu-experimental"]
|
||||
gpu-experimental-multi-arch = [
|
||||
"gpu",
|
||||
"tfhe-cuda-backend/experimental-multi-arch",
|
||||
|
||||
@@ -56,7 +56,6 @@ impl CudaProvenCompactCiphertextList {
|
||||
},
|
||||
|| self.expand_without_verification(key, streams),
|
||||
);
|
||||
|
||||
if all_valid {
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -17,15 +17,21 @@ use std::fmt::Debug;
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
use tfhe_zk_pok::proofs::pke::{
|
||||
commit as commit_v1, crs_gen as crs_gen_v1, prove as prove_v1, verify as verify_v1,
|
||||
Proof as ProofV1, PublicCommit as PublicCommitV1,
|
||||
commit as commit_v1, crs_gen as crs_gen_v1, verify as verify_v1, Proof as ProofV1,
|
||||
PublicCommit as PublicCommitV1,
|
||||
};
|
||||
|
||||
use tfhe_zk_pok::proofs::pke_v2::{
|
||||
commit as commit_v2, crs_gen as crs_gen_v2, prove as prove_v2, verify as verify_v2,
|
||||
PkeV2SupportedHashConfig, Proof as ProofV2, PublicCommit as PublicCommitV2,
|
||||
VerificationPairingMode,
|
||||
commit as commit_v2, crs_gen as crs_gen_v2, PkeV2SupportedHashConfig, Proof as ProofV2,
|
||||
PublicCommit as PublicCommitV2, VerificationPairingMode,
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "gpu-experimental-zk"))]
|
||||
use tfhe_zk_pok::proofs::pke::prove as prove_v1;
|
||||
|
||||
#[cfg(not(feature = "gpu-experimental-zk"))]
|
||||
use tfhe_zk_pok::proofs::pke_v2::{prove as prove_v2, verify as verify_v2};
|
||||
|
||||
pub use tfhe_zk_pok::curve_api::Compressible;
|
||||
pub use tfhe_zk_pok::proofs::pke_v2::PkeV2SupportedHashConfig as ZkPkeV2SupportedHashConfig;
|
||||
pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad;
|
||||
@@ -725,6 +731,15 @@ impl CompactPkeCrs {
|
||||
public_params,
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
let proof = tfhe_zk_pok::proofs::pke::gpu::prove(
|
||||
(public_params, &public_commit),
|
||||
&private_commit,
|
||||
metadata,
|
||||
load,
|
||||
&seed,
|
||||
);
|
||||
#[cfg(not(feature = "gpu-experimental-zk"))]
|
||||
let proof = prove_v1(
|
||||
(public_params, &public_commit),
|
||||
&private_commit,
|
||||
@@ -748,6 +763,15 @@ impl CompactPkeCrs {
|
||||
public_params,
|
||||
);
|
||||
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
let proof = tfhe_zk_pok::proofs::pke_v2::gpu::prove(
|
||||
(public_params, &public_commit),
|
||||
&private_commit,
|
||||
metadata,
|
||||
load,
|
||||
&seed,
|
||||
);
|
||||
#[cfg(not(feature = "gpu-experimental-zk"))]
|
||||
let proof = prove_v2(
|
||||
(public_params, &public_commit),
|
||||
&private_commit,
|
||||
@@ -816,12 +840,21 @@ impl CompactPkeCrs {
|
||||
}
|
||||
(Self::PkeV2(public_params), CompactPkeProof::PkeV2(proof)) => {
|
||||
let public_commit = PublicCommitV2::new(key_mask, key_body, ct_mask, ct_body);
|
||||
verify_v2(
|
||||
#[cfg(feature = "gpu-experimental-zk")]
|
||||
let res = tfhe_zk_pok::proofs::pke_v2::gpu::verify(
|
||||
proof,
|
||||
(public_params, &public_commit),
|
||||
metadata,
|
||||
VerificationPairingMode::default(),
|
||||
)
|
||||
);
|
||||
#[cfg(not(feature = "gpu-experimental-zk"))]
|
||||
let res = verify_v2(
|
||||
proof,
|
||||
(public_params, &public_commit),
|
||||
metadata,
|
||||
VerificationPairingMode::default(),
|
||||
);
|
||||
res
|
||||
}
|
||||
|
||||
(Self::PkeV1(_), CompactPkeProof::PkeV2(_))
|
||||
|
||||
Reference in New Issue
Block a user