Compare commits

...

1 Commits

Author SHA1 Message Date
Agnes Leroy
699fe25950 chore(gpu): improve scalar div performance 2024-10-02 11:24:40 +02:00

View File

@@ -631,68 +631,77 @@ impl CudaServerKey {
// multiplier is less than the max possible value of Scalar
// Issue q = SRA(MULSH(m, n), shpost) XSIGN(n);
let (mut tmp, xsign) = rayon::join(
move || {
// MULSH(m, n)
let mut tmp = self.signed_scalar_mul_high_async(
numerator,
chosen_multiplier.multiplier,
streams,
);
// SRA(MULSH(m, n), shpost)
self.unchecked_scalar_right_shift_assign_async(
&mut tmp,
chosen_multiplier.shift_post,
streams,
);
tmp
},
|| {
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N 1)
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams)
},
let streams_1 = match streams.len() {
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
_ => &CudaStreams::new_multi_gpu(),
};
let streams_2 = match streams.len() {
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
_ => &CudaStreams::new_multi_gpu(),
};
streams.synchronize();
// MULSH(m, n)
let mut tmp = self.signed_scalar_mul_high_async(
numerator,
chosen_multiplier.multiplier,
streams_1,
);
// SRA(MULSH(m, n), shpost)
self.unchecked_scalar_right_shift_assign_async(
&mut tmp,
chosen_multiplier.shift_post,
streams_1,
);
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N 1)
let xsign =
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams_2);
streams_1.synchronize();
streams_2.synchronize();
self.sub_assign_async(&mut tmp, &xsign, streams);
quotient = tmp;
} else {
// Issue q = SRA(n + MULSH(m 2^N , n), shpost) XSIGN(n);
// Note from the paper: m - 2^N is negative
let (mut tmp, xsign) = rayon::join(
move || {
// The subtraction may overflow.
// We then cast the result to a signed type.
// Overall, this will work fine due to two's complement representation
let cst = chosen_multiplier.multiplier
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE
<< numerator_bits);
let cst = Scalar::DoublePrecision::cast_from(cst);
let streams_1 = match streams.len() {
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
_ => &CudaStreams::new_multi_gpu(),
};
let streams_2 = match streams.len() {
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
_ => &CudaStreams::new_multi_gpu(),
};
streams.synchronize();
// The subtraction may overflow.
// We then cast the result to a signed type.
// Overall, this will work fine due to two's complement representation
let cst = chosen_multiplier.multiplier
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE << numerator_bits);
let cst = Scalar::DoublePrecision::cast_from(cst);
// MULSH(m - 2^N, n)
let mut tmp = self.signed_scalar_mul_high_async(numerator, cst, streams);
// MULSH(m - 2^N, n)
let mut tmp = self.signed_scalar_mul_high_async(numerator, cst, streams_1);
// n + MULSH(m 2^N , n)
self.add_assign_async(&mut tmp, numerator, streams);
// n + MULSH(m 2^N , n)
self.add_assign_async(&mut tmp, numerator, streams_1);
// SRA(n + MULSH(m - 2^N, n), shpost)
tmp = self.unchecked_scalar_right_shift_async(
&tmp,
chosen_multiplier.shift_post,
streams,
);
tmp
},
|| {
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N 1)
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams)
},
// SRA(n + MULSH(m - 2^N, n), shpost)
tmp = self.unchecked_scalar_right_shift_async(
&tmp,
chosen_multiplier.shift_post,
streams_1,
);
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N 1)
let xsign =
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams_2);
streams_1.synchronize();
streams_2.synchronize();
self.sub_assign_async(&mut tmp, &xsign, streams);
quotient = tmp;
}