Compare commits

..

1 Commits

Author SHA1 Message Date
Pedro Alves
739f99788a chore(zk): improve parallelism in prove_impl 2026-01-13 11:30:41 -03:00

View File

@@ -4,7 +4,7 @@
use super::*;
use crate::backward_compatibility::pke_v2::*;
use crate::backward_compatibility::BoundVersions;
use crate::curve_api::{CompressedG1, CompressedG2};
use crate::curve_api::{CompressedG1, CompressedG2, FieldOps};
use crate::four_squares::*;
use crate::serialization::{
InvalidSerializedAffineError, InvalidSerializedPublicParamsError, SerializableGroupElements,
@@ -956,28 +956,29 @@ fn prove_impl<G: Curve>(
.map(G::Zp::from_i64)
.collect::<Box<[_]>>();
let mut scalars = e1_zp
let scalars_e = e1_zp
.iter()
.copied()
.chain(e2_zp.iter().copied())
.chain(v_zp)
.collect::<Box<[_]>>();
let C_hat_e =
g_hat.mul_scalar(gamma_hat_e) + G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &scalars);
let scalars_e_rev: Box<[_]> = scalars_e.iter().copied().rev().collect();
let scalars_r: Box<[_]> = r1_zp.iter().chain(r2_zp.iter()).copied().collect();
let (C_e, C_r_tilde) = rayon::join(
let ((C_hat_e, C_e), C_r_tilde) = rayon::join(
|| {
scalars.reverse();
g.mul_scalar(gamma_e) + G::G1::multi_mul_scalar(&g_list[n - (d + k + 4)..n], &scalars)
},
|| {
let scalars = r1_zp
.iter()
.chain(r2_zp.iter())
.copied()
.collect::<Box<[_]>>();
g.mul_scalar(gamma_r) + G::G1::multi_mul_scalar(&g_list[..d + k], &scalars)
rayon::join(
|| {
g_hat.mul_scalar(gamma_hat_e)
+ G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &scalars_e)
},
|| {
g.mul_scalar(gamma_e)
+ G::G1::multi_mul_scalar(&g_list[n - (d + k + 4)..n], &scalars_e_rev)
},
)
},
|| g.mul_scalar(gamma_r) + G::G1::multi_mul_scalar(&g_list[..d + k], &scalars_r),
);
let C_hat_e_bytes = C_hat_e.to_le_bytes();
@@ -1102,178 +1103,230 @@ fn prove_impl<G: Curve>(
let (delta, delta_hash) = omega_hash.gen_delta::<G::Zp>();
let [delta_r, delta_dec, delta_eq, delta_y, delta_theta, delta_e, delta_l] = delta;
let mut poly_0_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_0_rhs = vec![G::Zp::ZERO; 1 + D + 128 * m];
let mut poly_1_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_1_rhs = vec![G::Zp::ZERO; 1 + d + k + 4];
let mut poly_2_lhs = vec![G::Zp::ZERO; 1 + d + k];
let mut poly_2_rhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_3_lhs = vec![G::Zp::ZERO; 1 + 128];
let mut poly_3_rhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_4_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_4_rhs = vec![G::Zp::ZERO; 1 + d + k + 4];
let mut poly_5_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_5_rhs = vec![G::Zp::ZERO; 1 + n];
let mut xi_scaled = xi;
poly_0_lhs[0] = delta_y * gamma_y;
for j in 0..D + 128 * m {
let p = &mut poly_0_lhs[n - j];
if !w_bin[j] {
*p -= delta_y * y[j];
}
if j < D {
*p += delta_theta * a_theta[j];
}
*p += delta_eq * t[j] * y[j];
if j >= D {
let j = j - D;
let xi = &mut xi_scaled[j / m];
let H_xi = *xi;
*xi = *xi + *xi;
let r = delta_dec * H_xi;
if j % m < m - 1 {
*p += r;
} else {
*p -= r;
}
}
}
poly_0_rhs[0] = gamma_bin;
for j in 0..D + 128 * m {
let p = &mut poly_0_rhs[j + 1];
if w_bin[j] {
*p = G::Zp::ONE;
}
}
poly_1_lhs[0] = delta_l * gamma_e;
for j in 0..d {
let p = &mut poly_1_lhs[n - j];
*p = delta_l * e1_zp[j];
}
for j in 0..k {
let p = &mut poly_1_lhs[n - (d + j)];
*p = delta_l * e2_zp[j];
}
for j in 0..4 {
let p = &mut poly_1_lhs[n - (d + k + j)];
*p = delta_l * v_zp[j];
}
for j in 0..n {
let p = &mut poly_1_lhs[n - j];
let mut acc = delta_e * omega[j];
if j < d + k {
acc += delta_theta * theta[j];
}
if j < d + k + 4 {
let mut acc2 = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, j) {
0 => {}
1 => acc2 += phi,
-1 => acc2 -= phi,
_ => unreachable!(),
}
}
acc += delta_r * acc2;
}
*p += acc;
}
poly_1_rhs[0] = gamma_hat_e;
for j in 0..d {
let p = &mut poly_1_rhs[1 + j];
*p = e1_zp[j];
}
for j in 0..k {
let p = &mut poly_1_rhs[1 + (d + j)];
*p = e2_zp[j];
}
for j in 0..4 {
let p = &mut poly_1_rhs[1 + (d + k + j)];
*p = v_zp[j];
}
poly_2_lhs[0] = gamma_r;
for j in 0..d {
let p = &mut poly_2_lhs[1 + j];
*p = r1_zp[j];
}
for j in 0..k {
let p = &mut poly_2_lhs[1 + (d + j)];
*p = r2_zp[j];
}
// Precompute xi powers to enable parallel polynomial construction
let xi_powers = precompute_xi_powers(&xi, m);
let delta_theta_q = delta_theta * G::Zp::from_u128(decoded_q);
for j in 0..d + k {
let p = &mut poly_2_rhs[n - j];
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, d + k + 4 + j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
*p = delta_r * acc - delta_theta_q * theta[j];
}
// Build all polynomial pairs in parallel
let (
((poly_0_lhs, poly_0_rhs), (poly_1_lhs, poly_1_rhs)),
(
(poly_2_lhs, poly_2_rhs),
((poly_3_lhs, poly_3_rhs), ((poly_4_lhs, poly_4_rhs), (poly_5_lhs, poly_5_rhs))),
),
) = rayon::join(
|| {
rayon::join(
// poly_0
|| {
let mut poly_0_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_0_rhs = vec![G::Zp::ZERO; 1 + D + 128 * m];
poly_3_lhs[0] = gamma_R;
for j in 0..128 {
let p = &mut poly_3_lhs[1 + j];
*p = G::Zp::from_i64(w_R[j]);
}
poly_0_lhs[0] = delta_y * gamma_y;
for j in 0..D + 128 * m {
let p = &mut poly_0_lhs[n - j];
for j in 0..128 {
let p = &mut poly_3_rhs[n - j];
*p = delta_r * phi[j] + delta_dec * xi[j];
}
if !w_bin[j] {
*p -= delta_y * y[j];
}
poly_4_lhs[0] = delta_e * gamma_e;
for j in 0..d {
let p = &mut poly_4_lhs[n - j];
*p = delta_e * e1_zp[j];
}
for j in 0..k {
let p = &mut poly_4_lhs[n - (d + j)];
*p = delta_e * e2_zp[j];
}
for j in 0..4 {
let p = &mut poly_4_lhs[n - (d + k + j)];
*p = delta_e * v_zp[j];
}
if j < D {
*p += delta_theta * a_theta[j];
}
*p += delta_eq * t[j] * y[j];
for j in 0..d + k + 4 {
let p = &mut poly_4_rhs[1 + j];
*p = omega[j];
}
if j >= D {
let j_inner = j - D;
let r = delta_dec * xi_powers[j_inner];
poly_5_lhs[0] = delta_eq * gamma_y;
for j in 0..D + 128 * m {
let p = &mut poly_5_lhs[n - j];
if j_inner % m < m - 1 {
*p += r;
} else {
*p -= r;
}
}
}
if w_bin[j] {
*p = delta_eq * y[j];
}
}
poly_0_rhs[0] = gamma_bin;
for j in 0..D + 128 * m {
let p = &mut poly_0_rhs[j + 1];
for j in 0..n {
let p = &mut poly_5_rhs[1 + j];
*p = t[j];
}
if w_bin[j] {
*p = G::Zp::ONE;
}
}
(poly_0_lhs, poly_0_rhs)
},
// poly_1
|| {
let mut poly_1_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_1_rhs = vec![G::Zp::ZERO; 1 + d + k + 4];
poly_1_lhs[0] = delta_l * gamma_e;
for j in 0..d {
let p = &mut poly_1_lhs[n - j];
*p = delta_l * e1_zp[j];
}
for j in 0..k {
let p = &mut poly_1_lhs[n - (d + j)];
*p = delta_l * e2_zp[j];
}
for j in 0..4 {
let p = &mut poly_1_lhs[n - (d + k + j)];
*p = delta_l * v_zp[j];
}
for j in 0..n {
let p = &mut poly_1_lhs[n - j];
let mut acc = delta_e * omega[j];
if j < d + k {
acc += delta_theta * theta[j];
}
if j < d + k + 4 {
let mut acc2 = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, j) {
0 => {}
1 => acc2 += phi,
-1 => acc2 -= phi,
_ => unreachable!(),
}
}
acc += delta_r * acc2;
}
*p += acc;
}
poly_1_rhs[0] = gamma_hat_e;
for j in 0..d {
let p = &mut poly_1_rhs[1 + j];
*p = e1_zp[j];
}
for j in 0..k {
let p = &mut poly_1_rhs[1 + (d + j)];
*p = e2_zp[j];
}
for j in 0..4 {
let p = &mut poly_1_rhs[1 + (d + k + j)];
*p = v_zp[j];
}
(poly_1_lhs, poly_1_rhs)
},
)
},
|| {
rayon::join(
// poly_2
|| {
let mut poly_2_lhs = vec![G::Zp::ZERO; 1 + d + k];
let mut poly_2_rhs = vec![G::Zp::ZERO; 1 + n];
poly_2_lhs[0] = gamma_r;
for j in 0..d {
let p = &mut poly_2_lhs[1 + j];
*p = r1_zp[j];
}
for j in 0..k {
let p = &mut poly_2_lhs[1 + (d + j)];
*p = r2_zp[j];
}
for j in 0..d + k {
let p = &mut poly_2_rhs[n - j];
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, d + k + 4 + j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
*p = delta_r * acc - delta_theta_q * theta[j];
}
(poly_2_lhs, poly_2_rhs)
},
|| {
rayon::join(
// poly_3
|| {
let mut poly_3_lhs = vec![G::Zp::ZERO; 1 + 128];
let mut poly_3_rhs = vec![G::Zp::ZERO; 1 + n];
poly_3_lhs[0] = gamma_R;
for j in 0..128 {
let p = &mut poly_3_lhs[1 + j];
*p = G::Zp::from_i64(w_R[j]);
}
for j in 0..128 {
let p = &mut poly_3_rhs[n - j];
*p = delta_r * phi[j] + delta_dec * xi_powers[j * m];
}
(poly_3_lhs, poly_3_rhs)
},
|| {
rayon::join(
// poly_4
|| {
let mut poly_4_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_4_rhs = vec![G::Zp::ZERO; 1 + d + k + 4];
poly_4_lhs[0] = delta_e * gamma_e;
for j in 0..d {
let p = &mut poly_4_lhs[n - j];
*p = delta_e * e1_zp[j];
}
for j in 0..k {
let p = &mut poly_4_lhs[n - (d + j)];
*p = delta_e * e2_zp[j];
}
for j in 0..4 {
let p = &mut poly_4_lhs[n - (d + k + j)];
*p = delta_e * v_zp[j];
}
for j in 0..d + k + 4 {
let p = &mut poly_4_rhs[1 + j];
*p = omega[j];
}
(poly_4_lhs, poly_4_rhs)
},
// poly_5
|| {
let mut poly_5_lhs = vec![G::Zp::ZERO; 1 + n];
let mut poly_5_rhs = vec![G::Zp::ZERO; 1 + n];
poly_5_lhs[0] = delta_eq * gamma_y;
for j in 0..D + 128 * m {
let p = &mut poly_5_lhs[n - j];
if w_bin[j] {
*p = delta_eq * y[j];
}
}
for j in 0..n {
let p = &mut poly_5_rhs[1 + j];
*p = t[j];
}
(poly_5_lhs, poly_5_rhs)
},
)
},
)
},
)
},
);
let poly = [
(&poly_0_lhs, &poly_0_rhs),
@@ -1346,101 +1399,125 @@ fn prove_impl<G: Curve>(
P_pi[n + 1] -= delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared);
}
let pi = if P_pi.is_empty() {
G::G1::ZERO
} else {
g.mul_scalar(P_pi[0]) + G::G1::multi_mul_scalar(&g_list[..P_pi.len() - 1], &P_pi[1..])
};
let mut xi_scaled = xi;
let mut scalars = (0..D + 128 * m)
.map(|j| {
let mut acc = G::Zp::ZERO;
if j < D {
acc += delta_theta * a_theta[j];
}
acc -= delta_y * y[j];
acc += delta_eq * t[j] * y[j];
if j >= D {
let j = j - D;
let xi = &mut xi_scaled[j / m];
let H_xi = *xi;
*xi = *xi + *xi;
let r = delta_dec * H_xi;
if j % m < m - 1 {
acc += r;
} else {
acc -= r;
}
}
acc
})
.collect::<Box<[_]>>();
scalars.reverse();
let C_h1 = G::G1::multi_mul_scalar(&g_list[n - (D + 128 * m)..n], &scalars);
let mut scalars = (0..n)
.map(|j| {
let mut acc = G::Zp::ZERO;
if j < d + k {
acc += delta_theta * theta[j];
}
acc += delta_e * omega[j];
if j < d + k + 4 {
let mut acc2 = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, j) {
0 => {}
1 => acc2 += phi,
-1 => acc2 -= phi,
_ => unreachable!(),
}
}
acc += delta_r * acc2;
}
acc
})
.collect::<Box<[_]>>();
scalars.reverse();
let C_h2 = G::G1::multi_mul_scalar(&g_list[..n], &scalars);
let compute_load_proof_fields = match load {
ComputeLoad::Proof => {
let (C_hat_h3, C_hat_w) = rayon::join(
// Parallelize pi, C_h1, C_h2, compute_load_proof_fields, and C_hat_t computations
let ((pi, C_h1), (C_h2, (compute_load_proof_fields, C_hat_t))) = rayon::join(
|| {
rayon::join(
// pi computation
|| {
G::G2::multi_mul_scalar(
&g_hat_list[n - (d + k)..n],
&(0..d + k)
.rev()
.map(|j| {
let mut acc = G::Zp::ZERO;
if P_pi.is_empty() {
G::G1::ZERO
} else {
g.mul_scalar(P_pi[0])
+ G::G1::multi_mul_scalar(&g_list[..P_pi.len() - 1], &P_pi[1..])
}
},
// C_h1 computation
|| {
let scalars_h1: Box<[_]> = (0..D + 128 * m)
.rev()
.map(|j| {
let mut acc = G::Zp::ZERO;
if j < D {
acc += delta_theta * a_theta[j];
}
acc -= delta_y * y[j];
acc += delta_eq * t[j] * y[j];
if j >= D {
let j_inner = j - D;
let r = delta_dec * xi_powers[j_inner];
if j_inner % m < m - 1 {
acc += r;
} else {
acc -= r;
}
}
acc
})
.collect();
G::G1::multi_mul_scalar(&g_list[n - (D + 128 * m)..n], &scalars_h1)
},
)
},
|| {
rayon::join(
// C_h2 computation
|| {
let scalars_h2: Box<[_]> = (0..n)
.rev()
.map(|j| {
let mut acc = G::Zp::ZERO;
if j < d + k {
acc += delta_theta * theta[j];
}
acc += delta_e * omega[j];
if j < d + k + 4 {
let mut acc2 = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, d + k + 4 + j) {
match R(i, j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
1 => acc2 += phi,
-1 => acc2 -= phi,
_ => unreachable!(),
}
}
delta_r * acc - delta_theta_q * theta[j]
})
.collect::<Box<[_]>>(),
acc += delta_r * acc2;
}
acc
})
.collect();
G::G1::multi_mul_scalar(&g_list[..n], &scalars_h2)
},
|| {
rayon::join(
// compute_load_proof_fields computation
|| match load {
ComputeLoad::Proof => {
let (C_hat_h3, C_hat_w) = rayon::join(
|| {
G::G2::multi_mul_scalar(
&g_hat_list[n - (d + k)..n],
&(0..d + k)
.rev()
.map(|j| {
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, d + k + 4 + j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
delta_r * acc - delta_theta_q * theta[j]
})
.collect::<Box<[_]>>(),
)
},
|| {
G::G2::multi_mul_scalar(
&g_hat_list[..d + k + 4],
&omega[..d + k + 4],
)
},
);
Some(ComputeLoadProofFields { C_hat_h3, C_hat_w })
}
ComputeLoad::Verify => None,
},
// C_hat_t computation
|| G::G2::multi_mul_scalar(g_hat_list, &t),
)
},
|| G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &omega[..d + k + 4]),
);
Some(ComputeLoadProofFields { C_hat_h3, C_hat_w })
}
ComputeLoad::Verify => None,
};
let C_hat_t = G::G2::multi_mul_scalar(g_hat_list, &t);
)
},
);
let (C_hat_h3_bytes, C_hat_w_bytes) =
ComputeLoadProofFields::to_le_bytes(&compute_load_proof_fields);
@@ -1457,110 +1534,175 @@ fn prove_impl<G: Curve>(
&C_hat_w_bytes,
);
let mut P_h1 = vec![G::Zp::ZERO; 1 + n];
let mut P_h2 = vec![G::Zp::ZERO; 1 + n];
let mut P_t = vec![G::Zp::ZERO; 1 + n];
let mut P_h3 = match load {
ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n],
ComputeLoad::Verify => vec![],
};
let mut P_omega = match load {
ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4],
ComputeLoad::Verify => vec![],
// Build P_h1, P_h2, P_t, P_h3, P_omega in parallel
let ((P_h1, P_h2), (P_t, (P_h3, P_omega))) = rayon::join(
|| {
rayon::join(
// P_h1
|| {
let mut P_h1 = vec![G::Zp::ZERO; 1 + n];
for j in 0..D + 128 * m {
let p = &mut P_h1[n - j];
if j < D {
*p += delta_theta * a_theta[j];
}
*p -= delta_y * y[j];
*p += delta_eq * t[j] * y[j];
if j >= D {
let j_inner = j - D;
let r = delta_dec * xi_powers[j_inner];
if j_inner % m < m - 1 {
*p += r;
} else {
*p -= r;
}
}
}
P_h1
},
// P_h2
|| {
let mut P_h2 = vec![G::Zp::ZERO; 1 + n];
for j in 0..n {
let p = &mut P_h2[n - j];
if j < d + k {
*p += delta_theta * theta[j];
}
*p += delta_e * omega[j];
if j < d + k + 4 {
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
*p += delta_r * acc;
}
}
P_h2
},
)
},
|| {
rayon::join(
// P_t
|| {
let mut P_t = vec![G::Zp::ZERO; 1 + n];
P_t[1..].copy_from_slice(&t);
P_t
},
|| {
rayon::join(
// P_h3
|| match load {
ComputeLoad::Proof => {
let mut P_h3 = vec![G::Zp::ZERO; 1 + n];
for j in 0..d + k {
let p = &mut P_h3[n - j];
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, d + k + 4 + j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
*p = delta_r * acc - delta_theta_q * theta[j];
}
P_h3
}
ComputeLoad::Verify => vec![],
},
// P_omega
|| match load {
ComputeLoad::Proof => {
let mut P_omega = vec![G::Zp::ZERO; 1 + d + k + 4];
P_omega[1..].copy_from_slice(&omega[..d + k + 4]);
P_omega
}
ComputeLoad::Verify => vec![],
},
)
},
)
},
);
// Precompute powers of z for parallel polynomial evaluation
let z_powers: Box<[_]> = {
let mut powers = Vec::with_capacity(n + 1);
let mut pow = G::Zp::ONE;
for _ in 0..n + 1 {
powers.push(pow);
pow = pow * z;
}
powers.into_boxed_slice()
};
let mut xi_scaled = xi;
for j in 0..D + 128 * m {
let p = &mut P_h1[n - j];
if j < D {
*p += delta_theta * a_theta[j];
}
*p -= delta_y * y[j];
*p += delta_eq * t[j] * y[j];
if j >= D {
let j = j - D;
let xi = &mut xi_scaled[j / m];
let H_xi = *xi;
*xi = *xi + *xi;
let r = delta_dec * H_xi;
if j % m < m - 1 {
*p += r;
} else {
*p -= r;
}
}
}
for j in 0..n {
let p = &mut P_h2[n - j];
if j < d + k {
*p += delta_theta * theta[j];
}
*p += delta_e * omega[j];
if j < d + k + 4 {
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
*p += delta_r * acc;
}
}
P_t[1..].copy_from_slice(&t);
if !P_h3.is_empty() {
for j in 0..d + k {
let p = &mut P_h3[n - j];
let mut acc = G::Zp::ZERO;
for (i, &phi) in phi.iter().enumerate() {
match R(i, d + k + 4 + j) {
0 => {}
1 => acc += phi,
-1 => acc -= phi,
_ => unreachable!(),
}
}
*p = delta_r * acc - delta_theta_q * theta[j];
}
}
if !P_omega.is_empty() {
P_omega[1..].copy_from_slice(&omega[..d + k + 4]);
}
let mut p_h1 = G::Zp::ZERO;
let mut p_h2 = G::Zp::ZERO;
let mut p_t = G::Zp::ZERO;
let mut p_h3 = G::Zp::ZERO;
let mut p_omega = G::Zp::ZERO;
let mut pow = G::Zp::ONE;
for j in 0..n + 1 {
p_h1 += P_h1[j] * pow;
p_h2 += P_h2[j] * pow;
p_t += P_t[j] * pow;
if j < P_h3.len() {
p_h3 += P_h3[j] * pow;
}
if j < P_omega.len() {
p_omega += P_omega[j] * pow;
}
pow = pow * z;
}
// Evaluate polynomials at z in parallel
let ((p_h1, p_h2), (p_t, (p_h3, p_omega))) = rayon::join(
|| {
rayon::join(
|| {
P_h1.iter()
.zip(z_powers.iter())
.map(|(&p, &pow)| p * pow)
.sum::<G::Zp>()
},
|| {
P_h2.iter()
.zip(z_powers.iter())
.map(|(&p, &pow)| p * pow)
.sum::<G::Zp>()
},
)
},
|| {
rayon::join(
|| {
P_t.iter()
.zip(z_powers.iter())
.map(|(&p, &pow)| p * pow)
.sum::<G::Zp>()
},
|| {
rayon::join(
|| {
if P_h3.is_empty() {
G::Zp::ZERO
} else {
P_h3.iter()
.zip(z_powers.iter())
.map(|(&p, &pow)| p * pow)
.sum::<G::Zp>()
}
},
|| {
if P_omega.is_empty() {
G::Zp::ZERO
} else {
P_omega
.iter()
.zip(z_powers.iter())
.map(|(&p, &pow)| p * pow)
.sum::<G::Zp>()
}
},
)
},
)
},
);
let p_h3_opt = if P_h3.is_empty() { None } else { Some(p_h3) };
let p_omega_opt = if P_omega.is_empty() {
@@ -1613,6 +1755,23 @@ fn prove_impl<G: Curve>(
}
}
/// Precompute xi powers: for each index j in 0..128*m, compute 2^(j % m) * xi[j / m]
/// This replaces the sequential accumulator pattern that mutates xi_scaled.
fn precompute_xi_powers<Zp: FieldOps>(xi: &[Zp; 128], m: usize) -> Box<[Zp]> {
(0..128 * m)
.map(|j| {
let group_idx = j / m;
let pos_in_group = j % m;
// 2^pos_in_group * xi[group_idx]
let mut power = xi[group_idx];
for _ in 0..pos_in_group {
power = power + power;
}
power
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn compute_a_theta<G: Curve>(
a_theta: &mut [G::Zp],