mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
1 Commits
hw-team/pg
...
al/improve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
699fe25950 |
@@ -631,68 +631,77 @@ impl CudaServerKey {
|
||||
// multiplier is less than the max possible value of Scalar
|
||||
// Issue q = SRA(MULSH(m, n), shpost) − XSIGN(n);
|
||||
|
||||
let (mut tmp, xsign) = rayon::join(
|
||||
move || {
|
||||
// MULSH(m, n)
|
||||
let mut tmp = self.signed_scalar_mul_high_async(
|
||||
numerator,
|
||||
chosen_multiplier.multiplier,
|
||||
streams,
|
||||
);
|
||||
|
||||
// SRA(MULSH(m, n), shpost)
|
||||
self.unchecked_scalar_right_shift_assign_async(
|
||||
&mut tmp,
|
||||
chosen_multiplier.shift_post,
|
||||
streams,
|
||||
);
|
||||
tmp
|
||||
},
|
||||
|| {
|
||||
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
|
||||
// It is equivalent to SRA(x, N − 1)
|
||||
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams)
|
||||
},
|
||||
let streams_1 = match streams.len() {
|
||||
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
|
||||
_ => &CudaStreams::new_multi_gpu(),
|
||||
};
|
||||
let streams_2 = match streams.len() {
|
||||
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
|
||||
_ => &CudaStreams::new_multi_gpu(),
|
||||
};
|
||||
streams.synchronize();
|
||||
// MULSH(m, n)
|
||||
let mut tmp = self.signed_scalar_mul_high_async(
|
||||
numerator,
|
||||
chosen_multiplier.multiplier,
|
||||
streams_1,
|
||||
);
|
||||
|
||||
// SRA(MULSH(m, n), shpost)
|
||||
self.unchecked_scalar_right_shift_assign_async(
|
||||
&mut tmp,
|
||||
chosen_multiplier.shift_post,
|
||||
streams_1,
|
||||
);
|
||||
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
|
||||
// It is equivalent to SRA(x, N − 1)
|
||||
let xsign =
|
||||
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams_2);
|
||||
streams_1.synchronize();
|
||||
streams_2.synchronize();
|
||||
|
||||
self.sub_assign_async(&mut tmp, &xsign, streams);
|
||||
quotient = tmp;
|
||||
} else {
|
||||
// Issue q = SRA(n + MULSH(m − 2^N , n), shpost) − XSIGN(n);
|
||||
// Note from the paper: m - 2^N is negative
|
||||
|
||||
let (mut tmp, xsign) = rayon::join(
|
||||
move || {
|
||||
// The subtraction may overflow.
|
||||
// We then cast the result to a signed type.
|
||||
// Overall, this will work fine due to two's complement representation
|
||||
let cst = chosen_multiplier.multiplier
|
||||
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE
|
||||
<< numerator_bits);
|
||||
let cst = Scalar::DoublePrecision::cast_from(cst);
|
||||
let streams_1 = match streams.len() {
|
||||
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
|
||||
_ => &CudaStreams::new_multi_gpu(),
|
||||
};
|
||||
let streams_2 = match streams.len() {
|
||||
1 => &CudaStreams::new_single_gpu(streams.gpu_indexes[0]),
|
||||
_ => &CudaStreams::new_multi_gpu(),
|
||||
};
|
||||
streams.synchronize();
|
||||
// The subtraction may overflow.
|
||||
// We then cast the result to a signed type.
|
||||
// Overall, this will work fine due to two's complement representation
|
||||
let cst = chosen_multiplier.multiplier
|
||||
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE << numerator_bits);
|
||||
let cst = Scalar::DoublePrecision::cast_from(cst);
|
||||
|
||||
// MULSH(m - 2^N, n)
|
||||
let mut tmp = self.signed_scalar_mul_high_async(numerator, cst, streams);
|
||||
// MULSH(m - 2^N, n)
|
||||
let mut tmp = self.signed_scalar_mul_high_async(numerator, cst, streams_1);
|
||||
|
||||
// n + MULSH(m − 2^N , n)
|
||||
self.add_assign_async(&mut tmp, numerator, streams);
|
||||
// n + MULSH(m − 2^N , n)
|
||||
self.add_assign_async(&mut tmp, numerator, streams_1);
|
||||
|
||||
// SRA(n + MULSH(m - 2^N, n), shpost)
|
||||
tmp = self.unchecked_scalar_right_shift_async(
|
||||
&tmp,
|
||||
chosen_multiplier.shift_post,
|
||||
streams,
|
||||
);
|
||||
|
||||
tmp
|
||||
},
|
||||
|| {
|
||||
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
|
||||
// It is equivalent to SRA(x, N − 1)
|
||||
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams)
|
||||
},
|
||||
// SRA(n + MULSH(m - 2^N, n), shpost)
|
||||
tmp = self.unchecked_scalar_right_shift_async(
|
||||
&tmp,
|
||||
chosen_multiplier.shift_post,
|
||||
streams_1,
|
||||
);
|
||||
|
||||
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
|
||||
// It is equivalent to SRA(x, N − 1)
|
||||
let xsign =
|
||||
self.unchecked_scalar_right_shift_async(numerator, numerator_bits - 1, streams_2);
|
||||
streams_1.synchronize();
|
||||
streams_2.synchronize();
|
||||
|
||||
self.sub_assign_async(&mut tmp, &xsign, streams);
|
||||
quotient = tmp;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user