Compare commits

...

4 Commits

Author SHA1 Message Date
Pedro Alves
359ebe02b8 chore(gpu): 32-bit zk
ZK_CUDA_LIMB_BITS=32 make test_integer_zk_experimental_gpu
2026-03-03 14:52:35 -03:00
Pedro Alves
b89d4419d7 chore(gpu): add new benchmark target for zk-cuda-backend accelerated functions 2026-03-02 16:03:50 -03:00
Pedro Alves
538692cc48 fix(gpu): propagate gpu-experimental-zk feature to tfhe crate in benchmarks
The benchmark's gpu-experimental-zk feature enabled tfhe-zk-pok/gpu-experimental
but never activated tfhe/gpu-experimental-zk. The #[cfg(feature = "gpu-experimental-zk")]
gate in tfhe/src/zk/mod.rs was always false, so the GPU verify path was dead code.
2026-03-02 16:03:49 -03:00
Pedro Alves
4505c5209a feat(gpu): integrate zk-cuda-backend with tfhe-zk-pok 2026-03-02 16:03:48 -03:00
30 changed files with 4298 additions and 150 deletions

View File

@@ -14,6 +14,7 @@ on:
- signed_integer
- integer_compression
- integer_zk
- msm_zk
- shortint
- shortint_oprf
- hlapi_unsigned

View File

@@ -31,6 +31,8 @@ on:
- pbs128
- ks
- ks_pbs
- tfhe_zk_pok
- msm_zk
- integer_zk
- integer_aes
- integer_aes256

View File

@@ -51,7 +51,12 @@ jobs:
with:
files_yaml: |
gpu:
- tfhe/Cargo.toml
- tfhe/build.rs
- backends/zk-cuda-backend/**
- tfhe/src/integer/gpu/zk/**
- tfhe-zk-pok/**
- 'tfhe/docs/**/**.md'
- '.github/workflows/gpu_zk_tests.yml'
- ci/slab.toml
@@ -126,6 +131,9 @@ jobs:
- name: Run zk-cuda-backend integration tests
run: |
make test_zk_cuda_backend
make test_zk_pok_gpu
make test_integer_zk_gpu
make test_integer_zk_experimental_gpu
slack-notify:
name: gpu_zk_tests/slack-notify

1
.gitignore vendored
View File

@@ -25,6 +25,7 @@ dieharder_run.log
# Cuda local build
backends/tfhe-cuda-backend/cuda/cmake-build-debug/
backends/tfhe-cuda-backend/cuda/build/
# WASM tests
tfhe/web_wasm_parallel_tests/server.PID

View File

@@ -353,14 +353,14 @@ check_typos: install_typos_checker
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
clippy_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
--features=boolean,shortint,integer,internal-keycache,gpu,gpu-experimental-zk,pbs-stats,extended-types,zk-pok \
--all-targets \
-p tfhe -- --no-deps -D warnings
.PHONY: check_gpu # Run check on tfhe with "gpu" enabled
check_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" check \
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats \
--features=boolean,shortint,integer,internal-keycache,gpu,gpu-experimental-zk,pbs-stats \
--all-targets \
-p tfhe
@@ -374,7 +374,7 @@ clippy_hpu: install_rs_check_toolchain
.PHONY: clippy_gpu_hpu # Run clippy lints on tfhe with "gpu" and "hpu" enabled
clippy_gpu_hpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean,shortint,integer,internal-keycache,gpu,hpu,pbs-stats,extended-types,zk-pok \
--features=boolean,shortint,integer,internal-keycache,gpu,gpu-experimental-zk,hpu,pbs-stats,extended-types,zk-pok \
--all-targets \
-p tfhe -- --no-deps -D warnings
@@ -467,7 +467,7 @@ clippy_rustdoc_gpu: install_rs_check_toolchain
fi && \
CARGO_TERM_QUIET=true CLIPPYFLAGS="-D warnings" RUSTDOCFLAGS="--no-run --test-builder ./scripts/clippy_driver.sh -Z unstable-options" \
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" test --doc \
--features=boolean,shortint,integer,zk-pok,pbs-stats,strings,experimental,gpu \
--features=boolean,shortint,integer,zk-pok,pbs-stats,strings,experimental,gpu,gpu-experimental-zk \
-p tfhe -- --nocapture
.PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API
@@ -670,7 +670,7 @@ build_c_api: install_rs_check_toolchain
.PHONY: build_c_api_gpu # Build the C API for boolean, shortint and integer
build_c_api_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \
--features=boolean-c-api,shortint-c-api,high-level-c-api,zk-pok,extended-types,gpu \
--features=boolean-c-api,shortint-c-api,high-level-c-api,zk-pok,extended-types,gpu,gpu-experimental-zk \
-p tfhe
.PHONY: build_c_api_experimental_deterministic_fft # Build the C API for boolean, shortint and integer with experimental deterministic FFT
@@ -769,7 +769,7 @@ test_zk_cuda_backend:
.PHONY: test_gpu # Run the tests of the core_crypto module including experimental on the gpu backend
test_gpu: test_core_crypto_gpu test_integer_gpu test_cuda_backend
test_gpu: test_core_crypto_gpu test_integer_gpu test_cuda_backend test_zk_cuda_backend
.PHONY: test_core_crypto_gpu # Run the tests of the core_crypto module including experimental on the gpu backend
test_core_crypto_gpu:
@@ -1205,12 +1205,31 @@ test_tfhe_csprng_big_endian: install_cargo_cross
RUSTFLAGS="" cross test --profile $(CARGO_PROFILE) \
-p tfhe-csprng --target=powerpc64-unknown-linux-gnu
.PHONY: test_zk_pok # Run tfhe-zk-pok tests
test_zk_pok:
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \
-p tfhe-zk-pok --features experimental
.PHONY: test_zk_pok_gpu # Run tfhe-zk-pok GPU-accelerated tests
test_zk_pok_gpu:
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \
-p tfhe-zk-pok --features experimental,gpu-experimental -- gpu
.PHONY: test_integer_zk_gpu # Run tfhe-zk-pok tests
test_integer_zk_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile release \
--features=integer,zk-pok,gpu -p tfhe -- \
integer::gpu::zk::
.PHONY: test_integer_zk_experimental_gpu # Run tfhe-zk-pok tests
test_integer_zk_experimental_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile release \
--features=integer,zk-pok,gpu,gpu-experimental-zk -p tfhe -- \
integer::gpu::zk::
.PHONY: test_zk_cuda # Run all GPU MSM integration tests (CPU vs GPU comparison + integration test)
test_zk_cuda: install_rs_check_toolchain test_zk_cuda_backend test_zk_pok_gpu test_integer_zk_gpu test_integer_zk_experimental_gpu
.PHONY: test_zk_wasm_x86_compat_ci
test_zk_wasm_x86_compat_ci: check_nvm_installed
source ~/.nvm/nvm.sh && \
@@ -1503,27 +1522,47 @@ bench_integer_compression_128b_gpu: install_rs_check_toolchain
--bench glwe_packing_compression_128b-integer-bench \
--features=integer,internal-keycache,gpu,pbs-stats -p tfhe-benchmark --
.PHONY: bench_msm_zk
bench_msm_zk: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench zk-msm \
--features=zk-pok -p tfhe-benchmark --profile release --
.PHONY: bench_msm_zk_gpu
bench_msm_zk_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench zk-msm \
--features=gpu,gpu-experimental-zk,zk-pok -p tfhe-benchmark --profile release --
.PHONY: bench_integer_zk_gpu
bench_integer_zk_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) __TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench integer-zk-pke \
--features=integer,internal-keycache,gpu,pbs-stats,zk-pok -p tfhe-benchmark --profile release_lto_off --
--features=integer,internal-keycache,gpu,pbs-stats,zk-pok -p tfhe-benchmark --profile release --
.PHONY: bench_integer_zk_experimental_gpu
bench_integer_zk_experimental_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) __TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench integer-zk-pke \
--features=integer,internal-keycache,gpu,gpu-experimental-zk,pbs-stats,zk-pok -p tfhe-benchmark --profile release --
.PHONY: bench_integer_aes_gpu # Run benchmarks for AES on GPU backend
bench_integer_aes_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench integer-aes \
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
--features=integer,internal-keycache,gpu -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_integer_aes256_gpu # Run benchmarks for AES256 on GPU backend
bench_integer_aes256_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench integer-aes256 \
--features=integer,internal-keycache,gpu, -p tfhe-benchmark --profile release_lto_off --
--features=integer,internal-keycache,gpu -p tfhe-benchmark --profile release_lto_off --
.PHONY: bench_integer_trivium_gpu # Run benchmarks for trivium on GPU backend
bench_integer_trivium_gpu: install_rs_check_toolchain
@@ -1806,6 +1845,13 @@ bench_tfhe_zk_pok: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench -p tfhe-zk-pok --
.PHONY: bench_tfhe_zk_pok_gpu # Run benchmarks for the tfhe_zk_pok crate using GPU acceleration
bench_tfhe_zk_pok_gpu: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--package tfhe-zk-pok \
--features=gpu-experimental --profile release
.PHONY: bench_hlapi_noise_squash # Run benchmarks for noise squash operation
bench_hlapi_noise_squash: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=$(BENCH_TYPE) __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) \

View File

@@ -46,7 +46,10 @@ fn main() {
}
// Build CUDA library using cmake crate
let limb_bits = std::env::var("ZK_CUDA_LIMB_BITS").unwrap_or_else(|_| "64".to_string());
println!("cargo::rerun-if-env-changed=ZK_CUDA_LIMB_BITS");
let mut cmake_config = cmake::Config::new("cuda");
cmake_config.define("ZK_CUDA_LIMB_BITS", &limb_bits);
let dest = cmake_config.build();
// cmake crate installs to dest/lib subdirectory

View File

@@ -51,6 +51,16 @@ else()
set(CMAKE_CUDA_ARCHITECTURES 70)
endif()
# Limb size configuration: 32 or 64 (default: 64)
# 32-bit limbs enable PTX carry-chain optimizations on GPU
set(ZK_CUDA_LIMB_BITS "64" CACHE STRING "Limb size in bits for Fp arithmetic (32 or 64)")
set_property(CACHE ZK_CUDA_LIMB_BITS PROPERTY STRINGS "32" "64")
if(NOT ZK_CUDA_LIMB_BITS STREQUAL "32" AND NOT ZK_CUDA_LIMB_BITS STREQUAL "64")
message(FATAL_ERROR "ZK_CUDA_LIMB_BITS must be 32 or 64, got: ${ZK_CUDA_LIMB_BITS}")
endif()
add_compile_definitions(LIMB_BITS_CONFIG=${ZK_CUDA_LIMB_BITS})
message(STATUS "Limb size: ${ZK_CUDA_LIMB_BITS}-bit")
# Enable CUDA separable compilation for better optimization
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)

View File

@@ -17,7 +17,15 @@ __host__ __device__ void fp2_zero(Fp2 &a);
// G1 point: (x, y) coordinates in Fp
// Curve equation: y^2 = x^3 + b (short Weierstrass form with a = 0)
struct G1Affine {
// alignas(8) ensures identical struct layout (size 120) in both 32-bit and
// 64-bit limb modes, matching the Rust FFI bindings generated from 64-bit.
// Without this, 32-bit mode produces 116-byte structs (4-byte alignment from
// uint32_t limbs) vs 120 bytes in Rust FFI, causing array stride mismatches
// that corrupt point data for n>1.
// The 4-byte padding overhead is negligible: MSM is compute-bound (Montgomery
// multiplications dominate), and point access patterns in Pippenger-style MSM
// are non-coalescing regardless of struct size.
struct alignas(8) G1Affine {
Fp x;
Fp y;
bool infinity; // true if point at infinity (identity element)
@@ -36,7 +44,9 @@ struct G1Affine {
// G2 point: (x, y) coordinates in Fp2
// Curve equation: y^2 = x^3 + b' (twisted curve over Fp2)
struct G2Affine {
// alignas(8): same rationale as G1Affine above — ensures FFI layout
// compatibility (size 232) between 32-bit and 64-bit limb modes.
struct alignas(8) G2Affine {
Fp2 x;
Fp2 y;
bool infinity; // true if point at infinity (identity element)

View File

@@ -282,6 +282,8 @@ __host__ __device__ void fp_cmov(Fp &dst, const Fp &src, uint64_t condition);
// Helper functions to access constants
// Get modulus reference (device: from constant memory, host: static copy)
__host__ __device__ const Fp &fp_modulus();
// Get Montgomery reduction constant p' = -p^(-1) mod 2^LIMB_BITS
__host__ __device__ UNSIGNED_LIMB fp_p_prime();
// ============================================================================
// Async/Sync API for device memory operations

View File

@@ -1,6 +1,7 @@
#include "bls12_446_params.h"
#include "device.h"
#include "fp.h"
#include "fp_ptx32.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
@@ -187,6 +188,9 @@ __host__ __device__ void fp_copy(Fp &dst, const Fp &src) {
// "Raw" means without modular reduction - performs a + b and returns carry.
// This is an internal helper used by fp_add() which handles reduction.
__host__ __device__ UNSIGNED_LIMB fp_add_raw(Fp &c, const Fp &a, const Fp &b) {
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
return fp_add_raw_ptx32(c, a, b);
#else
UNSIGNED_LIMB carry = 0;
for (int i = 0; i < FP_LIMBS; i++) {
@@ -199,12 +203,16 @@ __host__ __device__ UNSIGNED_LIMB fp_add_raw(Fp &c, const Fp &a, const Fp &b) {
}
return carry;
#endif
}
// Subtraction with borrow propagation
// "Raw" means without modular reduction - performs a - b and returns borrow.
// This is an internal helper used by fp_sub() which handles reduction.
__host__ __device__ UNSIGNED_LIMB fp_sub_raw(Fp &c, const Fp &a, const Fp &b) {
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
return fp_sub_raw_ptx32(c, a, b);
#else
UNSIGNED_LIMB borrow = 0;
for (int i = 0; i < FP_LIMBS; i++) {
@@ -218,11 +226,15 @@ __host__ __device__ UNSIGNED_LIMB fp_sub_raw(Fp &c, const Fp &a, const Fp &b) {
}
return borrow;
#endif
}
// Addition with modular reduction: c = (a + b) mod p
// MONTGOMERY: Both inputs and output must be in Montgomery form
__host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) {
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
fp_add_ptx32(c, a, b);
#else
Fp sum;
UNSIGNED_LIMB carry = fp_add_raw(sum, a, b);
@@ -235,11 +247,15 @@ __host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) {
} else {
fp_copy(c, sum);
}
#endif
}
// Subtraction with modular reduction: c = (a - b) mod p
// MONTGOMERY: Both inputs and output must be in Montgomery form
__host__ __device__ void fp_sub(Fp &c, const Fp &a, const Fp &b) {
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
fp_sub_ptx32(c, a, b);
#else
Fp diff;
UNSIGNED_LIMB borrow = fp_sub_raw(diff, a, b);
@@ -250,6 +266,7 @@ __host__ __device__ void fp_sub(Fp &c, const Fp &a, const Fp &b) {
} else {
fp_copy(c, diff);
}
#endif
}
// Small-constant multiplication via addition chains.
@@ -458,6 +475,9 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
// Uses only FP_LIMBS+1 limbs of working space instead of 2*FP_LIMBS.
// Both a and b are in Montgomery form, result is in Montgomery form.
__host__ __device__ void fp_mont_mul_cios(Fp &c, const Fp &a, const Fp &b) {
#if LIMB_BITS_CONFIG == 32 && defined(__CUDA_ARCH__)
fp_mont_mul_cios_ptx32(c, a, b);
#else
const Fp &p = fp_modulus();
UNSIGNED_LIMB p_prime = fp_p_prime();
@@ -545,6 +565,7 @@ __host__ __device__ void fp_mont_mul_cios(Fp &c, const Fp &a, const Fp &b) {
fp_copy(c, reduced);
}
// Result is in Montgomery form
#endif
}
// Montgomery multiplication: c = (a * b * R_INV) mod p

View File

@@ -29,6 +29,7 @@ rand = { workspace = true }
rayon = { workspace = true }
tfhe = { path = "../tfhe", default-features = false }
tfhe-csprng = { path = "../tfhe-csprng" }
tfhe-zk-pok = { path = "../tfhe-zk-pok", optional = true }
cpu-time = "1.0"
num_cpus = "1.17"
gag = "1.0.0"
@@ -39,12 +40,14 @@ boolean = ["tfhe/boolean"]
shortint = ["tfhe/shortint"]
integer = ["shortint", "tfhe/integer"]
gpu = ["tfhe/gpu"]
# gpu enables tfhe-cuda-backend which provides CUDA stream management used by tfhe-zk-pok
gpu-experimental-zk = ["gpu", "zk-pok", "tfhe/gpu-experimental-zk", "tfhe-zk-pok/gpu-experimental"]
hpu = ["tfhe/hpu"]
hpu-v80 = ["tfhe/hpu-v80"]
internal-keycache = ["tfhe/internal-keycache"]
avx512 = ["tfhe/avx512"]
pbs-stats = ["tfhe/pbs-stats"]
zk-pok = ["tfhe/zk-pok"]
zk-pok = ["tfhe/zk-pok", "dep:tfhe-zk-pok"]
[[bench]]
name = "boolean"
@@ -196,6 +199,12 @@ path = "benches/core_crypto/pbs128_bench.rs"
harness = false
required-features = ["shortint", "internal-keycache"]
[[bench]]
name = "zk-msm"
path = "benches/zk/msm.rs"
harness = false
required-features = ["zk-pok"]
[[bin]]
name = "boolean_key_sizes"
path = "src/bin/boolean_key_sizes.rs"

View File

@@ -485,6 +485,24 @@ mod cuda {
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
use tfhe::integer::gpu::CudaServerKey;
use tfhe::integer::CompressedServerKey;
use tfhe::GpuIndex;
/// Compute the number of elements for GPU ZK throughput benchmarks.
/// Values are tuned to avoid OOM on H100 GPUs while still saturating the GPU.
/// Memory usage scales with both CRS size and bits being proven.
fn gpu_zk_throughput_elements(crs_size: usize, bits: usize) -> u64 {
match (crs_size, bits) {
// 64-bit CRS: smaller proofs, can handle more elements
(64, _) => 30,
// 2048-bit CRS: moderate memory usage
(2048, b) if b <= 256 => 15,
(2048, _) => 10,
// 4096-bit CRS: largest proofs, most memory intensive
(4096, _) => 6,
// Default fallback for unknown configurations
_ => 10,
}
}
fn gpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
let bench_name = "integer::cuda::zk::pke_zk_verify";
@@ -686,12 +704,8 @@ mod cuda {
});
}
BenchmarkType::Throughput => {
let mut elements_per_gpu = 100;
if *bits == 4096 {
elements_per_gpu /= 5;
}
// This value, found empirically, ensure saturation of 8XH100 SXM5
let elements = elements_per_gpu * get_number_of_gpus() as u64;
let elements = gpu_zk_throughput_elements(crs_size, *bits)
* get_number_of_gpus() as u64;
bench_group.throughput(Throughput::Elements(elements));
bench_id_verify = format!(
@@ -716,15 +730,38 @@ mod cuda {
.collect::<Vec<_>>();
let local_streams = cuda_local_streams(num_block, elements as usize);
let d_ksk_material_vec = local_streams
.par_iter()
.map(|local_stream| {
CudaKeySwitchingKeyMaterial::from_key_switching_key(
&ksk,
local_stream,
let gpu_count = get_number_of_gpus() as usize;
let gpu_sks_vec: Vec<CudaServerKey> = (0..gpu_count)
.map(|gpu_idx| {
let stream =
CudaStreams::new_single_gpu(GpuIndex::new(gpu_idx as u32));
CudaServerKey::decompress_from_cpu(
&compressed_server_key,
&stream,
)
})
.collect::<Vec<_>>();
.collect();
let d_ksk_material_vec: Vec<CudaKeySwitchingKeyMaterial> = (0
..gpu_count)
.map(|gpu_idx| {
let stream =
CudaStreams::new_single_gpu(GpuIndex::new(gpu_idx as u32));
CudaKeySwitchingKeyMaterial::from_key_switching_key(
&ksk, &stream,
)
})
.collect();
let d_ksks: Vec<CudaKeySwitchingKey> = (0..gpu_count)
.map(|gpu_idx| {
CudaKeySwitchingKey::from_cuda_key_switching_key_material(
&d_ksk_material_vec[gpu_idx],
&gpu_sks_vec[gpu_idx],
)
})
.collect();
bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| {
@@ -750,17 +787,16 @@ mod cuda {
|gpu_cts| {
gpu_cts.par_iter().enumerate().for_each
(|(i, gpu_ct)| {
let local_stream = &local_streams[i % local_streams.len()];
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
let stream_idx = i % local_streams.len();
let local_stream = &local_streams[stream_idx];
let gpu_idx = i % gpu_count;
let d_ksk = &d_ksks[gpu_idx];
gpu_ct
.expand_without_verification(&d_ksk, local_stream)
.expand_without_verification(d_ksk, local_stream)
.unwrap();
});
}, BatchSize::SmallInput);
}, BatchSize::PerIteration);
});
bench_group.bench_function(&bench_id_verify_and_expand, |b| {
@@ -778,18 +814,18 @@ mod cuda {
|gpu_cts| {
gpu_cts.par_iter().enumerate().for_each
(|(i, gpu_ct)| {
let local_stream = &local_streams[i % local_streams.len()];
let gpu_sk = CudaServerKey::decompress_from_cpu(&compressed_server_key, local_stream);
let d_ksk =
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % local_streams.len()], &gpu_sk);
let stream_idx = i % local_streams.len();
let local_stream = &local_streams[stream_idx];
let gpu_idx = i % gpu_count;
let d_ksk = &d_ksks[gpu_idx];
gpu_ct
.verify_and_expand(
&crs, &pk, &metadata, &d_ksk, local_stream,
&crs, &pk, &metadata, d_ksk, local_stream,
)
.unwrap();
});
}, BatchSize::SmallInput);
}, BatchSize::PerIteration);
});
}
}
@@ -816,11 +852,154 @@ mod cuda {
bench_group.finish()
}
fn gpu_pke_zk_proof(c: &mut Criterion) {
let bench_name = "zk::cuda::pke_zk_proof";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let params: [(
CompactPublicKeyEncryptionParameters,
ShortintKeySwitchingParameters,
PBSParameters,
); 2] = [
(
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
PARAM_GPU_MULTI_BIT_GROUP_4_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(),
),
(
BENCH_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
BENCH_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(),
),
];
for (param_pke, _param_ksk, param_fhe) in params.iter() {
let param_name = param_fhe.name();
let param_name = param_name.as_str();
let cks = ClientKey::new(*param_fhe);
let sks = ServerKey::new_radix_server_key(&cks);
let compact_private_key = CompactPrivateKey::new(*param_pke);
let pk = CompactPublicKey::new(&compact_private_key);
// Kept for consistency
let _casting_key =
KeySwitchingKey::new((&compact_private_key, None), (&cks, &sks), *_param_ksk);
// We have a use case with 320 bits of metadata
let mut metadata = [0u8; (320 / u8::BITS) as usize];
let mut rng = rand::thread_rng();
metadata.fill_with(|| rng.gen());
let zk_vers = param_pke.zk_scheme;
for proof_config in default_proof_config().iter() {
let msg_bits =
(param_pke.message_modulus.0 * param_pke.carry_modulus.0).ilog2() as usize;
println!("Generating CRS... ");
let crs_size = proof_config.crs_size;
let crs = CompactPkeCrs::from_shortint_params(
*param_pke,
LweCiphertextCount(crs_size / msg_bits),
)
.unwrap();
for bits in proof_config.bits_to_prove.iter() {
assert_eq!(bits % 64, 0);
// Packing, so we take the message and carry modulus to compute our block count
let num_block = 64usize.div_ceil(msg_bits);
let fhe_uint_count = bits / 64;
for compute_load in [ZkComputeLoad::Proof, ZkComputeLoad::Verify] {
let zk_load = match compute_load {
ZkComputeLoad::Proof => "compute_load_proof",
ZkComputeLoad::Verify => "compute_load_verify",
};
let bench_id;
match get_bench_type() {
BenchmarkType::Latency => {
bench_id = format!(
"{bench_name}::{param_name}_{bits}_bits_packed_{crs_size}_bits_crs_{zk_load}_ZK{zk_vers:?}"
);
bench_group.bench_function(&bench_id, |b| {
let input_msg = rng.gen::<u64>();
let messages = vec![input_msg; fhe_uint_count];
b.iter(|| {
let _ct1 =
tfhe::integer::ProvenCompactCiphertextList::builder(
&pk,
)
.extend(messages.iter().copied())
.build_with_proof_packed(&crs, &metadata, compute_load)
.unwrap();
})
});
}
BenchmarkType::Throughput => {
// The zk proof is currently not pooled, so we simply use the number
// of threads as heuristic for the
// batch size
let elements =
(rayon::current_num_threads() / num_block).max(1) + 1;
bench_group.throughput(Throughput::Elements(elements as u64));
bench_id = format!(
"{bench_name}::throughput::{param_name}_{bits}_bits_packed_{crs_size}_bits_crs_{zk_load}_ZK{zk_vers:?}"
);
bench_group.bench_function(&bench_id, |b| {
let messages = (0..elements)
.map(|_| {
let input_msg = rng.gen::<u64>();
vec![input_msg; fhe_uint_count]
})
.collect::<Vec<_>>();
b.iter(|| {
messages.par_iter().for_each(|msg| {
tfhe::integer::ProvenCompactCiphertextList::builder(
&pk,
)
.extend(msg.iter().copied())
.build_with_proof_packed(&crs, &metadata, compute_load)
.unwrap();
})
})
});
}
}
let shortint_params: PBSParameters = *param_fhe;
write_to_json::<u64, _>(
&bench_id,
shortint_params,
param_name,
"pke_zk_proof",
&OperatorType::Atomic,
shortint_params.message_modulus().0 as u32,
vec![shortint_params.message_modulus().0.ilog2(); num_block],
);
}
}
}
}
}
pub fn gpu_zk_verify() {
let results_file = Path::new("gpu_pke_zk_crs_sizes.csv");
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
gpu_pke_zk_verify(&mut criterion, results_file);
}
pub fn gpu_zk_proof() {
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
gpu_pke_zk_proof(&mut criterion);
}
}
pub fn zk_verify_and_proof() {
@@ -831,11 +1010,14 @@ pub fn zk_verify_and_proof() {
}
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
use crate::cuda::gpu_zk_verify;
use crate::cuda::{gpu_zk_proof, gpu_zk_verify};
fn main() {
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
gpu_zk_verify();
{
gpu_zk_proof();
gpu_zk_verify();
}
#[cfg(not(feature = "gpu"))]
zk_verify_and_proof();

View File

@@ -0,0 +1,406 @@
//! Benchmark comparing CPU MSM vs GPU MSM for BLS12-446
//!
//! This benchmark measures the performance of multi-scalar multiplication (MSM)
//! for both G1 and G2 points on the BLS12-446 curve.
//!
//! CPU benchmarks use the arkworks-based `G1Affine::multi_mul_scalar` /
//! `G2Affine::multi_mul_scalar`. GPU benchmarks (gated behind the
//! `gpu-experimental-zk` feature) call `tfhe_zk_pok::gpu::g1_msm_gpu` /
//! `tfhe_zk_pok::gpu::g2_msm_gpu` directly, which dispatch to the
//! zk-cuda-backend.
//!
//! ## Running the benchmarks
//!
//! ```bash
//! # CPU only
//! cargo bench --package tfhe-benchmark --bench zk-msm
//!
//! # CPU and GPU
//! cargo bench --package tfhe-benchmark --bench zk-msm --features gpu-experimental-zk
//! ```
use benchmark::utilities::{
get_bench_type, write_to_json, BenchmarkType, CryptoParametersRecord, OperatorType,
};
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion, Throughput};
use rand::rngs::StdRng;
use rand::SeedableRng;
use rayon::prelude::*;
use std::time::Duration;
use tfhe_zk_pok::curve_api::bls12_446::{G1Affine, G2Affine, Zp, G1, G2};
use tfhe_zk_pok::curve_api::CurveGroupOps;
/// Compute the number of parallel elements for MSM throughput benchmarks.
/// Uses aggressive values to maximize throughput testing while keeping setup time reasonable.
fn msm_throughput_elements(input_size: usize) -> u64 {
match input_size {
n if n <= 1000 => 64,
n if n <= 4096 => 32,
_ => 16,
}
}
/// Generate random G1 affine points using tfhe-zk-pok
fn generate_g1_affine_points(rng: &mut StdRng, n: usize) -> Vec<G1Affine> {
(0..n)
.map(|_| {
let point = G1::GENERATOR.mul_scalar(Zp::rand(rng));
point.normalize()
})
.collect()
}
/// Generate random G2 affine points using tfhe-zk-pok
fn generate_g2_affine_points(rng: &mut StdRng, n: usize) -> Vec<G2Affine> {
(0..n)
.map(|_| {
let point = G2::GENERATOR.mul_scalar(Zp::rand(rng));
point.normalize()
})
.collect()
}
/// Generate random scalars using tfhe-zk-pok
fn generate_scalars(rng: &mut StdRng, n: usize) -> Vec<Zp> {
(0..n).map(|_| Zp::rand(rng)).collect()
}
/// Benchmark CPU MSM for G1 points using tfhe-zk-pok entry points
fn bench_cpu_g1_msm(c: &mut Criterion) {
let curve_name = "bls12_446";
let subgroup_name = "G1";
let bench_name = format!("zk::msm::{curve_name}::{subgroup_name}");
let mut group = c.benchmark_group(&bench_name);
group.sample_size(10);
group.measurement_time(Duration::from_secs(30));
for size in [100, 1000, 2048, 4096, 10000].iter() {
let n = *size;
let bench_id;
let bench_shortname = "zk::msm::bls12_446::g1";
match get_bench_type() {
BenchmarkType::Latency => {
let mut rng = StdRng::seed_from_u64(42);
let bases = generate_g1_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
bench_id = format!("{bench_name}::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
b.iter(|| {
let result =
G1Affine::multi_mul_scalar(black_box(&bases), black_box(&scalars));
black_box(result)
});
});
}
BenchmarkType::Throughput => {
let elements = msm_throughput_elements(n);
group.throughput(Throughput::Elements(elements));
bench_id = format!("{bench_name}::throughput::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
// Setup generates test data in parallel, excluded from measurement
let setup = || {
(0..elements)
.into_par_iter()
.map(|i| {
let mut rng = StdRng::seed_from_u64(42 + i);
let bases = generate_g1_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
(bases, scalars)
})
.collect::<Vec<_>>()
};
b.iter_batched(
setup,
|test_data| {
test_data.par_iter().for_each(|(bases, scalars)| {
let result = G1Affine::multi_mul_scalar(
black_box(bases),
black_box(scalars),
);
black_box(result);
});
},
BatchSize::LargeInput,
);
});
}
}
// MSM benchmarks are curve operations, use minimal parameters
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
write_to_json(
&bench_id,
params,
"MSM_BLS12_446_G1",
bench_shortname,
&OperatorType::Atomic,
64, // bit_size for curve scalar operations
vec![], // decomposition_basis not applicable for MSM
);
}
group.finish();
}
/// Benchmark CPU MSM for G2 points using tfhe-zk-pok entry points
fn bench_cpu_g2_msm(c: &mut Criterion) {
let curve_name = "bls12_446";
let subgroup_name = "G2";
let bench_name = format!("zk::msm::{curve_name}::{subgroup_name}");
let mut group = c.benchmark_group(&bench_name);
group.sample_size(10);
group.measurement_time(Duration::from_secs(30));
for size in [100, 1000, 2048, 4096, 10000].iter() {
let n = *size;
let bench_id;
let bench_shortname = "zk::msm::bls12_446::g2";
match get_bench_type() {
BenchmarkType::Latency => {
let mut rng = StdRng::seed_from_u64(42);
let bases = generate_g2_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
bench_id = format!("{bench_name}::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
b.iter(|| {
let result =
G2Affine::multi_mul_scalar(black_box(&bases), black_box(&scalars));
black_box(result)
});
});
}
BenchmarkType::Throughput => {
let elements = msm_throughput_elements(n);
group.throughput(Throughput::Elements(elements));
bench_id = format!("{bench_name}::throughput::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
// Setup generates test data in parallel, excluded from measurement
let setup = || {
(0..elements)
.into_par_iter()
.map(|i| {
let mut rng = StdRng::seed_from_u64(42 + i);
let bases = generate_g2_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
(bases, scalars)
})
.collect::<Vec<_>>()
};
b.iter_batched(
setup,
|test_data| {
test_data.par_iter().for_each(|(bases, scalars)| {
let result = G2Affine::multi_mul_scalar(
black_box(bases),
black_box(scalars),
);
black_box(result);
});
},
BatchSize::LargeInput,
);
});
}
}
// MSM benchmarks are curve operations, use minimal parameters
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
write_to_json(
&bench_id,
params,
"MSM_BLS12_446_G2",
bench_shortname,
&OperatorType::Atomic,
64, // bit_size for curve scalar operations
vec![], // decomposition_basis not applicable for MSM
);
}
group.finish();
}
/// Benchmark GPU MSM for G1 points via `tfhe_zk_pok::gpu::g1_msm_gpu`
#[cfg(feature = "gpu-experimental-zk")]
fn bench_gpu_g1_msm(c: &mut Criterion) {
use tfhe_zk_pok::gpu::{g1_msm_gpu, select_gpu_for_msm};
let curve_name = "bls12_446";
let subgroup_name = "G1";
let bench_name = format!("zk::cuda::msm::{curve_name}::{subgroup_name}");
let mut group = c.benchmark_group(&bench_name);
group.sample_size(10);
group.measurement_time(Duration::from_secs(30));
// Resolve GPU index once — stream creation/destruction is handled inside g1_msm_gpu
let gpu_index = select_gpu_for_msm();
for size in [100, 1000, 2048, 4096, 10000].iter() {
let n = *size;
let bench_id;
let bench_shortname = "zk::cuda::msm::bls12_446::g1";
match get_bench_type() {
BenchmarkType::Latency => {
let mut rng = StdRng::seed_from_u64(42);
let bases = generate_g1_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
bench_id = format!("{bench_name}::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
b.iter(|| {
let result = g1_msm_gpu(black_box(&bases), black_box(&scalars), gpu_index);
black_box(result)
});
});
}
BenchmarkType::Throughput => {
let elements = msm_throughput_elements(n);
group.throughput(Throughput::Elements(elements));
bench_id = format!("{bench_name}::throughput::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
let setup = || {
(0..elements)
.into_par_iter()
.map(|i| {
let mut rng = StdRng::seed_from_u64(42 + i);
let bases = generate_g1_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
(bases, scalars)
})
.collect::<Vec<_>>()
};
b.iter_batched(
setup,
|test_data| {
test_data.par_iter().for_each(|(bases, scalars)| {
let result =
g1_msm_gpu(black_box(bases), black_box(scalars), gpu_index);
black_box(result);
});
},
BatchSize::LargeInput,
);
});
}
}
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
write_to_json(
&bench_id,
params,
"MSM_BLS12_446_G1_CUDA",
bench_shortname,
&OperatorType::Atomic,
64, // bit_size for curve scalar operations
vec![], // decomposition_basis not applicable for MSM
);
}
group.finish();
}
/// Benchmark GPU MSM for G2 points via `tfhe_zk_pok::gpu::g2_msm_gpu`
#[cfg(feature = "gpu-experimental-zk")]
fn bench_gpu_g2_msm(c: &mut Criterion) {
use tfhe_zk_pok::gpu::{g2_msm_gpu, select_gpu_for_msm};
let curve_name = "bls12_446";
let subgroup_name = "G2";
let bench_name = format!("zk::cuda::msm::{curve_name}::{subgroup_name}");
let mut group = c.benchmark_group(&bench_name);
group.sample_size(10);
group.measurement_time(Duration::from_secs(30));
let gpu_index = select_gpu_for_msm();
for size in [100, 1000, 2048, 4096, 10000].iter() {
let n = *size;
let bench_id;
let bench_shortname = "zk::cuda::msm::bls12_446::g2";
match get_bench_type() {
BenchmarkType::Latency => {
let mut rng = StdRng::seed_from_u64(42);
let bases = generate_g2_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
bench_id = format!("{bench_name}::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
b.iter(|| {
let result = g2_msm_gpu(black_box(&bases), black_box(&scalars), gpu_index);
black_box(result)
});
});
}
BenchmarkType::Throughput => {
let elements = msm_throughput_elements(n);
group.throughput(Throughput::Elements(elements));
bench_id = format!("{bench_name}::throughput::{n}");
group.bench_with_input(&bench_id, &n, |b, _| {
let setup = || {
(0..elements)
.into_par_iter()
.map(|i| {
let mut rng = StdRng::seed_from_u64(42 + i);
let bases = generate_g2_affine_points(&mut rng, n);
let scalars = generate_scalars(&mut rng, n);
(bases, scalars)
})
.collect::<Vec<_>>()
};
b.iter_batched(
setup,
|test_data| {
test_data.par_iter().for_each(|(bases, scalars)| {
let result =
g2_msm_gpu(black_box(bases), black_box(scalars), gpu_index);
black_box(result);
});
},
BatchSize::LargeInput,
);
});
}
}
let params: CryptoParametersRecord<u64> = CryptoParametersRecord::default();
write_to_json(
&bench_id,
params,
"MSM_BLS12_446_G2_CUDA",
bench_shortname,
&OperatorType::Atomic,
64, // bit_size for curve scalar operations
vec![], // decomposition_basis not applicable for MSM
);
}
group.finish();
}
// CPU benchmarks (always available)
criterion_group!(benches_cpu, bench_cpu_g1_msm, bench_cpu_g2_msm,);
// GPU benchmarks (only when GPU feature is enabled)
#[cfg(feature = "gpu-experimental-zk")]
criterion_group!(benches_gpu, bench_gpu_g1_msm, bench_gpu_g2_msm,);
// Conditionally include GPU benchmarks in main
#[cfg(feature = "gpu-experimental-zk")]
criterion_main!(benches_cpu, benches_gpu);
#[cfg(not(feature = "gpu-experimental-zk"))]
criterion_main!(benches_cpu);

View File

@@ -14,8 +14,8 @@ rust-version.workspace = true
[dependencies]
ark-bls12-381 = "0.5.0"
ark-ec = { version = "0.5.0", features = ["parallel"] }
ark-ff = { version = "0.5.0", features = ["parallel"] }
ark-ec = { workspace = true, features = ["parallel"] }
ark-ff = { workspace = true, features = ["parallel"] }
ark-poly = { version = "0.5.0", features = ["parallel"] }
rand = { workspace = true }
rayon = { workspace = true }
@@ -24,9 +24,13 @@ serde = { workspace = true, features = ["default", "derive"] }
zeroize = "1.7.0"
num-bigint = "0.4.5"
tfhe-versionable = { version = "0.7.0", path = "../utils/tfhe-versionable" }
zk-cuda-backend = { version = "0.1.0", path = "../backends/zk-cuda-backend", optional = true }
tfhe-cuda-backend = { version = "0.13.0", path = "../backends/tfhe-cuda-backend", optional = true }
itertools.workspace = true
[features]
experimental = []
gpu-experimental = ["dep:zk-cuda-backend", "dep:tfhe-cuda-backend"]
[dev-dependencies]
serde_json = "~1.0"

View File

@@ -91,5 +91,110 @@ fn bench_pke_v1_verify(c: &mut Criterion) {
}
}
#[cfg(feature = "gpu-experimental")]
mod gpu {
use super::*;
use tfhe_zk_pok::proofs::pke;
pub fn bench_pke_v1_prove_gpu(c: &mut Criterion) {
let bench_shortname = "pke_zk_proof_v1";
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
let mut bench_group = c.benchmark_group(&bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let rng = &mut rand::thread_rng();
for (params, param_name) in [
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
] {
let (public_param, public_commit, private_commit, metadata) = init_params_v1(params);
let effective_t = params.t >> 1;
let bits = (params.k as u32) * effective_t.ilog2();
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let bench_id = format!("{bench_name}::{param_name}_{bits}_bits_packed_{load}");
let seed: u128 = rng.gen();
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
pke::gpu::prove(
(&public_param, &public_commit),
&private_commit,
&metadata,
load,
&seed.to_le_bytes(),
)
})
});
write_to_json(&bench_id, params, param_name, bench_shortname);
}
}
}
pub fn bench_pke_v1_verify_gpu(c: &mut Criterion) {
let bench_shortname = "pke_zk_verify_v1";
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
let mut bench_group = c.benchmark_group(&bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let rng = &mut rand::thread_rng();
for (params, param_name) in [
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
] {
let (public_param, public_commit, private_commit, metadata) = init_params_v1(params);
let effective_t = params.t >> 1;
let bits = (params.k as u32) * effective_t.ilog2();
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let bench_id = format!("{bench_name}::{param_name}_{bits}_bits_packed_{load}");
let seed: u128 = rng.gen();
// Use GPU prove to generate the proof
let proof = pke::gpu::prove(
(&public_param, &public_commit),
&private_commit,
&metadata,
load,
&seed.to_le_bytes(),
);
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
pke::gpu::verify(&proof, (&public_param, &public_commit), &metadata)
.unwrap();
})
});
write_to_json(&bench_id, params, param_name, bench_shortname);
}
}
}
}
criterion_group!(benches_pke_v1, bench_pke_v1_verify, bench_pke_v1_prove);
#[cfg(feature = "gpu-experimental")]
use gpu::{bench_pke_v1_prove_gpu, bench_pke_v1_verify_gpu};
#[cfg(feature = "gpu-experimental")]
criterion_group!(
benches_pke_v1_gpu,
bench_pke_v1_verify_gpu,
bench_pke_v1_prove_gpu
);
#[cfg(feature = "gpu-experimental")]
criterion_main!(benches_pke_v1, benches_pke_v1_gpu);
#[cfg(not(feature = "gpu-experimental"))]
criterion_main!(benches_pke_v1);

View File

@@ -107,5 +107,130 @@ fn bench_pke_v2_verify(c: &mut Criterion) {
}
}
#[cfg(feature = "gpu-experimental")]
mod gpu {
use super::*;
use tfhe_zk_pok::proofs::pke_v2;
pub fn bench_pke_v2_prove_gpu(c: &mut Criterion) {
let bench_shortname = "pke_zk_proof_v2";
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
let mut bench_group = c.benchmark_group(&bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let rng = &mut rand::thread_rng();
for ((params, param_name), load, bound) in itertools::iproduct!(
[
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
],
[ComputeLoad::Proof, ComputeLoad::Verify],
[Bound::CS, Bound::GHL]
) {
let (public_param, public_commit, private_commit, metadata) =
init_params_v2(params, bound);
let effective_t = params.t >> 1;
let bits = (params.k as u32) * effective_t.ilog2();
let bench_id =
format!("{bench_name}::{param_name}_{bits}_bits_packed_{load}_{bound:?}");
println!("{bench_id}");
let seed: u128 = rng.gen();
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
pke_v2::gpu::prove(
(&public_param, &public_commit),
&private_commit,
&metadata,
load,
&seed.to_le_bytes(),
)
})
});
write_to_json(&bench_id, params, param_name, bench_shortname);
}
}
pub fn bench_pke_v2_verify_gpu(c: &mut Criterion) {
let bench_shortname = "pke_zk_verify_v2";
let bench_name = format!("tfhe_zk_pok::cuda::{bench_shortname}");
let mut bench_group = c.benchmark_group(&bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let rng = &mut rand::thread_rng();
for ((params, param_name), load, bound, pairing_mode) in itertools::iproduct!(
[
(PKEV1_TEST_PARAMS, "PKEV1_TEST_PARAMS"),
(PKEV2_TEST_PARAMS, "PKEV2_TEST_PARAMS"),
],
[ComputeLoad::Proof, ComputeLoad::Verify],
[Bound::CS, Bound::GHL],
[
VerificationPairingMode::TwoSteps,
VerificationPairingMode::Batched
]
) {
let (public_param, public_commit, private_commit, metadata) =
init_params_v2(params, bound);
let effective_t = params.t >> 1;
let bits = (params.k as u32) * effective_t.ilog2();
let bench_id = format!(
"{bench_name}::{param_name}_{bits}_bits_packed_{load}_{bound:?}_{pairing_mode:?}"
);
println!("{bench_id}");
let seed: u128 = rng.gen();
// Use GPU prove to generate the proof
let proof = pke_v2::gpu::prove(
(&public_param, &public_commit),
&private_commit,
&metadata,
load,
&seed.to_le_bytes(),
);
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
pke_v2::gpu::verify(
&proof,
(&public_param, &public_commit),
&metadata,
pairing_mode,
)
.unwrap();
})
});
write_to_json(&bench_id, params, param_name, bench_shortname);
}
}
}
criterion_group!(benches_pke_v2, bench_pke_v2_verify, bench_pke_v2_prove);
#[cfg(feature = "gpu-experimental")]
use gpu::{bench_pke_v2_prove_gpu, bench_pke_v2_verify_gpu};
#[cfg(feature = "gpu-experimental")]
criterion_group!(
benches_pke_v2_gpu,
bench_pke_v2_verify_gpu,
bench_pke_v2_prove_gpu
);
#[cfg(feature = "gpu-experimental")]
criterion_main!(benches_pke_v2, benches_pke_v2_gpu);
#[cfg(not(feature = "gpu-experimental"))]
criterion_main!(benches_pke_v2);

331
tfhe-zk-pok/src/gpu/mod.rs Normal file
View File

@@ -0,0 +1,331 @@
//! GPU acceleration module for tfhe-zk-pok
//!
//! This module provides GPU-accelerated operations using zk-cuda-backend,
//! type conversions between tfhe-zk-pok and zk-cuda-backend types,
//! and GPU MSM helper functions used by `proofs::pke::gpu` and
//! `proofs::pke_v2::gpu`.
#[cfg(test)]
mod tests;
use crate::curve_446::{Fq, Fq2};
use crate::curve_api::bls12_446::{G1Affine, G2Affine, Zp, G1, G2};
use crate::curve_api::CurveGroupOps;
use ark_ec::CurveGroup;
use ark_ff::{BigInt, MontFp, PrimeField};
use tfhe_cuda_backend::cuda_bind::{
cuda_create_stream, cuda_destroy_stream, cuda_get_number_of_gpus,
};
use zk_cuda_backend::{G1Affine as ZkG1Affine, G2Affine as ZkG2Affine, Scalar as ZkScalar};
// ---------------------------------------------------------------------------
// Compile-time assertions to verify transmute safety between wrapper types and inner arkworks
// types. These ensure that G1Affine/G2Affine wrappers are truly repr(transparent) around their
// inner types.
// ---------------------------------------------------------------------------
const _: () = {
assert!(
std::mem::size_of::<G1Affine>() == std::mem::size_of::<crate::curve_446::g1::G1Affine>()
);
assert!(
std::mem::align_of::<G1Affine>() == std::mem::align_of::<crate::curve_446::g1::G1Affine>()
);
assert!(
std::mem::size_of::<G2Affine>() == std::mem::size_of::<crate::curve_446::g2::G2Affine>()
);
assert!(
std::mem::align_of::<G2Affine>() == std::mem::align_of::<crate::curve_446::g2::G2Affine>()
);
};
// ---------------------------------------------------------------------------
// GPU helpers
// ---------------------------------------------------------------------------
/// Returns the number of available GPUs. Panics if no GPU is found.
///
/// The result is cached after the first call since GPU count cannot change
/// during execution.
pub(crate) fn get_num_gpus() -> u32 {
static NUM_GPUS: std::sync::OnceLock<u32> = std::sync::OnceLock::new();
*NUM_GPUS.get_or_init(|| {
// SAFETY: cuda_get_number_of_gpus is a pure query with no preconditions
let num_gpus = unsafe { cuda_get_number_of_gpus() };
assert!(num_gpus > 0, "No GPU available");
num_gpus
.try_into()
.expect("cuda_get_number_of_gpus returned negative value")
})
}
/// Selects a GPU for MSM based on the rayon thread index, distributing work
/// across all available GPUs. Returns `None` (meaning GPU 0) when only one GPU
/// is present.
#[inline]
pub fn select_gpu_for_msm() -> Option<u32> {
let num_gpus = get_num_gpus();
if num_gpus <= 1 {
return None;
}
let thread_idx = rayon::current_thread_index().unwrap_or(0);
Some(
(thread_idx % num_gpus as usize)
.try_into()
.expect("GPU index fits in u32"),
)
}
// ---------------------------------------------------------------------------
// Type conversion helpers
// ---------------------------------------------------------------------------
/// Convert an Fq element (arkworks) to a zk-cuda-backend Fp (normal form limbs).
#[inline]
fn fq_to_cuda_fp(fq: &Fq) -> zk_cuda_backend::Fp {
zk_cuda_backend::Fp::new(fq.into_bigint().0)
}
/// Convert a zk-cuda-backend Fp (normal form limbs) back to an arkworks Fq.
#[inline]
fn fq_from_cuda_fp(fp: &zk_cuda_backend::Fp) -> Fq {
Fq::from_bigint(BigInt::new(fp.limb)).expect("invalid Fq element from CUDA Fp limbs")
}
/// Convert an Fq2 element (arkworks) to a zk-cuda-backend Fp2 (normal form).
#[inline]
fn fq2_to_cuda_fp2(fq2: &Fq2) -> zk_cuda_backend::bindings::Fp2 {
zk_cuda_backend::bindings::Fp2 {
c0: fq_to_cuda_fp(&fq2.c0),
c1: fq_to_cuda_fp(&fq2.c1),
}
}
/// Convert a zk-cuda-backend Fp2 back to an arkworks Fq2.
#[inline]
fn fq2_from_cuda_fp2(fp2: &zk_cuda_backend::bindings::Fp2) -> Fq2 {
Fq2::new(fq_from_cuda_fp(&fp2.c0), fq_from_cuda_fp(&fp2.c1))
}
/// Convert a tfhe-zk-pok G1Affine to a zk-cuda-backend G1Affine (normal form).
///
/// # Panics
///
/// Panics if the input point is non-identity but has no coordinates (malformed arkworks point).
pub fn g1_affine_to_zk_cuda(affine: &G1Affine) -> ZkG1Affine {
use ark_ec::AffineRepr;
if affine.inner.is_zero() {
return ZkG1Affine::infinity();
}
let xy = affine
.inner
.xy()
.expect("non-identity point must have coordinates");
let x = fq_to_cuda_fp(&xy.0);
let y = fq_to_cuda_fp(&xy.1);
ZkG1Affine::new(x, y, affine.inner.infinity)
}
/// Convert a zk-cuda-backend G1Affine back to a tfhe-zk-pok G1Affine.
///
/// # Panics
///
/// Panics if the Fp limbs from the zk-cuda-backend point do not represent a valid `Fq` element
/// (i.e., the value is not in the base field).
pub fn g1_affine_from_zk_cuda(affine: &ZkG1Affine) -> G1Affine {
if affine.is_infinity() {
return G1::ZERO.normalize();
}
let x = fq_from_cuda_fp(&affine.x());
let y = fq_from_cuda_fp(&affine.y());
use crate::curve_446::g1::G1Projective;
let one = MontFp!("1");
let proj = G1Projective::new_unchecked(x, y, one);
let inner = <G1Projective as CurveGroup>::into_affine(proj);
G1Affine { inner }
}
/// Convert a tfhe-zk-pok G2Affine to a zk-cuda-backend G2Affine (normal form).
///
/// # Panics
///
/// Panics if the input point is non-identity but has no coordinates (malformed arkworks point).
pub fn g2_affine_to_zk_cuda(affine: &G2Affine) -> ZkG2Affine {
use ark_ec::AffineRepr;
if affine.inner.is_zero() {
return ZkG2Affine::infinity();
}
let xy = affine
.inner
.xy()
.expect("non-identity point must have coordinates");
let x = fq2_to_cuda_fp2(&xy.0);
let y = fq2_to_cuda_fp2(&xy.1);
ZkG2Affine::new(x, y, affine.inner.infinity)
}
/// Convert a zk-cuda-backend G2Affine back to a tfhe-zk-pok G2Affine.
///
/// # Panics
///
/// Panics if the Fp limbs from the zk-cuda-backend point do not represent valid `Fq` elements
/// (i.e., the values are not in the base field).
pub fn g2_affine_from_zk_cuda(affine: &ZkG2Affine) -> G2Affine {
if affine.is_infinity() {
return G2::ZERO.normalize();
}
let x = fq2_from_cuda_fp2(&affine.x());
let y = fq2_from_cuda_fp2(&affine.y());
use crate::curve_446::g2::G2Projective;
let one = MontFp!("1");
let zero = MontFp!("0");
let one_fq2 = Fq2::new(one, zero);
let proj = G2Projective::new_unchecked(x, y, one_fq2);
let inner = <G2Projective as CurveGroup>::into_affine(proj);
G2Affine { inner }
}
/// Convert a Zp scalar to a zk-cuda-backend Scalar.
///
/// # Panics
///
/// This function does not panic. The `into_bigint` conversion on arkworks `Fp` types is
/// infallible, and `ZkScalar::from` accepts any 5-limb array.
pub fn zp_to_zk_scalar(zp: &Zp) -> ZkScalar {
let limbs = zp.inner.into_bigint().0;
ZkScalar::from(limbs)
}
// ---------------------------------------------------------------------------
// GPU MSM functions
// ---------------------------------------------------------------------------
/// GPU-accelerated multi-scalar multiplication for G1. The `gpu_index` parameter
/// selects which GPU device to use; `None` defaults to GPU 0.
///
/// # Panics
///
/// - If `gpu_index` is `Some(i)` where `i >= number of available GPUs`.
/// - If `bases` and `scalars` have different lengths (checked inside the backend).
/// - If the GPU MSM call fails.
pub fn g1_msm_gpu(bases: &[G1Affine], scalars: &[Zp], gpu_index: Option<u32>) -> G1 {
use crate::curve_446::g1::G1Projective;
// Convert points to zk-cuda-backend format (normal form)
let gpu_bases: Vec<_> = bases
.iter()
.map(|b| zk_cuda_backend::g1_affine_from_arkworks(&b.inner))
.collect();
let gpu_scalars: Vec<_> = scalars
.iter()
.map(|s| zk_cuda_backend::Scalar::from(s.inner.into_bigint().0))
.collect();
let gpu_index = gpu_index.unwrap_or(0);
let num_gpus = get_num_gpus();
assert!(
gpu_index < num_gpus,
"gpu_index {gpu_index} exceeds available GPUs ({num_gpus})",
);
// SAFETY: gpu_index was validated by the assert above
let stream = unsafe { cuda_create_stream(gpu_index) };
let result =
zk_cuda_backend::G1Projective::msm(&gpu_bases, &gpu_scalars, stream, gpu_index, false);
// SAFETY: stream was created by cuda_create_stream above with the same gpu_index and is not
// used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
let (gpu_result, _size_tracker) = result.unwrap_or_else(|e| panic!("G1 GPU MSM failed: {e}"));
// Convert result from Montgomery form back to arkworks types
let normalized = gpu_result.from_montgomery_normalized();
let z_fp = normalized.Z();
if z_fp.limb.iter().all(|&limb| limb == 0) {
return G1::ZERO;
}
let x = fq_from_cuda_fp(&normalized.X());
let y = fq_from_cuda_fp(&normalized.Y());
let z = fq_from_cuda_fp(&z_fp);
G1 {
inner: G1Projective::new_unchecked(x, y, z),
}
}
/// GPU-accelerated multi-scalar multiplication for G2. The `gpu_index` parameter
/// selects which GPU device to use; `None` defaults to GPU 0.
///
/// # Panics
///
/// - If `gpu_index` is `Some(i)` where `i >= number of available GPUs`.
/// - If `bases` and `scalars` have different lengths (checked inside the backend).
/// - If the GPU MSM call fails.
pub fn g2_msm_gpu(bases: &[G2Affine], scalars: &[Zp], gpu_index: Option<u32>) -> G2 {
use crate::curve_446::g2::G2Projective;
use ark_ec::AffineRepr;
// Convert points to zk-cuda-backend format (normal form)
let gpu_bases: Vec<_> = bases
.iter()
.map(|b| {
if b.inner.is_zero() {
return zk_cuda_backend::G2Affine::infinity();
}
let x = fq2_to_cuda_fp2(&b.inner.x);
let y = fq2_to_cuda_fp2(&b.inner.y);
zk_cuda_backend::G2Affine::new(x, y, false)
})
.collect();
let gpu_scalars: Vec<_> = scalars
.iter()
.map(|s| zk_cuda_backend::Scalar::from(s.inner.into_bigint().0))
.collect();
let gpu_index = gpu_index.unwrap_or(0);
let num_gpus = get_num_gpus();
assert!(
gpu_index < num_gpus,
"gpu_index {gpu_index} exceeds available GPUs ({num_gpus})",
);
// SAFETY: gpu_index was validated by the assert above
let stream = unsafe { cuda_create_stream(gpu_index) };
let result =
zk_cuda_backend::G2Projective::msm(&gpu_bases, &gpu_scalars, stream, gpu_index, false);
// SAFETY: stream was created by cuda_create_stream above with the same gpu_index and is not
// used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
let (gpu_result, _size_tracker) = result.unwrap_or_else(|e| panic!("G2 GPU MSM failed: {e}"));
// Convert result from Montgomery form back to arkworks types
let normalized = gpu_result.from_montgomery_normalized();
let x_fp2 = normalized.X();
let y_fp2 = normalized.Y();
let z_fp2 = normalized.Z();
let z_is_zero =
z_fp2.c0.limb.iter().all(|&limb| limb == 0) && z_fp2.c1.limb.iter().all(|&limb| limb == 0);
if z_is_zero {
return G2::ZERO;
}
let x = fq2_from_cuda_fp2(&x_fp2);
let y = fq2_from_cuda_fp2(&y_fp2);
let z = fq2_from_cuda_fp2(&z_fp2);
G2 {
inner: G2Projective::new_unchecked(x, y, z),
}
}

View File

@@ -0,0 +1,7 @@
//! Integration tests for zk-cuda-backend
//!
//! These tests verify that zk-cuda-backend produces correct results
//! by comparing against tfhe-zk-pok's CPU implementation.
mod prove_verify_stress;
mod zk_cuda_backend;

View File

@@ -0,0 +1,417 @@
//! Stress tests comparing GPU and CPU prove/verify for PKE v1 and v2.
//!
//! These tests mirror the `test_pke()` functions in `proofs/pke.rs` (v1) and
//! `proofs/pke_v2/mod.rs` (v2), running GPU prove/verify alongside CPU to
//! verify that:
//! - Serialized proofs are byte-identical between CPU and GPU.
//! - CPU and GPU verifiers agree on accept/reject for every input.
//! - Both directions are covered: GPU-prove→GPU-verify and CPU-prove→GPU-verify.
//!
//! Each test iterates over 3 CRS variants (original, compressed, not-compressed)
//! × 32 combinations of invalid inputs (e1, e2, r, m, metadata) × 2 compute
//! loads (Proof, Verify), yielding 192 iterations per test. The v2 test doubles
//! that by also sweeping both pairing modes (TwoSteps, Batched).
use crate::curve_api::Bls12_446;
use crate::proofs::pke_v2::VerificationPairingMode;
use crate::proofs::test::*;
use crate::proofs::{pke, pke_v2, ComputeLoad};
use rand::rngs::StdRng;
use rand::{thread_rng, Rng, SeedableRng};
type Curve = Bls12_446;
/// Exhaustive GPU-vs-CPU equivalence test for PKE v1.
///
/// For every combination of (CRS variant, invalid flag) we:
/// 1. Prove on CPU and GPU with identical inputs.
/// 2. Assert the serialized proofs are byte-identical.
/// 3. Verify the GPU proof on both CPU and GPU, checking that both agree on accept/reject and
/// that the outcome matches the expected validity.
#[test]
fn test_pke_v1_gpu_cpu_equivalence() {
let params = crate::proofs::pke::tests::PKEV1_TEST_PARAMS;
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = params;
let seed = thread_rng().gen();
let rng = &mut StdRng::seed_from_u64(seed);
// Generate a valid test case: keys, randomness, noise, and plaintext
let valid_testcase = PkeTestcase::gen(rng, params);
let ct = valid_testcase.encrypt(params);
// Independent witnesses for rejection testing. The values are
// in-range (same distribution as `testcase`), but they are the *wrong*
// witnesses for `ct` — proving with them produces a proof that does not
// satisfy the ciphertext-witness relation, so verification must reject.
let invalid_testcase = PkeTestcase::gen(rng, params);
// CRS k > message count k exercises the path where the CRS is larger
// than strictly needed, as happens in production
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
// Same three CRS variants as in `test_pke()`: original, round-tripped
// through compressed serde, and round-tripped through uncompressed serde.
let original_public_param =
pke::crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
let public_param_that_was_compressed =
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
let public_param_that_was_not_compressed =
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
// Runs for all combinations: 3 CRS variants × 32 invalid-witness flags
let cases = itertools::iproduct!(
[
original_public_param,
public_param_that_was_compressed,
public_param_that_was_not_compressed
],
[false, true], // e1
[false, true], // e2
[false, true], // m
[false, true], // r
[false, true] // metadata
);
for (
public_param,
use_invalid_e1,
use_invalid_e2,
use_invalid_m,
use_invalid_r,
use_invalid_metadata,
) in cases
{
// Build the commit, substituting invalid witnesses where flagged.
let (public_commit, private_commit) = pke::commit(
valid_testcase.a.clone(),
valid_testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
(if use_invalid_r {
&invalid_testcase.r
} else {
&valid_testcase.r
})
.clone(),
(if use_invalid_e1 {
&invalid_testcase.e1
} else {
&valid_testcase.e1
})
.clone(),
(if use_invalid_m {
&invalid_testcase.m
} else {
&valid_testcase.m
})
.clone(),
(if use_invalid_e2 {
&invalid_testcase.e2
} else {
&valid_testcase.e2
})
.clone(),
&public_param,
);
// ComputeLoad::Proof shifts work to the prover; ComputeLoad::Verify
// shifts it to the verifier. Both must yield identical results.
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
// Produce proofs from identical inputs on CPU and GPU
let cpu_proof = pke::prove(
(&public_param, &public_commit),
&private_commit,
&valid_testcase.metadata,
load,
&seed.to_be_bytes(),
);
let gpu_proof = pke::gpu::prove(
(&public_param, &public_commit),
&private_commit,
&valid_testcase.metadata,
load,
&seed.to_be_bytes(),
);
// GPU MSM is exact integer arithmetic, so the serialized proofs
// must be byte-for-byte identical
assert_eq!(
bincode::serialize(&cpu_proof).unwrap(),
bincode::serialize(&gpu_proof).unwrap(),
"v1 proof mismatch: load={load}, invalid_e1={use_invalid_e1}, \
invalid_e2={use_invalid_e2}, invalid_m={use_invalid_m}, \
invalid_r={use_invalid_r}",
);
// When invalid metadata is used at verification time (but not at
// proving time), the proof is valid but Fiat-Shamir binding fails
let verify_metadata = if use_invalid_metadata {
&invalid_testcase.metadata
} else {
&valid_testcase.metadata
};
// Any single invalid input should cause verification to reject
let should_fail = use_invalid_e1
|| use_invalid_e2
|| use_invalid_r
|| use_invalid_m
|| use_invalid_metadata;
// --- Verify GPU proof on CPU ---
let cpu_verify_result =
pke::verify(&gpu_proof, (&public_param, &public_commit), verify_metadata);
assert_eq!(
cpu_verify_result.is_err(),
should_fail,
"v1 CPU verify mismatch: load={load}, should_fail={should_fail}",
);
// --- Verify GPU proof on GPU ---
let gpu_verify_result =
pke::gpu::verify(&gpu_proof, (&public_param, &public_commit), verify_metadata);
assert_eq!(
gpu_verify_result.is_err(),
should_fail,
"v1 GPU verify mismatch: load={load}, should_fail={should_fail}",
);
// CPU and GPU verifiers must produce the same Result value
assert_eq!(
cpu_verify_result, gpu_verify_result,
"v1 CPU/GPU verify disagree: load={load}",
);
// --- Cross-direction: GPU verify of CPU-produced proof ---
// Although proofs are byte-identical, this exercises pke::gpu::verify
// with a proof object constructed entirely on the CPU side, ensuring
// the GPU verifier does not depend on any GPU-side proof metadata.
let gpu_verify_cpu_proof =
pke::gpu::verify(&cpu_proof, (&public_param, &public_commit), verify_metadata);
assert_eq!(
gpu_verify_cpu_proof, cpu_verify_result,
"v1 GPU-verify-of-CPU-proof disagrees with CPU-verify: load={load}",
);
}
}
}
/// Exhaustive GPU-vs-CPU equivalence test for PKE v2.
///
/// Same structure as the v1 test but exercises `pke_v2` functions. Key
/// differences from v1:
/// - Uses `PKEV2_TEST_PARAMS` and `pke_v2::*` functions.
/// - Seed bytes are little-endian (`to_le_bytes`) per v2 convention.
/// - Verify takes a `VerificationPairingMode`; we test both `TwoSteps` and `Batched` and assert
/// they agree.
#[test]
fn test_pke_v2_gpu_cpu_equivalence() {
let params = crate::proofs::pke_v2::tests::PKEV2_TEST_PARAMS;
let PkeTestParameters {
d,
k,
B,
q,
t,
msbs_zero_padding_bit_count,
} = params;
let seed = thread_rng().gen();
let rng = &mut StdRng::seed_from_u64(seed);
// Generate a valid test case: keys, randomness, noise, and plaintext
let testcase = PkeTestcase::gen(rng, params);
let ct = testcase.encrypt(params);
// Independent witnesses for rejection testing. The values are
// in-range (same distribution as `testcase`), but they are the *wrong*
// witnesses for `ct` — proving with them produces a proof that does not
// satisfy the ciphertext-witness relation, so verification must reject.
let invalid_testcase = PkeTestcase::gen(rng, params);
// CRS k > message count k exercises the path where the CRS is larger
// than strictly needed, as happens in production
let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
// Same three CRS variants as in `test_pke()`: original, round-tripped
// through compressed serde, and round-tripped through uncompressed serde.
let original_public_param =
pke_v2::crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
let public_param_that_was_compressed =
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
let public_param_that_was_not_compressed =
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
// Sweep all combinations: 3 CRS variants × 32 invalid-witness flags
let cases = itertools::iproduct!(
[
original_public_param,
public_param_that_was_compressed,
public_param_that_was_not_compressed
],
[false, true], // r
[false, true], // e1
[false, true], // e2
[false, true], // m
[false, true] // metadata
);
for (
public_param,
use_invalid_r,
use_invalid_e1,
use_invalid_e2,
use_invalid_m,
use_invalid_metadata,
) in cases
{
// Build the commit, substituting invalid witnesses where flagged.
let (public_commit, private_commit) = pke_v2::commit(
testcase.a.clone(),
testcase.b.clone(),
ct.c1.clone(),
ct.c2.clone(),
(if use_invalid_r {
&invalid_testcase.r
} else {
&testcase.r
})
.clone(),
(if use_invalid_e1 {
&invalid_testcase.e1
} else {
&testcase.e1
})
.clone(),
(if use_invalid_m {
&invalid_testcase.m
} else {
&testcase.m
})
.clone(),
(if use_invalid_e2 {
&invalid_testcase.e2
} else {
&testcase.e2
})
.clone(),
&public_param,
);
// ComputeLoad::Proof shifts work to the prover; ComputeLoad::Verify
// shifts it to the verifier. Both must yield identical results.
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
// Produce proofs from identical inputs on CPU and GPU.
// v2 convention: seed bytes are little-endian (to_le_bytes).
let cpu_proof = pke_v2::prove(
(&public_param, &public_commit),
&private_commit,
&testcase.metadata,
load,
&seed.to_le_bytes(),
);
let gpu_proof = pke_v2::gpu::prove(
(&public_param, &public_commit),
&private_commit,
&testcase.metadata,
load,
&seed.to_le_bytes(),
);
// GPU MSM is exact integer arithmetic, so the serialized proofs
// must be byte-for-byte identical
assert_eq!(
bincode::serialize(&cpu_proof).unwrap(),
bincode::serialize(&gpu_proof).unwrap(),
"v2 proof mismatch: load={load}, invalid_r={use_invalid_r}, \
invalid_e1={use_invalid_e1}, invalid_e2={use_invalid_e2}, \
invalid_m={use_invalid_m}",
);
// When invalid metadata is used at verification time (but not at
// proving time), the proof is valid but Fiat-Shamir binding fails
let verify_metadata = if use_invalid_metadata {
&invalid_testcase.metadata
} else {
&testcase.metadata
};
// Any single invalid input should cause verification to reject
let should_fail = use_invalid_e1
|| use_invalid_e2
|| use_invalid_r
|| use_invalid_m
|| use_invalid_metadata;
// v2 supports two pairing strategies for verification:
// - TwoSteps: two independent pairing checks
// - Batched: single batched pairing check (faster but same result)
// Both must agree with each other and with the CPU verifier.
for pairing_mode in [
VerificationPairingMode::TwoSteps,
VerificationPairingMode::Batched,
] {
// --- Verify GPU proof on CPU ---
let cpu_verify_result = pke_v2::verify(
&gpu_proof,
(&public_param, &public_commit),
verify_metadata,
pairing_mode,
);
assert_eq!(
cpu_verify_result.is_err(),
should_fail,
"v2 CPU verify mismatch: load={load}, mode={pairing_mode:?}, \
should_fail={should_fail}",
);
// --- Verify GPU proof on GPU ---
let gpu_verify_result = pke_v2::gpu::verify(
&gpu_proof,
(&public_param, &public_commit),
verify_metadata,
pairing_mode,
);
assert_eq!(
gpu_verify_result.is_err(),
should_fail,
"v2 GPU verify mismatch: load={load}, mode={pairing_mode:?}, \
should_fail={should_fail}",
);
// CPU and GPU verifiers must produce the same Result value
assert_eq!(
cpu_verify_result, gpu_verify_result,
"v2 CPU/GPU verify disagree: load={load}, mode={pairing_mode:?}",
);
// --- Cross-direction: GPU verify of CPU-produced proof ---
// Although proofs are byte-identical, this exercises
// pke_v2::gpu::verify with a proof object constructed entirely
// on the CPU side, ensuring the GPU verifier does not depend
// on any GPU-side proof metadata.
let gpu_verify_cpu_proof = pke_v2::gpu::verify(
&cpu_proof,
(&public_param, &public_commit),
verify_metadata,
pairing_mode,
);
assert_eq!(
gpu_verify_cpu_proof, cpu_verify_result,
"v2 GPU-verify-of-CPU-proof disagrees with CPU-verify: load={load}, \
mode={pairing_mode:?}",
);
}
}
}
}

View File

@@ -0,0 +1,392 @@
//! Tests comparing zk-cuda-backend MSM results against tfhe-zk-pok CPU implementation
use crate::curve_api::bls12_446::{Zp, G1, G2};
use crate::curve_api::CurveGroupOps;
use crate::gpu::{
g1_affine_from_zk_cuda, g1_affine_to_zk_cuda, g2_affine_from_zk_cuda, g2_affine_to_zk_cuda,
};
use tfhe_cuda_backend::cuda_bind::{cuda_create_stream, cuda_destroy_stream};
use zk_cuda_backend::conversions::{g1_affine_from_montgomery, g2_affine_from_montgomery};
use zk_cuda_backend::{
G1Affine as ZkG1Affine, G1Projective as ZkG1Projective, G2Affine as ZkG2Affine,
G2Projective as ZkG2Projective, Scalar as ZkScalar,
};
/// Helper function to compute triangular number: N * (N+1) / 2
fn triangular_number(n: u64) -> u64 {
n * (n + 1) / 2
}
/// BLS12-446 scalar field modulus minus one (r - 1), little-endian limbs.
/// Used by canceling-scalar tests: 1*G + (r-1)*G = r*G = O.
const R_MINUS_1: [u64; 5] = [
0x0428001400040000,
0x7bb9b0e8d8ca3461,
0xd04c98ccc4c050bc,
0x7995b34995830fa4,
0x00000511b70539f2,
];
/// Check that a G1 projective point has Z == 0 (point at infinity).
/// Converts from Montgomery form first so limbs are directly comparable.
fn g1_proj_z_is_zero(p: &ZkG1Projective) -> bool {
let z = p.from_montgomery_normalized().Z();
z.limb.iter().all(|&limb| limb == 0)
}
/// Check that a G2 projective point has Z == 0 (point at infinity).
/// G2 lives over Fp2, so both c0 and c1 components must be zero.
fn g2_proj_z_is_zero(p: &ZkG2Projective) -> bool {
let z = p.from_montgomery_normalized().Z();
z.c0.limb.iter().all(|&limb| limb == 0) && z.c1.limb.iter().all(|&limb| limb == 0)
}
// =============================================================================
// Macro: generates G1 and G2 MSM test variants from a single template.
//
// Parameters:
// $Group - tfhe-zk-pok curve group (G1 or G2)
// $ZkAffine - zk-cuda-backend affine type (ZkG1Affine or ZkG2Affine)
// $ZkProjective - zk-cuda-backend projective type
// $to_zk - conversion: tfhe-zk-pok affine -> zk-cuda affine
// $from_zk - conversion: zk-cuda affine -> tfhe-zk-pok affine
// $from_mont - Montgomery-to-normal conversion for zk-cuda affine
// $proj_z_is_zero - fn to check projective Z == 0 (group-specific)
// $test_* - test function names (avoids `paste` dependency)
// =============================================================================
macro_rules! msm_tests {
(
group: $Group:ty,
zk_affine: $ZkAffine:ty,
zk_proj: $ZkProjective:ty,
to_zk: $to_zk:expr,
from_zk: $from_zk:expr,
from_mont: $from_mont:expr,
proj_z_is_zero: $proj_z_is_zero:expr,
label: $label:expr,
test_large_n: $test_large_n:ident,
test_zero_scalars: $test_zero_scalars:ident,
test_canceling: $test_canceling:ident,
test_infinity_input: $test_infinity_input:ident
) => {
#[test]
fn $test_large_n() {
const MAX_N: u64 = 100;
let gen = <$Group>::GENERATOR.normalize();
let gen_zk = $to_zk(&gen);
// Probe CUDA availability with a trivial MSM before the sweep
{
let probe_points: Vec<$ZkAffine> = vec![gen_zk];
let probe_scalars: Vec<ZkScalar> = vec![ZkScalar::from_u64(1)];
// SAFETY: gpu_index 0 is valid (checked by test setup)
let probe_stream = unsafe { cuda_create_stream(0) };
if <$ZkProjective>::msm(&probe_points, &probe_scalars, probe_stream, 0, false)
.is_err()
{
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(probe_stream, 0) };
eprintln!("CUDA not available - Skipping test");
return;
}
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(probe_stream, 0) };
}
// Sweep N from 1..=MAX_N: points = [G; N], scalars = [1..=N].
// Expected result = G * triangular(N).
for n in 1..=MAX_N {
let points: Vec<$ZkAffine> = (0..n).map(|_| gen_zk).collect();
let scalars: Vec<ZkScalar> = (1..=n).map(ZkScalar::from_u64).collect();
// SAFETY: gpu_index 0 is valid (checked by test setup)
let stream = unsafe { cuda_create_stream(0) };
let (gpu_result_proj, _size_tracker) =
<$ZkProjective>::msm(&points, &scalars, stream, 0, false)
.unwrap_or_else(|_| panic!("CUDA MSM failed at N={}", n));
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, 0) };
let gpu_result = $from_mont(&gpu_result_proj.to_affine());
let expected_scalar = Zp::from_u64(triangular_number(n));
let cpu_result = <$Group>::GENERATOR.mul_scalar(expected_scalar).normalize();
let gpu_tfhe = $from_zk(&gpu_result);
assert_eq!(
$to_zk(&gpu_tfhe).is_infinity(),
$to_zk(&cpu_result).is_infinity(),
"{} MSM large_n: N={} infinity mismatch",
$label,
n
);
if !gpu_result.is_infinity() {
assert_eq!(
$to_zk(&gpu_tfhe).x(),
$to_zk(&cpu_result).x(),
"{} MSM large_n: N={} x mismatch",
$label,
n
);
assert_eq!(
$to_zk(&gpu_tfhe).y(),
$to_zk(&cpu_result).y(),
"{} MSM large_n: N={} y mismatch",
$label,
n
);
}
}
}
#[test]
fn $test_zero_scalars() {
let gen = <$Group>::GENERATOR.normalize();
let gen_zk = $to_zk(&gen);
// All-zero scalars: 0*G + 0*G + ... = O
let points: Vec<$ZkAffine> = vec![gen_zk; 5];
let scalars: Vec<ZkScalar> = vec![ZkScalar::from_u64(0); 5];
let gpu_index = 0;
// SAFETY: gpu_index 0 is valid (checked by test setup)
let stream = unsafe { cuda_create_stream(gpu_index) };
let result_proj =
match <$ZkProjective>::msm(&points, &scalars, stream, gpu_index, false) {
Ok((result, _size_tracker)) => result,
Err(e) => {
eprintln!("CUDA MSM failed: {} - Skipping test", e);
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
return;
}
};
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
let is_infinity = ($proj_z_is_zero)(&result_proj);
assert!(
is_infinity,
"{} MSM with all-zero scalars should return infinity",
$label
);
}
#[test]
fn $test_canceling() {
let gen = <$Group>::GENERATOR.normalize();
let gen_zk = $to_zk(&gen);
// 1*G + (r-1)*G = r*G = O
let points: Vec<$ZkAffine> = vec![gen_zk, gen_zk];
let scalars: Vec<ZkScalar> = vec![ZkScalar::from_u64(1), ZkScalar::from(R_MINUS_1)];
let gpu_index = 0;
// SAFETY: gpu_index 0 is valid (checked by test setup)
let stream = unsafe { cuda_create_stream(gpu_index) };
let result_proj =
match <$ZkProjective>::msm(&points, &scalars, stream, gpu_index, false) {
Ok((result, _size_tracker)) => result,
Err(e) => {
eprintln!("CUDA MSM failed: {} - Skipping test", e);
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
return;
}
};
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
let is_infinity = ($proj_z_is_zero)(&result_proj);
assert!(
is_infinity,
"{} MSM with canceling scalars (1*G + (r-1)*G) should return infinity",
$label
);
}
#[test]
fn $test_infinity_input() {
let gen = <$Group>::GENERATOR.normalize();
let gen_zk = $to_zk(&gen);
let inf = <$ZkAffine>::infinity();
// 5*O + 3*G + 7*O = 3*G (infinity inputs contribute nothing)
let points: Vec<$ZkAffine> = vec![inf, gen_zk, inf];
let scalars: Vec<ZkScalar> = vec![
ZkScalar::from_u64(5),
ZkScalar::from_u64(3),
ZkScalar::from_u64(7),
];
let gpu_index = 0;
// SAFETY: gpu_index 0 is valid (checked by test setup)
let stream = unsafe { cuda_create_stream(gpu_index) };
let result_proj =
match <$ZkProjective>::msm(&points, &scalars, stream, gpu_index, false) {
Ok((result, _size_tracker)) => result,
Err(e) => {
eprintln!("CUDA MSM failed: {} - Skipping test", e);
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
return;
}
};
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
let expected = <$Group>::GENERATOR.mul_scalar(Zp::from_u64(3)).normalize();
let expected_zk = $to_zk(&expected);
let result = $from_mont(&result_proj.to_affine());
assert_eq!(
result.x(),
expected_zk.x(),
"{} MSM with infinity points: x mismatch",
$label
);
assert_eq!(
result.y(),
expected_zk.y(),
"{} MSM with infinity points: y mismatch",
$label
);
}
};
}
// Generate G1 MSM tests
msm_tests! {
group: G1,
zk_affine: ZkG1Affine,
zk_proj: ZkG1Projective,
to_zk: g1_affine_to_zk_cuda,
from_zk: g1_affine_from_zk_cuda,
from_mont: g1_affine_from_montgomery,
proj_z_is_zero: g1_proj_z_is_zero,
label: "G1",
test_large_n: test_g1_msm_large_n,
test_zero_scalars: test_g1_msm_zero_scalars_returns_infinity,
test_canceling: test_g1_msm_canceling_scalars_returns_infinity,
test_infinity_input: test_g1_msm_infinity_point_input
}
// Generate G2 MSM tests
msm_tests! {
group: G2,
zk_affine: ZkG2Affine,
zk_proj: ZkG2Projective,
to_zk: g2_affine_to_zk_cuda,
from_zk: g2_affine_from_zk_cuda,
from_mont: g2_affine_from_montgomery,
proj_z_is_zero: g2_proj_z_is_zero,
label: "G2",
test_large_n: test_g2_msm_large_n,
test_zero_scalars: test_g2_msm_zero_scalars_returns_infinity,
test_canceling: test_g2_msm_canceling_scalars_returns_infinity,
test_infinity_input: test_g2_msm_infinity_point_input
}
// =============================================================================
// Non-macro tests: these test unique behavior not shared between G1/G2
// =============================================================================
#[test]
fn test_g1_conversion_roundtrip() {
let tfhe_g1_gen = G1::GENERATOR.normalize();
let zk_g1 = g1_affine_to_zk_cuda(&tfhe_g1_gen);
let tfhe_g1_again = g1_affine_from_zk_cuda(&zk_g1);
let zk_g1_again = g1_affine_to_zk_cuda(&tfhe_g1_again);
assert_eq!(zk_g1.x(), zk_g1_again.x());
assert_eq!(zk_g1.y(), zk_g1_again.y());
}
#[test]
fn test_g2_conversion_roundtrip() {
let tfhe_g2_gen = G2::GENERATOR.normalize();
let zk_g2 = g2_affine_to_zk_cuda(&tfhe_g2_gen);
let tfhe_g2_again = g2_affine_from_zk_cuda(&zk_g2);
let zk_g2_again = g2_affine_to_zk_cuda(&tfhe_g2_again);
assert_eq!(zk_g2.x(), zk_g2_again.x());
assert_eq!(zk_g2.y(), zk_g2_again.y());
}
#[test]
fn test_g1_msm_multi_limb_scalar() {
let tfhe_g1_gen = G1::GENERATOR.normalize();
let g1_gen = g1_affine_to_zk_cuda(&tfhe_g1_gen);
// Scalar = 2^64 (requires 2 limbs) to exercise multi-limb scalar handling
let scalar = ZkScalar::new([0u64, 1u64, 0u64, 0u64, 0u64]);
let points: Vec<ZkG1Affine> = vec![g1_gen];
let scalars: Vec<ZkScalar> = vec![scalar];
let gpu_index = 0;
// SAFETY: gpu_index 0 is valid (checked by test setup)
let stream = unsafe { cuda_create_stream(gpu_index) };
let gpu_result_proj = match ZkG1Projective::msm(&points, &scalars, stream, gpu_index, false) {
Ok((result, _size_tracker)) => result,
Err(e) => {
eprintln!("CUDA MSM failed: {} - Skipping test", e);
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
return;
}
};
// SAFETY: stream was created above and is not used after this point
unsafe { cuda_destroy_stream(stream, gpu_index) };
let gpu_result_mont = gpu_result_proj.to_affine();
let gpu_result = g1_affine_from_montgomery(&gpu_result_mont);
let cpu_scalar = Zp::from_bigint([0u64, 1u64, 0u64, 0u64, 0u64]);
let cpu_result = G1::GENERATOR.mul_scalar(cpu_scalar).normalize();
let gpu_result_tfhe = g1_affine_from_zk_cuda(&gpu_result);
assert_eq!(
g1_affine_to_zk_cuda(&gpu_result_tfhe).x(),
g1_affine_to_zk_cuda(&cpu_result).x(),
"G1 MSM 2^64 scalar x mismatch"
);
assert_eq!(
g1_affine_to_zk_cuda(&gpu_result_tfhe).y(),
g1_affine_to_zk_cuda(&cpu_result).y(),
"G1 MSM 2^64 scalar y mismatch"
);
}
/// Test scalar validation functions
#[test]
fn test_scalar_validation() {
// Valid scalar (less than modulus)
let valid_scalar = ZkScalar::from_u64(12345);
assert!(valid_scalar.is_valid(), "Small scalar should be valid");
// Scalar equal to modulus minus 1 (r - 1)
let max_valid = ZkScalar::from(R_MINUS_1);
assert!(max_valid.is_valid(), "r-1 should be valid");
// Scalar equal to modulus (invalid)
let r: [u64; 5] = [
0x0428001400040001,
0x7bb9b0e8d8ca3461,
0xd04c98ccc4c050bc,
0x7995b34995830fa4,
0x00000511b70539f2,
];
let equal_to_r = ZkScalar::from(r);
assert!(!equal_to_r.is_valid(), "r should be invalid");
// Test reduction
let reduced = equal_to_r.reduce_once();
assert!(
reduced.is_valid(),
"Reduced scalar should be valid (equals zero)"
);
}

View File

@@ -1,5 +1,7 @@
pub mod curve_446;
pub mod curve_api;
#[cfg(feature = "gpu-experimental")]
pub mod gpu;
pub mod proofs;
pub mod serialization;

View File

@@ -15,7 +15,7 @@ use tfhe_versionable::Versionize;
#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Versionize)]
#[repr(transparent)]
pub(crate) struct OneBased<T: ?Sized>(T);
pub(crate) struct OneBased<T: ?Sized>(pub(crate) T);
/// The proving scheme is available in 2 versions, one that puts more load on the prover and one
/// that puts more load on the verifier
@@ -153,7 +153,7 @@ pub(crate) enum ProofSanityCheckMode {
/// Check the preconditions of the pke proof before computing it. Panic if one of the conditions
/// does not hold.
#[allow(clippy::too_many_arguments)]
fn assert_pke_proof_preconditions(
pub(crate) fn assert_pke_proof_preconditions(
a: &[i64],
b: &[i64],
c1: &[i64],
@@ -179,7 +179,7 @@ fn assert_pke_proof_preconditions(
/// q (modulus) is encoded on 64b, with 0 meaning 2^64. This converts the encoded q to its effective
/// value for modular operations.
fn decode_q(q: u64) -> u128 {
pub(crate) fn decode_q(q: u64) -> u128 {
if q == 0 {
1u128 << 64
} else {
@@ -193,7 +193,7 @@ fn decode_q(q: u64) -> u128 {
/// implies
/// phi(r1) = (rot(a) * phi(bar(r)) + phi(e1) - phi(c1)) / q
/// (phi is the function that maps a polynomial to its coeffs vector)
fn compute_r1(
pub(crate) fn compute_r1(
e1: &[i64],
c1: &[i64],
a: &[i64],
@@ -233,7 +233,7 @@ fn compute_r1(
/// r2_i = (phi_[d - i](b).T * phi(bar(r)) + delta * m_i + e2_i - c2_i) / q
/// (phi is the function that maps a polynomial to its coeffs vector)
#[allow(clippy::too_many_arguments)]
fn compute_r2(
pub(crate) fn compute_r2(
e2: &[i64],
c2: &[i64],
m: &[i64],
@@ -314,7 +314,7 @@ impl Sid {
Self(Some(rng.gen()))
}
fn to_le_bytes(self) -> SidBytes {
pub(crate) fn to_le_bytes(self) -> SidBytes {
self.0
.map(|val| SidBytes(Some(val.to_le_bytes())))
.unwrap_or_default()
@@ -322,10 +322,10 @@ impl Sid {
}
#[derive(Default)]
struct SidBytes(Option<[u8; 16]>);
pub(crate) struct SidBytes(Option<[u8; 16]>);
impl SidBytes {
fn as_slice(&self) -> &[u8] {
pub(crate) fn as_slice(&self) -> &[u8] {
self.0.as_ref().map(|val| val.as_slice()).unwrap_or(&[])
}
}
@@ -367,7 +367,7 @@ fn get_or_init_pools() -> &'static Vec<VerificationPool> {
///
/// When multiple calls of this function are made in parallel, each of them is executed in a
/// dedicated pool, if there is enough free cores on the CPU.
fn run_in_pool<OP, R>(f: OP) -> R
pub(crate) fn run_in_pool<OP, R>(f: OP) -> R
where
OP: FnOnce() -> R + Send,
R: Send,
@@ -422,7 +422,7 @@ pub mod pke;
pub mod pke_v2;
#[cfg(test)]
mod test {
pub(crate) mod test {
#![allow(non_snake_case)]
use std::fmt::Display;
use std::num::Wrapping;
@@ -439,9 +439,9 @@ mod test {
use crate::proofs::decode_q;
// One of our usecases uses 320 bits of additional metadata
pub(super) const METADATA_LEN: usize = (320 / u8::BITS) as usize;
pub(crate) const METADATA_LEN: usize = (320 / u8::BITS) as usize;
pub(super) enum Compress {
pub(crate) enum Compress {
Yes,
No,
}
@@ -453,7 +453,7 @@ mod test {
Nominal,
}
pub(super) fn serialize_then_deserialize<
pub(crate) fn serialize_then_deserialize<
Params: Compressible + Serialize + for<'de> Deserialize<'de>,
>(
public_params: &Params,
@@ -503,36 +503,36 @@ mod test {
/// Parameters needed for a PKE zk proof test
#[derive(Copy, Clone)]
pub(super) struct PkeTestParameters {
pub(super) d: usize,
pub(super) k: usize,
pub(super) B: u64,
pub(super) q: u64,
pub(super) t: u64,
pub(super) msbs_zero_padding_bit_count: u64,
pub(crate) struct PkeTestParameters {
pub(crate) d: usize,
pub(crate) k: usize,
pub(crate) B: u64,
pub(crate) q: u64,
pub(crate) t: u64,
pub(crate) msbs_zero_padding_bit_count: u64,
}
/// An encrypted PKE ciphertext
pub struct PkeTestCiphertext {
pub(super) c1: Vec<i64>,
pub(super) c2: Vec<i64>,
pub(crate) struct PkeTestCiphertext {
pub(crate) c1: Vec<i64>,
pub(crate) c2: Vec<i64>,
}
/// A randomly generated testcase of pke encryption
#[derive(Clone)]
pub(super) struct PkeTestcase {
pub(super) a: Vec<i64>,
pub(super) e1: Vec<i64>,
pub(super) e2: Vec<i64>,
pub(super) r: Vec<i64>,
pub(super) m: Vec<i64>,
pub(super) b: Vec<i64>,
pub(super) metadata: [u8; METADATA_LEN],
pub(super) s: Vec<i64>,
pub(crate) struct PkeTestcase {
pub(crate) a: Vec<i64>,
pub(crate) e1: Vec<i64>,
pub(crate) e2: Vec<i64>,
pub(crate) r: Vec<i64>,
pub(crate) m: Vec<i64>,
pub(crate) b: Vec<i64>,
pub(crate) metadata: [u8; METADATA_LEN],
pub(crate) s: Vec<i64>,
}
impl PkeTestcase {
pub(super) fn gen(rng: &mut StdRng, params: PkeTestParameters) -> Self {
pub(crate) fn gen(rng: &mut StdRng, params: PkeTestParameters) -> Self {
let PkeTestParameters {
d,
k,
@@ -587,7 +587,7 @@ mod test {
}
}
pub(super) fn sk_encrypt_zero(
pub(crate) fn sk_encrypt_zero(
&self,
params: PkeTestParameters,
rng: &mut StdRng,
@@ -618,7 +618,7 @@ mod test {
}
/// Decrypt a ciphertext list
pub(super) fn decrypt(
pub(crate) fn decrypt(
&self,
ct: &PkeTestCiphertext,
params: PkeTestParameters,
@@ -659,7 +659,7 @@ mod test {
}
/// Encrypt using compact pke, the encryption is validated by doing a decryption
pub(super) fn encrypt(&self, params: PkeTestParameters) -> PkeTestCiphertext {
pub(crate) fn encrypt(&self, params: PkeTestParameters) -> PkeTestCiphertext {
let ct = self.encrypt_unchecked(params);
// Check decryption
@@ -671,7 +671,7 @@ mod test {
}
/// Encrypt using compact pke, without checking that the decryption is correct
pub(super) fn encrypt_unchecked(&self, params: PkeTestParameters) -> PkeTestCiphertext {
pub(crate) fn encrypt_unchecked(&self, params: PkeTestParameters) -> PkeTestCiphertext {
let PkeTestParameters {
d,
k,

View File

@@ -0,0 +1,438 @@
//! GPU-accelerated prove/verify for PKE v1.
//!
//! `prove` duplicates the logic of [`super::prove_impl`] but replaces every
//! `multi_mul_scalar` call with the GPU-accelerated [`crate::gpu::g1_msm_gpu`]
//! / [`crate::gpu::g2_msm_gpu`]. `verify` simply delegates to the CPU
//! verifier since v1 verification contains no MSM calls.
use crate::curve_api::bls12_446::{Zp, G1, G2};
use crate::curve_api::{Bls12_446, CurveGroupOps, FieldOps};
use crate::gpu::select_gpu_for_msm;
use crate::proofs::{
assert_pke_proof_preconditions, compute_r1, compute_r2, decode_q, ComputeLoad, OneBased,
ProofSanityCheckMode,
};
use super::{
bit_iter, compute_a_theta, ComputeLoadProofFields, PrivateCommit, Proof, PublicCommit,
PublicParams,
};
/// GPU-accelerated proof generation for PKE v1.
///
/// Identical to [`super::prove`] but dispatches MSM to the GPU via
/// [`crate::gpu::g1_msm_gpu`] / [`crate::gpu::g2_msm_gpu`].
pub fn prove(
public: (&PublicParams<Bls12_446>, &PublicCommit<Bls12_446>),
private_commit: &PrivateCommit<Bls12_446>,
metadata: &[u8],
load: ComputeLoad,
seed: &[u8],
) -> Proof<Bls12_446> {
prove_impl(
public,
private_commit,
metadata,
load,
seed,
ProofSanityCheckMode::Panic,
)
}
/// GPU-accelerated verification for PKE v1.
///
/// PKE v1 verification has no MSM calls, so this delegates directly to the CPU
/// verifier.
#[allow(clippy::result_unit_err)]
pub fn verify(
proof: &Proof<Bls12_446>,
public: (&PublicParams<Bls12_446>, &PublicCommit<Bls12_446>),
metadata: &[u8],
) -> Result<(), ()> {
super::verify(proof, public, metadata)
}
// ---------------------------------------------------------------------------
// prove_impl GPU variant of super::prove_impl
// ---------------------------------------------------------------------------
#[allow(clippy::too_many_arguments)]
fn prove_impl(
public: (&PublicParams<Bls12_446>, &PublicCommit<Bls12_446>),
private_commit: &PrivateCommit<Bls12_446>,
metadata: &[u8],
load: ComputeLoad,
seed: &[u8],
sanity_check_mode: ProofSanityCheckMode,
) -> Proof<Bls12_446> {
let &PublicParams {
ref g_lists,
big_d: big_d_max,
n,
d,
b,
b_r,
q,
t,
msbs_zero_padding_bit_count,
k: k_max,
sid,
domain_separators: ref ds,
} = public.0;
let g_list = &g_lists.g_list;
let g_hat_list = &g_lists.g_hat_list;
let b_i = b;
let PublicCommit { a, b, c1, c2, .. } = public.1;
let PrivateCommit { r, e1, m, e2, .. } = private_commit;
let k = c2.len();
let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count;
let decoded_q = decode_q(q);
let big_d = d
+ k * effective_t_for_decomposition.ilog2() as usize
+ (d + k) * (2 + b_i.ilog2() as usize + b_r.ilog2() as usize);
if sanity_check_mode == ProofSanityCheckMode::Panic {
assert_pke_proof_preconditions(a, b, c1, e1, c2, e2, d, k_max, big_d, big_d_max);
}
// FIXME: div_round
let delta = {
// delta takes the encoding with the padding bit
// decoded_q <= 2^64 and t >= 1, so the quotient always fits in u64
(decoded_q / t as u128) as u64
};
let g = G1::GENERATOR;
let g_hat = G2::GENERATOR;
let mut gamma_list = [Zp::ZERO; 2];
Zp::hash(&mut gamma_list, &[ds.hash_gamma(), seed]);
let [gamma, gamma_y] = gamma_list;
let r1 = compute_r1(e1, c1, a, r, d, decoded_q);
let r2 = compute_r2(e2, c2, m, b, r, d, delta, decoded_q);
let mut w = vec![false; n];
// Reinterpret as unsigned for bit decomposition; bit pattern is preserved,
// which is correct for torus arithmetic
let u64 = |x: i64| x as u64;
w[..big_d]
.iter_mut()
.zip(
r.iter()
.rev()
.flat_map(|&r| bit_iter(u64(r), 1))
.chain(
m.iter()
.flat_map(|&m| bit_iter(u64(m), effective_t_for_decomposition.ilog2())),
)
.chain(e1.iter().flat_map(|&e1| bit_iter(u64(e1), 1 + b_i.ilog2())))
.chain(e2.iter().flat_map(|&e2| bit_iter(u64(e2), 1 + b_i.ilog2())))
.chain(r1.iter().flat_map(|&r1| bit_iter(u64(r1), 1 + b_r.ilog2())))
.chain(r2.iter().flat_map(|&r2| bit_iter(u64(r2), 1 + b_r.ilog2()))),
)
.for_each(|(dst, src)| *dst = src);
let w = OneBased(w);
let mut c_hat = g_hat.mul_scalar(gamma);
for j in 1..big_d + 1 {
if w[j] {
c_hat += G2::projective(g_hat_list[j]);
}
}
let x_bytes = &*[
q.to_le_bytes().as_slice(),
(d as u64).to_le_bytes().as_slice(),
b_i.to_le_bytes().as_slice(),
t.to_le_bytes().as_slice(),
msbs_zero_padding_bit_count.to_le_bytes().as_slice(),
&*a.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
&*b.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
&*c1.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
&*c2.iter().flat_map(|&x| x.to_le_bytes()).collect::<Box<_>>(),
]
.iter()
.copied()
.flatten()
.copied()
.collect::<Box<_>>();
let mut y = vec![Zp::ZERO; n];
Zp::hash(
&mut y,
&[
ds.hash(),
sid.to_le_bytes().as_slice(),
metadata,
x_bytes,
c_hat.to_le_bytes().as_ref(),
],
);
let y = OneBased(y);
// GPU MSM: c_y commitment
let scalars = (n + 1 - big_d..n + 1)
.map(|j| y[n + 1 - j] * Zp::from_u64(w[n + 1 - j] as u64))
.collect::<Vec<_>>();
let c_y = g.mul_scalar(gamma_y)
+ crate::gpu::g1_msm_gpu(&g_list.0[n - big_d..n], &scalars, select_gpu_for_msm());
let mut theta = vec![Zp::ZERO; d + k + 1];
Zp::hash(
&mut theta,
&[
ds.hash_lmap(),
sid.to_le_bytes().as_slice(),
metadata,
x_bytes,
c_hat.to_le_bytes().as_ref(),
c_y.to_le_bytes().as_ref(),
],
);
let theta0 = &theta[..d + k];
let delta_theta = theta[d + k];
let mut a_theta = vec![Zp::ZERO; big_d];
compute_a_theta::<Bls12_446>(
theta0,
d,
a,
k,
b,
&mut a_theta,
effective_t_for_decomposition,
delta,
b_i,
b_r,
decoded_q,
);
let mut t = vec![Zp::ZERO; n];
Zp::hash_128bit(
&mut t,
&[
ds.hash_t(),
sid.to_le_bytes().as_slice(),
metadata,
&(1..n + 1)
.flat_map(|i| y[i].to_le_bytes().as_ref().to_vec())
.collect::<Box<_>>(),
x_bytes,
c_hat.to_le_bytes().as_ref(),
c_y.to_le_bytes().as_ref(),
],
);
let t = OneBased(t);
let mut delta = [Zp::ZERO; 2];
Zp::hash(
&mut delta,
&[
ds.hash_agg(),
sid.to_le_bytes().as_slice(),
metadata,
x_bytes,
c_hat.to_le_bytes().as_ref(),
c_y.to_le_bytes().as_ref(),
],
);
let [delta_eq, delta_y] = delta;
let delta = [delta_eq, delta_y, delta_theta];
let mut poly_0 = vec![Zp::ZERO; n + 1];
let mut poly_1 = vec![Zp::ZERO; big_d + 1];
let mut poly_2 = vec![Zp::ZERO; n + 1];
let mut poly_3 = vec![Zp::ZERO; n + 1];
poly_0[0] = delta_y * gamma_y;
for i in 1..n + 1 {
poly_0[n + 1 - i] =
delta_y * (y[i] * Zp::from_u64(w[i] as u64)) + (delta_eq * t[i] - delta_y) * y[i];
if i < big_d + 1 {
poly_0[n + 1 - i] += delta_theta * a_theta[i - 1];
}
}
poly_1[0] = gamma;
for i in 1..big_d + 1 {
poly_1[i] = Zp::from_u64(w[i] as u64);
}
poly_2[0] = gamma_y;
for i in 1..big_d + 1 {
poly_2[n + 1 - i] = y[i] * Zp::from_u64(w[i] as u64);
}
for i in 1..n + 1 {
poly_3[i] = delta_eq * t[i];
}
let mut t_theta = Zp::ZERO;
for i in 0..d {
t_theta += theta0[i] * Zp::from_i64(c1[i]);
}
for i in 0..k {
t_theta += theta0[d + i] * Zp::from_i64(c2[i]);
}
let mul = rayon::join(
|| Zp::poly_mul(&poly_0, &poly_1),
|| Zp::poly_mul(&poly_2, &poly_3),
);
let mut poly = Zp::poly_sub(&mul.0, &mul.1);
if poly.len() > n + 1 {
poly[n + 1] -= t_theta * delta_theta;
}
// GPU MSM: pi commitment
let pi = g.mul_scalar(poly[0])
+ crate::gpu::g1_msm_gpu(
&g_list.0[..poly.len() - 1],
&poly[1..],
select_gpu_for_msm(),
);
if load == ComputeLoad::Proof {
// GPU MSM: c_hat_t commitment
let c_hat_t = crate::gpu::g2_msm_gpu(&g_hat_list.0, &t.0, select_gpu_for_msm());
let scalars = (1..n + 1)
.map(|i| {
let i = n + 1 - i;
(delta_eq * t[i] - delta_y) * y[i]
+ if i < big_d + 1 {
delta_theta * a_theta[i - 1]
} else {
Zp::ZERO
}
})
.collect::<Vec<_>>();
// GPU MSM: c_h commitment
let c_h = crate::gpu::g1_msm_gpu(&g_list.0[..n], &scalars, select_gpu_for_msm());
let mut z = Zp::ZERO;
Zp::hash(
core::array::from_mut(&mut z),
&[
ds.hash_z(),
sid.to_le_bytes().as_slice(),
metadata,
x_bytes,
c_hat.to_le_bytes().as_ref(),
c_y.to_le_bytes().as_ref(),
pi.to_le_bytes().as_ref(),
c_h.to_le_bytes().as_ref(),
c_hat_t.to_le_bytes().as_ref(),
&y.0.iter()
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>(),
&t.0.iter()
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>(),
&delta
.iter()
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>(),
],
);
let mut pow = z;
let mut p_t = Zp::ZERO;
let mut p_h = Zp::ZERO;
for i in 1..n + 1 {
p_t += t[i] * pow;
if n - i < big_d {
p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]
+ delta_theta * a_theta[n - i])
* pow;
} else {
p_h += ((delta_eq * t[n + 1 - i] - delta_y) * y[n + 1 - i]) * pow;
}
pow = pow * z;
}
let mut w = Zp::ZERO;
Zp::hash(
core::array::from_mut(&mut w),
&[
ds.hash_w(),
sid.to_le_bytes().as_slice(),
metadata,
x_bytes,
c_hat.to_le_bytes().as_ref(),
c_y.to_le_bytes().as_ref(),
pi.to_le_bytes().as_ref(),
c_h.to_le_bytes().as_ref(),
c_hat_t.to_le_bytes().as_ref(),
&y.0.iter()
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>(),
&t.0.iter()
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>(),
&delta
.iter()
.flat_map(|x| x.to_le_bytes().as_ref().to_vec())
.collect::<Box<[_]>>(),
z.to_le_bytes().as_ref(),
p_h.to_le_bytes().as_ref(),
p_t.to_le_bytes().as_ref(),
],
);
let mut poly = vec![Zp::ZERO; n + 1];
for i in 1..n + 1 {
poly[i] += w * t[i];
if i < big_d + 1 {
poly[n + 1 - i] +=
(delta_eq * t[i] - delta_y) * y[i] + delta_theta * a_theta[i - 1];
} else {
poly[n + 1 - i] += (delta_eq * t[i] - delta_y) * y[i];
}
}
let mut q = vec![Zp::ZERO; n];
// Polynomial long division by (X - z)
for i in (0..n).rev() {
poly[i] = poly[i] + z * poly[i + 1];
q[i] = poly[i + 1];
poly[i + 1] = Zp::ZERO;
}
// GPU MSM: pi_kzg commitment
let pi_kzg = g.mul_scalar(q[0])
+ crate::gpu::g1_msm_gpu(&g_list.0[..n - 1], &q[1..n], select_gpu_for_msm());
Proof {
c_hat,
c_y,
pi,
compute_load_proof_fields: Some(ComputeLoadProofFields {
c_hat_t,
c_h,
pi_kzg,
}),
}
} else {
Proof {
c_hat,
c_y,
pi,
compute_load_proof_fields: None,
}
}
}

View File

@@ -1,3 +1,6 @@
#[cfg(feature = "gpu-experimental")]
pub mod gpu;
// TODO: refactor copy-pasted code in proof/verify
use crate::backward_compatibility::pke::{
@@ -15,7 +18,7 @@ use core::marker::PhantomData;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
pub(crate) fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
}
@@ -441,10 +444,10 @@ where
#[derive(Clone, Debug)]
pub struct PublicCommit<G: Curve> {
a: Vec<i64>,
b: Vec<i64>,
c1: Vec<i64>,
c2: Vec<i64>,
pub(crate) a: Vec<i64>,
pub(crate) b: Vec<i64>,
pub(crate) c1: Vec<i64>,
pub(crate) c2: Vec<i64>,
__marker: PhantomData<G>,
}
@@ -462,10 +465,10 @@ impl<G: Curve> PublicCommit<G> {
#[derive(Clone, Debug)]
pub struct PrivateCommit<G: Curve> {
r: Vec<i64>,
e1: Vec<i64>,
m: Vec<i64>,
e2: Vec<i64>,
pub(crate) r: Vec<i64>,
pub(crate) e1: Vec<i64>,
pub(crate) m: Vec<i64>,
pub(crate) e2: Vec<i64>,
__marker: PhantomData<G>,
}
@@ -932,7 +935,7 @@ fn prove_impl<G: Curve>(
}
#[allow(clippy::too_many_arguments)]
fn compute_a_theta<G: Curve>(
pub(crate) fn compute_a_theta<G: Curve>(
theta0: &[G::Zp],
d: usize,
a: &[i64],
@@ -1353,7 +1356,7 @@ pub fn verify<G: Curve>(
}
#[cfg(test)]
mod tests {
pub(crate) mod tests {
use crate::curve_api::{self, bls12_446};
use super::super::test::*;
@@ -1364,7 +1367,7 @@ mod tests {
type Curve = curve_api::Bls12_446;
/// Compact key params used with pkev1
pub(super) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
pub(crate) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
d: 1024,
k: 320,
B: 4398046511104, // 2**42

File diff suppressed because it is too large Load Diff

View File

@@ -337,13 +337,13 @@ struct RInputs<'a> {
mode: PkeV2HashMode,
}
pub(super) struct RHash<'a> {
pub(crate) struct RHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
}
impl<'a> RHash<'a> {
pub(super) fn new<G: Curve>(
pub(crate) fn new<G: Curve>(
public: (&'a PublicParams<G>, &PublicCommit<G>),
metadata: &'a [u8],
C_hat_e_bytes: &'a [u8],
@@ -502,7 +502,7 @@ impl<'a> RHash<'a> {
]
}
pub(super) fn gen_phi<Zp: FieldOps>(self, C_R_bytes: &'a [u8]) -> ([Zp; 128], PhiHash<'a>) {
pub(crate) fn gen_phi<Zp: FieldOps>(self, C_R_bytes: &'a [u8]) -> ([Zp; 128], PhiHash<'a>) {
let mode = self.R_inputs.mode;
let phi_inputs = PhiInputs { C_R_bytes };
@@ -526,7 +526,7 @@ struct PhiInputs<'a> {
C_R_bytes: &'a [u8],
}
pub(super) struct PhiHash<'a> {
pub(crate) struct PhiHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -572,7 +572,7 @@ impl<'a> PhiHash<'a> {
}
}
pub(super) fn gen_xi<Zp: FieldOps>(self, C_hat_bin_bytes: &'a [u8]) -> ([Zp; 128], XiHash<'a>) {
pub(crate) fn gen_xi<Zp: FieldOps>(self, C_hat_bin_bytes: &'a [u8]) -> ([Zp; 128], XiHash<'a>) {
let mode = self.R_inputs.mode;
let xi_inputs = XiInputs { C_hat_bin_bytes };
@@ -598,7 +598,7 @@ struct XiInputs<'a> {
C_hat_bin_bytes: &'a [u8],
}
pub(super) struct XiHash<'a> {
pub(crate) struct XiHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -650,7 +650,7 @@ impl<'a> XiHash<'a> {
}
}
pub(super) fn gen_y<Zp: FieldOps>(self) -> (Vec<Zp>, YHash<'a>) {
pub(crate) fn gen_y<Zp: FieldOps>(self) -> (Vec<Zp>, YHash<'a>) {
let mode = self.R_inputs.mode;
let mut y = vec![Zp::ZERO; self.R_inputs.D + 128 * self.R_inputs.m];
@@ -671,7 +671,7 @@ impl<'a> XiHash<'a> {
}
}
pub(super) struct YHash<'a> {
pub(crate) struct YHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -729,7 +729,7 @@ impl<'a> YHash<'a> {
}
}
pub(super) fn gen_t<Zp: FieldOps>(self, C_y_bytes: &'a [u8]) -> (Vec<Zp>, THash<'a>) {
pub(crate) fn gen_t<Zp: FieldOps>(self, C_y_bytes: &'a [u8]) -> (Vec<Zp>, THash<'a>) {
let mode = self.R_inputs.mode;
let t_inputs = TInputs { C_y_bytes };
@@ -761,7 +761,7 @@ struct TInputs<'a> {
C_y_bytes: &'a [u8],
}
pub(super) struct THash<'a> {
pub(crate) struct THash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -825,7 +825,7 @@ impl<'a> THash<'a> {
}
}
pub(super) fn gen_theta<Zp: FieldOps>(self) -> (Vec<Zp>, ThetaHash<'a>) {
pub(crate) fn gen_theta<Zp: FieldOps>(self) -> (Vec<Zp>, ThetaHash<'a>) {
let mode = self.R_inputs.mode;
let mut theta = vec![Zp::ZERO; self.R_inputs.d + self.R_inputs.k];
@@ -849,7 +849,7 @@ impl<'a> THash<'a> {
}
}
pub(super) struct ThetaHash<'a> {
pub(crate) struct ThetaHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -917,7 +917,7 @@ impl<'a> ThetaHash<'a> {
}
}
pub(super) fn gen_omega<Zp: FieldOps>(self) -> (Vec<Zp>, OmegaHash<'a>) {
pub(crate) fn gen_omega<Zp: FieldOps>(self) -> (Vec<Zp>, OmegaHash<'a>) {
let mode = self.R_inputs.mode;
let mut omega = vec![Zp::ZERO; self.R_inputs.n];
@@ -946,7 +946,7 @@ impl<'a> ThetaHash<'a> {
}
}
pub(super) struct OmegaHash<'a> {
pub(crate) struct OmegaHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -1018,7 +1018,7 @@ impl<'a> OmegaHash<'a> {
}
}
pub(super) fn gen_delta<Zp: FieldOps>(self) -> ([Zp; 7], DeltaHash<'a>) {
pub(crate) fn gen_delta<Zp: FieldOps>(self) -> ([Zp; 7], DeltaHash<'a>) {
let mut delta = [Zp::ZERO; 7];
// Delta does not use the compact hash optimization
@@ -1048,7 +1048,7 @@ impl<'a> OmegaHash<'a> {
}
}
pub(super) struct DeltaHash<'a> {
pub(crate) struct DeltaHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -1161,7 +1161,7 @@ impl<'a> DeltaHash<'a> {
}
}
pub(super) fn gen_z<Zp: FieldOps>(
pub(crate) fn gen_z<Zp: FieldOps>(
self,
C_h1_bytes: &'a [u8],
C_h2_bytes: &'a [u8],
@@ -1210,7 +1210,7 @@ struct ZInputs<'a> {
C_hat_omega_bytes: &'a [u8],
}
pub(super) struct ZHash<'a> {
pub(crate) struct ZHash<'a> {
R_inputs: RInputs<'a>,
R_bytes: Box<[u8]>,
phi_inputs: PhiInputs<'a>,
@@ -1351,7 +1351,7 @@ impl<'a> ZHash<'a> {
}
}
pub(super) fn gen_chi<Zp: FieldOps>(
pub(crate) fn gen_chi<Zp: FieldOps>(
self,
p_h1: Zp,
p_h2: Zp,

View File

@@ -1,6 +1,9 @@
// to follow the notation of the paper
#![allow(non_snake_case)]
#[cfg(feature = "gpu-experimental")]
pub mod gpu;
use super::*;
use crate::backward_compatibility::pke_v2::*;
use crate::backward_compatibility::BoundVersions;
@@ -16,13 +19,13 @@ use core::marker::PhantomData;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
mod hashes;
pub(crate) mod hashes;
use hashes::RHash;
pub use hashes::*;
fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
pub(crate) fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
}
@@ -436,7 +439,7 @@ pub(crate) struct ComputeLoadProofFields<G: Curve> {
}
impl<G: Curve> ComputeLoadProofFields<G> {
fn to_le_bytes(fields: &Option<Self>) -> (Box<[u8]>, Box<[u8]>) {
pub(crate) fn to_le_bytes(fields: &Option<Self>) -> (Box<[u8]>, Box<[u8]>) {
if let Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) = fields.as_ref() {
(
Box::from(G::G2::to_le_bytes(*C_hat_h3).as_ref()),
@@ -589,13 +592,13 @@ where
#[derive(Clone, Debug)]
pub struct PublicCommit<G: Curve> {
/// Mask of the public key
a: Vec<i64>,
pub(crate) a: Vec<i64>,
/// Body of the public key
b: Vec<i64>,
pub(crate) b: Vec<i64>,
/// Mask of the ciphertexts
c1: Vec<i64>,
pub(crate) c1: Vec<i64>,
/// Bodies of the ciphertexts
c2: Vec<i64>,
pub(crate) c2: Vec<i64>,
__marker: PhantomData<G>,
}
@@ -614,13 +617,13 @@ impl<G: Curve> PublicCommit<G> {
#[derive(Clone, Debug)]
pub struct PrivateCommit<G: Curve> {
/// Public key sampling vector
r: Vec<i64>,
pub(crate) r: Vec<i64>,
/// Error vector associated with the masks
e1: Vec<i64>,
pub(crate) e1: Vec<i64>,
/// Input messages
m: Vec<i64>,
pub(crate) m: Vec<i64>,
/// Error vector associated with the bodies
e2: Vec<i64>,
pub(crate) e2: Vec<i64>,
__marker: PhantomData<G>,
}
@@ -697,7 +700,7 @@ The computed m parameter is {m_bound} > 64. Please select a smaller B, d and/or
///
/// Use the relationship: `||x||_2 <= sqrt(dim)*||x||_inf`. Since we are only interested in the
/// squared bound, we avoid the sqrt by returning dim*(||x||_inf)^2.
fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
pub(crate) fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
let norm_squared = sqr(B_inf);
norm_squared
.checked_mul(dim as u128)
@@ -1812,7 +1815,7 @@ fn precompute_xi_powers<Zp: FieldOps>(xi: &[Zp; 128], m: usize) -> Box<[Zp]> {
}
#[allow(clippy::too_many_arguments)]
fn compute_a_theta<G: Curve>(
pub(crate) fn compute_a_theta<G: Curve>(
a_theta: &mut [G::Zp],
theta: &[G::Zp],
a: &[i64],
@@ -1931,23 +1934,23 @@ pub enum VerificationPairingMode {
Batched,
}
struct GeneratedScalars<G: Curve> {
phi: [G::Zp; 128],
xi: [G::Zp; 128],
theta: Vec<G::Zp>,
omega: Vec<G::Zp>,
delta: [G::Zp; 7],
chi_powers: [G::Zp; 4],
z: G::Zp,
t_theta: G::Zp,
pub(crate) struct GeneratedScalars<G: Curve> {
pub(crate) phi: [G::Zp; 128],
pub(crate) xi: [G::Zp; 128],
pub(crate) theta: Vec<G::Zp>,
pub(crate) omega: Vec<G::Zp>,
pub(crate) delta: [G::Zp; 7],
pub(crate) chi_powers: [G::Zp; 4],
pub(crate) z: G::Zp,
pub(crate) t_theta: G::Zp,
}
struct EvaluationPoints<G: Curve> {
p_h1: G::Zp,
p_h2: G::Zp,
p_h3: G::Zp,
p_t: G::Zp,
p_omega: G::Zp,
pub(crate) struct EvaluationPoints<G: Curve> {
pub(crate) p_h1: G::Zp,
pub(crate) p_h2: G::Zp,
pub(crate) p_h3: G::Zp,
pub(crate) p_t: G::Zp,
pub(crate) p_omega: G::Zp,
}
#[allow(clippy::result_unit_err)]
@@ -2655,7 +2658,7 @@ fn pairing_check_batched<G: Curve>(
}
#[cfg(test)]
mod tests {
pub(crate) mod tests {
use crate::curve_api::{self, bls12_446};
use super::super::test::*;
@@ -2667,7 +2670,7 @@ mod tests {
type Curve = curve_api::Bls12_446;
/// Compact key params used with pkev2
pub(super) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
pub(crate) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
d: 2048,
k: 320,
B: 131072, // 2**17

View File

@@ -102,6 +102,7 @@ integer = ["shortint", "dep:strum"]
strings = ["integer"]
internal-keycache = ["dep:fs2"]
gpu = ["dep:tfhe-cuda-backend", "shortint"]
gpu-experimental-zk = ["gpu", "zk-pok", "tfhe-zk-pok/gpu-experimental"]
gpu-experimental-multi-arch = [
"gpu",
"tfhe-cuda-backend/experimental-multi-arch",

View File

@@ -56,7 +56,6 @@ impl CudaProvenCompactCiphertextList {
},
|| self.expand_without_verification(key, streams),
);
if all_valid {
return r;
}

View File

@@ -17,15 +17,21 @@ use std::fmt::Debug;
use tfhe_versionable::Versionize;
use tfhe_zk_pok::proofs::pke::{
commit as commit_v1, crs_gen as crs_gen_v1, prove as prove_v1, verify as verify_v1,
Proof as ProofV1, PublicCommit as PublicCommitV1,
commit as commit_v1, crs_gen as crs_gen_v1, verify as verify_v1, Proof as ProofV1,
PublicCommit as PublicCommitV1,
};
use tfhe_zk_pok::proofs::pke_v2::{
commit as commit_v2, crs_gen as crs_gen_v2, prove as prove_v2, verify as verify_v2,
PkeV2SupportedHashConfig, Proof as ProofV2, PublicCommit as PublicCommitV2,
VerificationPairingMode,
commit as commit_v2, crs_gen as crs_gen_v2, PkeV2SupportedHashConfig, Proof as ProofV2,
PublicCommit as PublicCommitV2, VerificationPairingMode,
};
#[cfg(not(feature = "gpu-experimental-zk"))]
use tfhe_zk_pok::proofs::pke::prove as prove_v1;
#[cfg(not(feature = "gpu-experimental-zk"))]
use tfhe_zk_pok::proofs::pke_v2::{prove as prove_v2, verify as verify_v2};
pub use tfhe_zk_pok::curve_api::Compressible;
pub use tfhe_zk_pok::proofs::pke_v2::PkeV2SupportedHashConfig as ZkPkeV2SupportedHashConfig;
pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad;
@@ -725,6 +731,15 @@ impl CompactPkeCrs {
public_params,
);
#[cfg(feature = "gpu-experimental-zk")]
let proof = tfhe_zk_pok::proofs::pke::gpu::prove(
(public_params, &public_commit),
&private_commit,
metadata,
load,
&seed,
);
#[cfg(not(feature = "gpu-experimental-zk"))]
let proof = prove_v1(
(public_params, &public_commit),
&private_commit,
@@ -748,6 +763,15 @@ impl CompactPkeCrs {
public_params,
);
#[cfg(feature = "gpu-experimental-zk")]
let proof = tfhe_zk_pok::proofs::pke_v2::gpu::prove(
(public_params, &public_commit),
&private_commit,
metadata,
load,
&seed,
);
#[cfg(not(feature = "gpu-experimental-zk"))]
let proof = prove_v2(
(public_params, &public_commit),
&private_commit,
@@ -816,12 +840,21 @@ impl CompactPkeCrs {
}
(Self::PkeV2(public_params), CompactPkeProof::PkeV2(proof)) => {
let public_commit = PublicCommitV2::new(key_mask, key_body, ct_mask, ct_body);
verify_v2(
#[cfg(feature = "gpu-experimental-zk")]
let res = tfhe_zk_pok::proofs::pke_v2::gpu::verify(
proof,
(public_params, &public_commit),
metadata,
VerificationPairingMode::default(),
)
);
#[cfg(not(feature = "gpu-experimental-zk"))]
let res = verify_v2(
proof,
(public_params, &public_commit),
metadata,
VerificationPairingMode::default(),
);
res
}
(Self::PkeV1(_), CompactPkeProof::PkeV2(_))