feat(gpu): add PTX carry chains for fp_add/sub and branchless reduction

- replace software carry detection in fp_add_raw/fp_sub_raw with inline
PTX add.cc.u64/addc.cc.u64 and sub.cc.u64/subc.cc.u64 chains\
- now we always compute both reduced and unreduced result and select via bitmask
This commit is contained in:
Pedro Alves
2026-02-24 23:50:36 +00:00
parent af7e0f6fc8
commit ec64200123

View File

@@ -187,37 +187,82 @@ __host__ __device__ void fp_copy(Fp &dst, const Fp &src) {
// "Raw" means without modular reduction - performs a + b and returns carry.
// This is an internal helper used by fp_add() which handles reduction.
__host__ __device__ UNSIGNED_LIMB fp_add_raw(Fp &c, const Fp &a, const Fp &b) {
#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64
// PTX carry-chain: add.cc sets the hardware carry flag, addc.cc propagates
// it. This replaces 2 software carry-detect comparisons per limb (~14 extra
// instructions across 7 limbs) with zero-cost hardware flag propagation.
uint64_t carry_out;
asm("add.cc.u64 %0, %8, %15;\n\t" // c[0] = a[0] + b[0], set CF
"addc.cc.u64 %1, %9, %16;\n\t" // c[1] = a[1] + b[1] + CF
"addc.cc.u64 %2, %10, %17;\n\t" // c[2] = a[2] + b[2] + CF
"addc.cc.u64 %3, %11, %18;\n\t" // c[3] = a[3] + b[3] + CF
"addc.cc.u64 %4, %12, %19;\n\t" // c[4] = a[4] + b[4] + CF
"addc.cc.u64 %5, %13, %20;\n\t" // c[5] = a[5] + b[5] + CF
"addc.cc.u64 %6, %14, %21;\n\t" // c[6] = a[6] + b[6] + CF
"addc.u64 %7, 0, 0;\n\t" // carry_out = 0 + 0 + CF
: "=l"(c.limb[0]), "=l"(c.limb[1]), "=l"(c.limb[2]), "=l"(c.limb[3]),
"=l"(c.limb[4]), "=l"(c.limb[5]), "=l"(c.limb[6]), "=l"(carry_out)
: "l"(a.limb[0]), "l"(a.limb[1]), "l"(a.limb[2]), "l"(a.limb[3]),
"l"(a.limb[4]), "l"(a.limb[5]), "l"(a.limb[6]), "l"(b.limb[0]),
"l"(b.limb[1]), "l"(b.limb[2]), "l"(b.limb[3]), "l"(b.limb[4]),
"l"(b.limb[5]), "l"(b.limb[6]));
return carry_out;
#else
// Host path: portable software carry detection
UNSIGNED_LIMB carry = 0;
for (int i = 0; i < FP_LIMBS; i++) {
// Add with carry: c = a + b + carry
UNSIGNED_LIMB sum = a.limb[i] + carry;
carry = (sum < a.limb[i]) ? 1 : 0; // Check for overflow
carry = (sum < a.limb[i]) ? 1 : 0;
sum += b.limb[i];
carry += (sum < b.limb[i]) ? 1 : 0; // Check for overflow
carry += (sum < b.limb[i]) ? 1 : 0;
c.limb[i] = sum;
}
return carry;
#endif
}
// Subtraction with borrow propagation
// "Raw" means without modular reduction - performs a - b and returns borrow.
// This is an internal helper used by fp_sub() which handles reduction.
__host__ __device__ UNSIGNED_LIMB fp_sub_raw(Fp &c, const Fp &a, const Fp &b) {
#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64
// PTX borrow-chain: sub.cc sets the hardware borrow flag, subc.cc propagates
// it. Same benefit as fp_add_raw -- eliminates 2 comparisons per limb.
uint64_t borrow_out;
asm("sub.cc.u64 %0, %8, %15;\n\t" // c[0] = a[0] - b[0], set CF
"subc.cc.u64 %1, %9, %16;\n\t" // c[1] = a[1] - b[1] - CF
"subc.cc.u64 %2, %10, %17;\n\t" // c[2] = a[2] - b[2] - CF
"subc.cc.u64 %3, %11, %18;\n\t" // c[3] = a[3] - b[3] - CF
"subc.cc.u64 %4, %12, %19;\n\t" // c[4] = a[4] - b[4] - CF
"subc.cc.u64 %5, %13, %20;\n\t" // c[5] = a[5] - b[5] - CF
"subc.cc.u64 %6, %14, %21;\n\t" // c[6] = a[6] - b[6] - CF
"subc.u64 %7, 0, 0;\n\t" // borrow_out = 0 - 0 - CF
: "=l"(c.limb[0]), "=l"(c.limb[1]), "=l"(c.limb[2]), "=l"(c.limb[3]),
"=l"(c.limb[4]), "=l"(c.limb[5]), "=l"(c.limb[6]), "=l"(borrow_out)
: "l"(a.limb[0]), "l"(a.limb[1]), "l"(a.limb[2]), "l"(a.limb[3]),
"l"(a.limb[4]), "l"(a.limb[5]), "l"(a.limb[6]), "l"(b.limb[0]),
"l"(b.limb[1]), "l"(b.limb[2]), "l"(b.limb[3]), "l"(b.limb[4]),
"l"(b.limb[5]), "l"(b.limb[6]));
// subc.u64 with 0-0-CF produces 0 if no borrow, or 0xFFFFFFFFFFFFFFFF if
// borrow. Normalize to 0/1 for callers that check (borrow != 0) or add it.
return borrow_out & 1;
#else
// Host path: portable software borrow detection
UNSIGNED_LIMB borrow = 0;
for (int i = 0; i < FP_LIMBS; i++) {
// Subtract with borrow: c = a - b - borrow
UNSIGNED_LIMB diff = a.limb[i] - borrow;
borrow = (diff > a.limb[i]) ? 1 : 0; // Check for underflow
borrow = (diff > a.limb[i]) ? 1 : 0;
UNSIGNED_LIMB old_diff = diff;
diff -= b.limb[i];
borrow += (diff > old_diff) ? 1 : 0; // Check for underflow
borrow += (diff > old_diff) ? 1 : 0;
c.limb[i] = diff;
}
return borrow;
#endif
}
// Addition with modular reduction: c = (a + b) mod p
@@ -226,7 +271,27 @@ __host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) {
Fp sum;
UNSIGNED_LIMB carry = fp_add_raw(sum, a, b);
// If there's a carry or sum >= MODULUS, we need to reduce
#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64
// Branchless reduction: always compute sum - p, then select based on
// whether reduction was needed. This avoids divergent branches that stall
// warps when some threads need reduction and others don't.
//
// Decision logic:
// carry=1 -> sum overflowed 448 bits, definitely >= p -> use reduced
// carry=0, borrow=0 -> sum >= p in 448 bits -> use reduced
// carry=0, borrow=1 -> sum < p -> use original sum
// So: use_original = (!carry) & borrow
Fp reduced;
UNSIGNED_LIMB borrow = fp_sub_raw(reduced, sum, fp_modulus());
UNSIGNED_LIMB use_original = ((carry ^ 1) & borrow);
UNSIGNED_LIMB mask =
-use_original; // all-ones if keep sum, all-zeros if keep reduced
for (int i = 0; i < FP_LIMBS; i++) {
c.limb[i] = (sum.limb[i] & mask) | (reduced.limb[i] & ~mask);
}
#else
// Host path: branching is fine on CPU (branch predictor handles it well)
const Fp &p = fp_modulus();
if (carry || fp_cmp(sum, p) != ComparisonType::Less) {
Fp reduced;
@@ -235,6 +300,7 @@ __host__ __device__ void fp_add(Fp &c, const Fp &a, const Fp &b) {
} else {
fp_copy(c, sum);
}
#endif
}
// Subtraction with modular reduction: c = (a - b) mod p
@@ -243,13 +309,28 @@ __host__ __device__ void fp_sub(Fp &c, const Fp &a, const Fp &b) {
Fp diff;
UNSIGNED_LIMB borrow = fp_sub_raw(diff, a, b);
// If there was a borrow, we need to add MODULUS
#if defined(__CUDA_ARCH__) && LIMB_BITS_CONFIG == 64
// Branchless correction: always compute diff + p, select based on borrow.
// Same rationale as fp_add -- avoids warp divergence.
// borrow=1 -> a < b, need to add p -> use corrected
// borrow=0 -> a >= b, result is valid -> use diff
Fp corrected;
fp_add_raw(corrected, diff, fp_modulus());
UNSIGNED_LIMB mask =
-borrow; // all-ones if borrow (use corrected), all-zeros if not
for (int i = 0; i < FP_LIMBS; i++) {
c.limb[i] = (corrected.limb[i] & mask) | (diff.limb[i] & ~mask);
}
#else
// Host path: branching is fine on CPU
const Fp &p = fp_modulus();
if (borrow) {
fp_add_raw(c, diff, p);
} else {
fp_copy(c, diff);
}
#endif
}
// Small-constant multiplication via addition chains.
@@ -461,9 +542,8 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
// (carry, t[j]) = t[j] + a[j] * b_i + carry
// which is a 64x64->128 multiply plus a three-operand addition with carry.
//
// The C++ path uses software carry detection: carry = (sum < old) ? 1 : 0,
// which compiles to SETP + SELP (2-3 extra instructions per carry). The PTX
// path below uses hardware carry flags via the .cc suffix:
// The C++ path uses software carry detection: carry = (sum < old) ? 1 : 0.
// The PTX path below uses hardware carry flags via the .cc suffix:
// - mul.lo.u64 / mul.hi.u64 : 64x64->128 wide multiply
// - add.cc.u64 / addc.u64 : addition chain with hardware carry flag
//
@@ -493,17 +573,18 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
// addc.u64 _hi, _hi, 0 -- _hi += CF
// add.cc.u64 t_j, t_j, carry -- t_j += carry_in, set CF
// addc.u64 carry, _hi, 0 -- carry_out = _hi + CF
#define LIMB_MACC(t_j, carry, a_j, b_i) \
#define LIMB_MACC(t_j, carry, a_j, b_i) \
asm volatile("{\n\t" \
".reg .u64 _lo, _hi;\n\t" \
"mul.lo.u64 _lo, %2, %3;\n\t" \
"mul.hi.u64 _hi, %2, %3;\n\t" \
"add.cc.u64 %0, %0, _lo;\n\t" \
"addc.u64 _hi, _hi, 0;\n\t" \
"add.cc.u64 %0, %0, %1;\n\t" \
"addc.u64 %1, _hi, 0;\n\t" \
"mul.lo.u64 _lo, %2, %3;\n\t" \
"mul.hi.u64 _hi, %2, %3;\n\t" \
"add.cc.u64 %0, %0, _lo;\n\t" \
"addc.u64 _hi, _hi, 0;\n\t" \
"add.cc.u64 %0, %0, %1;\n\t" \
"addc.u64 %1, _hi, 0;\n\t" \
"}\n\t" \
: "+&l"(t_j), "+&l"(carry) : "l"(a_j), "l"(b_i))
: "+&l"(t_j), "+&l"(carry) \
: "l"(a_j), "l"(b_i))
// Single CIOS iteration: multiply-accumulate, reduce, and shift.
//
@@ -515,13 +596,9 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
//
// The macro lets the compiler allocate registers across all 7 unrolled
// iterations, avoiding spills to local memory.
#define CIOS_ITERATION_PTX( \
t0, t1, t2, t3, t4, t5, t6, t7, \
a0, a1, a2, a3, a4, a5, a6, \
b_i, \
p0, p1, p2, p3, p4, p5, p6, \
p_prime, \
r0, r1, r2, r3, r4, r5, r6, r7) \
#define CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, \
a5, a6, b_i, p0, p1, p2, p3, p4, p5, p6, p_prime, \
r0, r1, r2, r3, r4, r5, r6, r7) \
do { \
uint64_t _carry = 0; \
/* Step 1: t += a * b_i */ \
@@ -534,7 +611,7 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
LIMB_MACC(t6, _carry, a6, b_i); \
/* Accumulate final carry into overflow limb t7 */ \
uint64_t _overflow; \
asm("add.cc.u64 %0, %0, %2;\n\t" \
asm("add.cc.u64 %0, %0, %2;\n\t" \
"addc.u64 %1, 0, 0;\n\t" \
: "+l"(t7), "=l"(_overflow) \
: "l"(_carry)); \
@@ -551,19 +628,26 @@ __host__ __device__ void fp_mont_reduce(Fp &c, const UNSIGNED_LIMB *a) {
LIMB_MACC(t4, _carry, _m, p4); \
LIMB_MACC(t5, _carry, _m, p5); \
LIMB_MACC(t6, _carry, _m, p6); \
/* Finalize overflow: t7 = t7 + _carry + _overflow */ \
/* Plain adds (no carry chain) -- the CIOS invariant guarantees this */ \
/* sum fits in 64 bits so intermediate overflow does not matter. */ \
/* Finalize overflow: t7 = t7 + _carry + _overflow */ \
/* Plain adds (no carry chain) -- the CIOS invariant guarantees this */ \
/* sum fits in 64 bits so intermediate overflow does not matter. */ \
t7 += _carry; \
t7 += _overflow; \
t7 += _overflow; \
\
/* Step 4: Shift right by one limb via register renaming */ \
/* t0 is now zero (by construction of m), discard it */ \
r0 = t1; r1 = t2; r2 = t3; r3 = t4; \
r4 = t5; r5 = t6; r6 = t7; r7 = 0; \
r0 = t1; \
r1 = t2; \
r2 = t3; \
r3 = t4; \
r4 = t5; \
r5 = t6; \
r6 = t7; \
r7 = 0; \
} while (0)
__device__ __noinline__ void fp_mont_mul_cios_ptx(Fp &c, const Fp &a, const Fp &b) {
__device__ __noinline__ void fp_mont_mul_cios_ptx(Fp &c, const Fp &a,
const Fp &b) {
const uint64_t p0 = DEVICE_MODULUS.limb[0];
const uint64_t p1 = DEVICE_MODULUS.limb[1];
const uint64_t p2 = DEVICE_MODULUS.limb[2];
@@ -585,61 +669,53 @@ __device__ __noinline__ void fp_mont_mul_cios_ptx(Fp &c, const Fp &a, const Fp &
// Each iteration processes one limb of b, accumulates a*b[i], reduces,
// and shifts. The output registers become the input for the next iteration.
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[0],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[0], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[1],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[1], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[2],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[2], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[3],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[3], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[4],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[4], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[5],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[5], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7,
a0, a1, a2, a3, a4, a5, a6, b.limb[6],
p0, p1, p2, p3, p4, p5, p6, pp,
t0, t1, t2, t3, t4, t5, t6, t7);
CIOS_ITERATION_PTX(t0, t1, t2, t3, t4, t5, t6, t7, a0, a1, a2, a3, a4, a5, a6,
b.limb[6], p0, p1, p2, p3, p4, p5, p6, pp, t0, t1, t2, t3,
t4, t5, t6, t7);
// Final reduction: if t[0..7] >= p (extended to 8 limbs), subtract p.
// Compute (t[0..6] - p[0..6]) with borrow, then subtract borrow from t7.
// If t7 after subtraction is non-negative, the reduced result is valid;
// otherwise the original t[0..6] is already in [0, p).
uint64_t r0, r1, r2, r3, r4, r5, r6, mask;
asm("sub.cc.u64 %0, %8, %15;\n\t" // r0 = t0 - p0
"subc.cc.u64 %1, %9, %16;\n\t" // r1 = t1 - p1 - borrow
"subc.cc.u64 %2, %10, %17;\n\t" // r2 = t2 - p2 - borrow
"subc.cc.u64 %3, %11, %18;\n\t" // r3 = t3 - p3 - borrow
"subc.cc.u64 %4, %12, %19;\n\t" // r4 = t4 - p4 - borrow
"subc.cc.u64 %5, %13, %20;\n\t" // r5 = t5 - p5 - borrow
"subc.cc.u64 %6, %14, %21;\n\t" // r6 = t6 - p6 - borrow
"subc.u64 %7, %22, 0;\n\t" // mask_src = t7 - 0 - borrow
"shr.s64 %7, %7, 63;\n\t" // mask = sign-extend: -1 if negative, 0 if >= 0
: "=l"(r0), "=l"(r1), "=l"(r2), "=l"(r3),
"=l"(r4), "=l"(r5), "=l"(r6), "=l"(mask)
: "l"(t0), "l"(t1), "l"(t2), "l"(t3),
"l"(t4), "l"(t5), "l"(t6),
"l"(p0), "l"(p1), "l"(p2), "l"(p3),
"l"(p4), "l"(p5), "l"(p6), "l"(t7));
asm("sub.cc.u64 %0, %8, %15;\n\t" // r0 = t0 - p0
"subc.cc.u64 %1, %9, %16;\n\t" // r1 = t1 - p1 - borrow
"subc.cc.u64 %2, %10, %17;\n\t" // r2 = t2 - p2 - borrow
"subc.cc.u64 %3, %11, %18;\n\t" // r3 = t3 - p3 - borrow
"subc.cc.u64 %4, %12, %19;\n\t" // r4 = t4 - p4 - borrow
"subc.cc.u64 %5, %13, %20;\n\t" // r5 = t5 - p5 - borrow
"subc.cc.u64 %6, %14, %21;\n\t" // r6 = t6 - p6 - borrow
"subc.u64 %7, %22, 0;\n\t" // mask_src = t7 - 0 - borrow
"shr.s64 %7, %7, 63;\n\t" // mask = sign-extend: -1 if negative, 0 if
// >= 0
: "=l"(r0), "=l"(r1), "=l"(r2), "=l"(r3), "=l"(r4), "=l"(r5), "=l"(r6),
"=l"(mask)
: "l"(t0), "l"(t1), "l"(t2), "l"(t3), "l"(t4), "l"(t5), "l"(t6), "l"(p0),
"l"(p1), "l"(p2), "l"(p3), "l"(p4), "l"(p5), "l"(p6), "l"(t7));
// Branchless selection:
// mask = 0 -> t >= p (use reduced r[0..6])