diff --git a/backends/zk-cuda-backend/cuda/src/primitives/fp.cu b/backends/zk-cuda-backend/cuda/src/primitives/fp.cu index 38323c054..93c444398 100644 --- a/backends/zk-cuda-backend/cuda/src/primitives/fp.cu +++ b/backends/zk-cuda-backend/cuda/src/primitives/fp.cu @@ -187,37 +187,82 @@ __host__ __device__ void fp_copy(Fp &dst, const Fp &src) { // "Raw" means without modular reduction - performs a + b and returns carry. // This is an internal helper used by fp_add() which handles reduction. __host__ __device__ UNSIGNED_LIMB fp_add_raw(Fp &c, const Fp &a, const Fp &b) { +#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64 + // PTX carry-chain: add.cc sets the hardware carry flag, addc.cc propagates + // it. This replaces 2 software carry-detect comparisons per limb (~14 extra + // instructions across 7 limbs) with zero-cost hardware flag propagation. + uint64_t carry_out; + asm("add.cc.u64 %0, %8, %15;\n\t" // c[0] = a[0] + b[0], set CF + "addc.cc.u64 %1, %9, %16;\n\t" // c[1] = a[1] + b[1] + CF + "addc.cc.u64 %2, %10, %17;\n\t" // c[2] = a[2] + b[2] + CF + "addc.cc.u64 %3, %11, %18;\n\t" // c[3] = a[3] + b[3] + CF + "addc.cc.u64 %4, %12, %19;\n\t" // c[4] = a[4] + b[4] + CF + "addc.cc.u64 %5, %13, %20;\n\t" // c[5] = a[5] + b[5] + CF + "addc.cc.u64 %6, %14, %21;\n\t" // c[6] = a[6] + b[6] + CF + "addc.u64 %7, 0, 0;\n\t" // carry_out = 0 + 0 + CF + : "=l"(c.limb[0]), "=l"(c.limb[1]), "=l"(c.limb[2]), "=l"(c.limb[3]), + "=l"(c.limb[4]), "=l"(c.limb[5]), "=l"(c.limb[6]), "=l"(carry_out) + : "l"(a.limb[0]), "l"(a.limb[1]), "l"(a.limb[2]), "l"(a.limb[3]), + "l"(a.limb[4]), "l"(a.limb[5]), "l"(a.limb[6]), "l"(b.limb[0]), + "l"(b.limb[1]), "l"(b.limb[2]), "l"(b.limb[3]), "l"(b.limb[4]), + "l"(b.limb[5]), "l"(b.limb[6])); + return carry_out; +#else + // Host path: portable software carry detection UNSIGNED_LIMB carry = 0; for (int i = 0; i < FP_LIMBS; i++) { - // Add with carry: c = a + b + carry UNSIGNED_LIMB sum = a.limb[i] + carry; - carry = (sum < a.limb[i]) ? 1 : 0; // Check for overflow + carry = (sum < a.limb[i]) ? 1 : 0; sum += b.limb[i]; - carry += (sum < b.limb[i]) ? 1 : 0; // Check for overflow + carry += (sum < b.limb[i]) ? 1 : 0; c.limb[i] = sum; } return carry; +#endif } // Subtraction with borrow propagation // "Raw" means without modular reduction - performs a - b and returns borrow. // This is an internal helper used by fp_sub() which handles reduction. __host__ __device__ UNSIGNED_LIMB fp_sub_raw(Fp &c, const Fp &a, const Fp &b) { +#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64 + // PTX borrow-chain: sub.cc sets the hardware borrow flag, subc.cc propagates + // it. Same benefit as fp_add_raw -- eliminates 2 comparisons per limb. + uint64_t borrow_out; + asm("sub.cc.u64 %0, %8, %15;\n\t" // c[0] = a[0] - b[0], set CF + "subc.cc.u64 %1, %9, %16;\n\t" // c[1] = a[1] - b[1] - CF + "subc.cc.u64 %2, %10, %17;\n\t" // c[2] = a[2] - b[2] - CF + "subc.cc.u64 %3, %11, %18;\n\t" // c[3] = a[3] - b[3] - CF + "subc.cc.u64 %4, %12, %19;\n\t" // c[4] = a[4] - b[4] - CF + "subc.cc.u64 %5, %13, %20;\n\t" // c[5] = a[5] - b[5] - CF + "subc.cc.u64 %6, %14, %21;\n\t" // c[6] = a[6] - b[6] - CF + "subc.u64 %7, 0, 0;\n\t" // borrow_out = 0 - 0 - CF + : "=l"(c.limb[0]), "=l"(c.limb[1]), "=l"(c.limb[2]), "=l"(c.limb[3]), + "=l"(c.limb[4]), "=l"(c.limb[5]), "=l"(c.limb[6]), "=l"(borrow_out) + : "l"(a.limb[0]), "l"(a.limb[1]), "l"(a.limb[2]), "l"(a.limb[3]), + "l"(a.limb[4]), "l"(a.limb[5]), "l"(a.limb[6]), "l"(b.limb[0]), + "l"(b.limb[1]), "l"(b.limb[2]), "l"(b.limb[3]), "l"(b.limb[4]), + "l"(b.limb[5]), "l"(b.limb[6])); + // subc.u64 with 0-0-CF produces 0 if no borrow, or 0xFFFFFFFFFFFFFFFF if + // borrow. Normalize to 0/1 for callers that check (borrow != 0) or add it. + return borrow_out & 1; +#else + // Host path: portable software borrow detection UNSIGNED_LIMB borrow = 0; for (int i = 0; i < FP_LIMBS; i++) { - // Subtract with borrow: c = a - b - borrow UNSIGNED_LIMB diff = a.limb[i] - borrow; - borrow = (diff > a.limb[i]) ? 1 : 0; // Check for underflow + borrow = (diff > a.limb[i]) ? 1 : 0; UNSIGNED_LIMB old_diff = diff; diff -= b.limb[i]; - borrow += (diff > old_diff) ? 1 : 0; // Check for underflow + borrow += (diff > old_diff) ? 1 : 0; c.limb[i] = diff; } return borrow; +#endif } // Addition with modular reduction: c = (a + b) mod p @@ -226,7 +271,27 @@ __host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) { Fp sum; UNSIGNED_LIMB carry = fp_add_raw(sum, a, b); - // If there's a carry or sum >= MODULUS, we need to reduce +#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64 + // Branchless reduction: always compute sum - p, then select based on + // whether reduction was needed. This avoids divergent branches that stall + // warps when some threads need reduction and others don't. + // + // Decision logic: + // carry=1 -> sum overflowed 448 bits, definitely >= p -> use reduced + // carry=0, borrow=0 -> sum >= p in 448 bits -> use reduced + // carry=0, borrow=1 -> sum < p -> use original sum + // So: use_original = (!carry) & borrow + Fp reduced; + UNSIGNED_LIMB borrow = fp_sub_raw(reduced, sum, fp_modulus()); + UNSIGNED_LIMB use_original = ((carry ^ 1) & borrow); + UNSIGNED_LIMB mask = + -use_original; // all-ones if keep sum, all-zeros if keep reduced + + for (int i = 0; i < FP_LIMBS; i++) { + c.limb[i] = (sum.limb[i] & mask) | (reduced.limb[i] & ~mask); + } +#else + // Host path: branching is fine on CPU (branch predictor handles it well) const Fp &p = fp_modulus(); if (carry || fp_cmp(sum, p) != ComparisonType::Less) { Fp reduced; @@ -235,6 +300,7 @@ __host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) { } else { fp_copy(c, sum); } +#endif } // Subtraction with modular reduction: c = (a - b) mod p @@ -243,13 +309,28 @@ __host__ __device__ void fp_sub(Fp &c, const Fp &a, const Fp &b) { Fp diff; UNSIGNED_LIMB borrow = fp_sub_raw(diff, a, b); - // If there was a borrow, we need to add MODULUS +#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64 + // Branchless correction: always compute diff + p, select based on borrow. + // Same rationale as fp_add -- avoids warp divergence. + // borrow=1 -> a < b, need to add p -> use corrected + // borrow=0 -> a >= b, result is valid -> use diff + Fp corrected; + fp_add_raw(corrected, diff, fp_modulus()); + UNSIGNED_LIMB mask = + -borrow; // all-ones if borrow (use corrected), all-zeros if not + + for (int i = 0; i < FP_LIMBS; i++) { + c.limb[i] = (corrected.limb[i] & mask) | (diff.limb[i] & ~mask); + } +#else + // Host path: branching is fine on CPU const Fp &p = fp_modulus(); if (borrow) { fp_add_raw(c, diff, p); } else { fp_copy(c, diff); } +#endif } // Small-constant multiplication via addition chains. @@ -461,9 +542,8 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) { // (carry, t[j]) = t[j] + a[j] * b_i + carry // which is a 64x64->128 multiply plus a three-operand addition with carry. // -// The C++ path uses software carry detection: carry = (sum < old) ? 1 : 0, -// which compiles to SETP + SELP (2-3 extra instructions per carry). The PTX -// path below uses hardware carry flags via the .cc suffix: +// The C++ path uses software carry detection: carry = (sum < old) ? 1 : 0. +// The PTX path below uses hardware carry flags via the .cc suffix: // - mul.lo.u64 / mul.hi.u64 : 64x64->128 wide multiply // - add.cc.u64 / addc.u64 : addition chain with hardware carry flag // @@ -493,17 +573,18 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) { // addc.u64 _hi, _hi, 0 -- _hi += CF // add.cc.u64 t_j, t_j, carry -- t_j += carry_in, set CF // addc.u64 carry, _hi, 0 -- carry_out = _hi + CF -#define LIMB_MACC(t_j, carry, a_j, b_i) \ +#define LIMB_MACC(t_j, carry, a_j, b_i) \ asm volatile("{\n\t" \ ".reg .u64 _lo, _hi;\n\t" \ - "mul.lo.u64 _lo, %2, %3;\n\t" \ - "mul.hi.u64 _hi, %2, %3;\n\t" \ - "add.cc.u64 %0, %0, _lo;\n\t" \ - "addc.u64 _hi, _hi, 0;\n\t" \ - "add.cc.u64 %0, %0, %1;\n\t" \ - "addc.u64 %1, _hi, 0;\n\t" \ + "mul.lo.u64 _lo, %2, %3;\n\t" \ + "mul.hi.u64 _hi, %2, %3;\n\t" \ + "add.cc.u64 %0, %0, _lo;\n\t" \ + "addc.u64 _hi, _hi, 0;\n\t" \ + "add.cc.u64 %0, %0, %1;\n\t" \ + "addc.u64 %1, _hi, 0;\n\t" \ "}\n\t" \ - : "+&l"(t_j), "+&l"(carry) : "l"(a_j), "l"(b_i)) + : "+&l"(t_j), "+&l"(carry) \ + : "l"(a_j), "l"(b_i)) // Single CIOS iteration: multiply-accumulate, reduce, and shift. // @@ -515,13 +596,9 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) { // // The macro lets the compiler allocate registers across all 7 unrolled // iterations, avoiding spills to local memory. -#define CIOS_ITERATION_PTX( \ - t0, t1, t2, t3, t4, t5, t6, t7, \ - a0, a1, a2, a3, a4, a5, a6, \ - b_i, \ - p0, p1, p2, p3, p4, p5, p6, \ - p_prime, \ - r0, r1, r2, r3, r4, r5, r6, r7) \ +#define CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, \ + a5, a6, b_i, p0, p1, p2, p3, p4, p5, p6, p_prime, \ + r0, r1, r2, r3, r4, r5, r6, r7) \ do { \ uint64_t _carry = 0; \ /* Step 1: t += a * b_i */ \ @@ -534,7 +611,7 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) { LIMB_MACC(t6, _carry, a6, b_i); \ /* Accumulate final carry into overflow limb t7 */ \ uint64_t _overflow; \ - asm("add.cc.u64 %0, %0, %2;\n\t" \ + asm("add.cc.u64 %0, %0, %2;\n\t" \ "addc.u64 %1, 0, 0;\n\t" \ : "+l"(t7), "=l"(_overflow) \ : "l"(_carry)); \ @@ -551,19 +628,26 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) { LIMB_MACC(t4, _carry, _m, p4); \ LIMB_MACC(t5, _carry, _m, p5); \ LIMB_MACC(t6, _carry, _m, p6); \ - /* Finalize overflow: t7 = t7 + _carry + _overflow */ \ - /* Plain adds (no carry chain) -- the CIOS invariant guarantees this */ \ - /* sum fits in 64 bits so intermediate overflow does not matter. */ \ + /* Finalize overflow: t7 = t7 + _carry + _overflow */ \ + /* Plain adds (no carry chain) -- the CIOS invariant guarantees this */ \ + /* sum fits in 64 bits so intermediate overflow does not matter. */ \ t7 += _carry; \ - t7 += _overflow; \ + t7 += _overflow; \ \ /* Step 4: Shift right by one limb via register renaming */ \ /* t0 is now zero (by construction of m), discard it */ \ - r0 = t1; r1 = t2; r2 = t3; r3 = t4; \ - r4 = t5; r5 = t6; r6 = t7; r7 = 0; \ + r0 = t1; \ + r1 = t2; \ + r2 = t3; \ + r3 = t4; \ + r4 = t5; \ + r5 = t6; \ + r6 = t7; \ + r7 = 0; \ } while (0) -__device__ __noinline__ void fp_mont_mul_cios_ptx(Fp &c, const Fp &a, const Fp &b) { +__device__ __noinline__ void fp_mont_mul_cios_ptx(Fp &c, const Fp &a, + const Fp &b) { const uint64_t p0 = DEVICE_MODULUS.limb[0]; const uint64_t p1 = DEVICE_MODULUS.limb[1]; const uint64_t p2 = DEVICE_MODULUS.limb[2]; @@ -585,61 +669,53 @@ __device__ __noinline__ void fp_mont_mul_cios_ptx(Fp &c, const Fp &a, const Fp & // Each iteration processes one limb of b, accumulates a*b[i], reduces, // and shifts. The output registers become the input for the next iteration. - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[0], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[0], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[1], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[1], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[2], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[2], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[3], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[3], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[4], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[4], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[5], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[5], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); - CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, - a0, a1, a2, a3, a4, a5, a6, b.limb[6], - p0, p1, p2, p3, p4, p5, p6, pp, - t0, t1, t2, t3, t4, t5, t6, t7); + CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6, + b.limb[6], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3, + t4, t5, t6, t7); // Final reduction: if t[0..7] >= p (extended to 8 limbs), subtract p. // Compute (t[0..6] - p[0..6]) with borrow, then subtract borrow from t7. // If t7 after subtraction is non-negative, the reduced result is valid; // otherwise the original t[0..6] is already in [0, p). uint64_t r0, r1, r2, r3, r4, r5, r6, mask; - asm("sub.cc.u64 %0, %8, %15;\n\t" // r0 = t0 - p0 - "subc.cc.u64 %1, %9, %16;\n\t" // r1 = t1 - p1 - borrow - "subc.cc.u64 %2, %10, %17;\n\t" // r2 = t2 - p2 - borrow - "subc.cc.u64 %3, %11, %18;\n\t" // r3 = t3 - p3 - borrow - "subc.cc.u64 %4, %12, %19;\n\t" // r4 = t4 - p4 - borrow - "subc.cc.u64 %5, %13, %20;\n\t" // r5 = t5 - p5 - borrow - "subc.cc.u64 %6, %14, %21;\n\t" // r6 = t6 - p6 - borrow - "subc.u64 %7, %22, 0;\n\t" // mask_src = t7 - 0 - borrow - "shr.s64 %7, %7, 63;\n\t" // mask = sign-extend: -1 if negative, 0 if >= 0 - : "=l"(r0), "=l"(r1), "=l"(r2), "=l"(r3), - "=l"(r4), "=l"(r5), "=l"(r6), "=l"(mask) - : "l"(t0), "l"(t1), "l"(t2), "l"(t3), - "l"(t4), "l"(t5), "l"(t6), - "l"(p0), "l"(p1), "l"(p2), "l"(p3), - "l"(p4), "l"(p5), "l"(p6), "l"(t7)); + asm("sub.cc.u64 %0, %8, %15;\n\t" // r0 = t0 - p0 + "subc.cc.u64 %1, %9, %16;\n\t" // r1 = t1 - p1 - borrow + "subc.cc.u64 %2, %10, %17;\n\t" // r2 = t2 - p2 - borrow + "subc.cc.u64 %3, %11, %18;\n\t" // r3 = t3 - p3 - borrow + "subc.cc.u64 %4, %12, %19;\n\t" // r4 = t4 - p4 - borrow + "subc.cc.u64 %5, %13, %20;\n\t" // r5 = t5 - p5 - borrow + "subc.cc.u64 %6, %14, %21;\n\t" // r6 = t6 - p6 - borrow + "subc.u64 %7, %22, 0;\n\t" // mask_src = t7 - 0 - borrow + "shr.s64 %7, %7, 63;\n\t" // mask = sign-extend: -1 if negative, 0 if + // >= 0 + : "=l"(r0), "=l"(r1), "=l"(r2), "=l"(r3), "=l"(r4), "=l"(r5), "=l"(r6), + "=l"(mask) + : "l"(t0), "l"(t1), "l"(t2), "l"(t3), "l"(t4), "l"(t5), "l"(t6), "l"(p0), + "l"(p1), "l"(p2), "l"(p3), "l"(p4), "l"(p5), "l"(p6), "l"(t7)); // Branchless selection: // mask = 0 -> t >= p (use reduced r[0..6])