mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
1 Commits
tm/split-o
...
al/fix_shi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4258f9c53c |
@@ -12,8 +12,8 @@ enum OUTPUT_CARRY { NONE = 0, GENERATED = 1, PROPAGATED = 2 };
|
||||
enum SHIFT_OR_ROTATE_TYPE {
|
||||
LEFT_SHIFT = 0,
|
||||
RIGHT_SHIFT = 1,
|
||||
LEFT_ROTATE = 2,
|
||||
RIGHT_ROTATE = 3
|
||||
ROTATE_LEFT = 2,
|
||||
ROTATE_RIGHT = 3
|
||||
};
|
||||
enum LUT_TYPE { OPERATOR = 0, MAXVALUE = 1, ISNONZERO = 2, BLOCKSLEN = 3 };
|
||||
enum BITOP_TYPE {
|
||||
@@ -1309,7 +1309,7 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
|
||||
|
||||
std::function<Torus(Torus, Torus)> shift_lut_f;
|
||||
|
||||
if (shift_type == LEFT_SHIFT) {
|
||||
if (shift_type == LEFT_SHIFT || shift_type == ROTATE_LEFT) {
|
||||
shift_lut_f = [shift_within_block,
|
||||
params](Torus current_block,
|
||||
Torus previous_block) -> Torus {
|
||||
@@ -1395,7 +1395,7 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
|
||||
|
||||
std::function<Torus(Torus, Torus)> shift_lut_f;
|
||||
|
||||
if (shift_type == LEFT_SHIFT) {
|
||||
if (shift_type == LEFT_SHIFT || shift_type == ROTATE_LEFT) {
|
||||
shift_lut_f = [shift_within_block,
|
||||
params](Torus current_block,
|
||||
Torus previous_block) -> Torus {
|
||||
|
||||
@@ -348,7 +348,7 @@ host_integer_div_rem_kb(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
interesting_remainder1.len - 1, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
host_radix_blocks_rotate_left(
|
||||
host_radix_blocks_rotate_right(
|
||||
streams, gpu_indexes, gpu_count, interesting_remainder1.data,
|
||||
tmp_radix.data, 1, interesting_remainder1.len, big_lwe_size);
|
||||
|
||||
|
||||
@@ -63,10 +63,10 @@ __global__ void radix_blocks_rotate_left(Torus *dst, Torus *src, uint32_t value,
|
||||
// calculation is not inplace, so `dst` and `src` must not be the same
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
host_radix_blocks_rotate_right(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count, Torus *dst, Torus *src,
|
||||
uint32_t value, uint32_t blocks_count,
|
||||
uint32_t lwe_size) {
|
||||
host_radix_blocks_rotate_left(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count, Torus *dst, Torus *src,
|
||||
uint32_t value, uint32_t blocks_count,
|
||||
uint32_t lwe_size) {
|
||||
if (src == dst) {
|
||||
PANIC("Cuda error (blocks_rotate_right): the source and destination "
|
||||
"pointers should be different");
|
||||
@@ -80,10 +80,10 @@ host_radix_blocks_rotate_right(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
// calculation is not inplace, so `dst` and `src` must not be the same
|
||||
template <typename Torus>
|
||||
__host__ void
|
||||
host_radix_blocks_rotate_left(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count, Torus *dst, Torus *src,
|
||||
uint32_t value, uint32_t blocks_count,
|
||||
uint32_t lwe_size) {
|
||||
host_radix_blocks_rotate_right(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
uint32_t gpu_count, Torus *dst, Torus *src,
|
||||
uint32_t value, uint32_t blocks_count,
|
||||
uint32_t lwe_size) {
|
||||
if (src == dst) {
|
||||
PANIC("Cuda error (blocks_rotate_left): the source and destination "
|
||||
"pointers should be different");
|
||||
@@ -456,9 +456,9 @@ void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes,
|
||||
}
|
||||
|
||||
cudaSetDevice(gpu_indexes[0]);
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count, step_output,
|
||||
generates_or_propagates, 1, num_blocks,
|
||||
big_lwe_size);
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, step_output,
|
||||
generates_or_propagates, 1, num_blocks,
|
||||
big_lwe_size);
|
||||
cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
@@ -522,9 +522,9 @@ void host_propagate_single_sub_borrow(cudaStream_t *streams,
|
||||
overflowed, &generates_or_propagates[big_lwe_size * (num_blocks - 1)],
|
||||
big_lwe_size_bytes, streams[0], gpu_indexes[0]);
|
||||
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count, step_output,
|
||||
generates_or_propagates, 1, num_blocks,
|
||||
big_lwe_size);
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, step_output,
|
||||
generates_or_propagates, 1, num_blocks,
|
||||
big_lwe_size);
|
||||
cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
|
||||
@@ -99,9 +99,9 @@ __host__ void host_integer_scalar_mul_radix(
|
||||
preshifted_buffer + (i % msg_bits) * num_radix_blocks * lwe_size;
|
||||
T *block_shift_buffer =
|
||||
all_shifted_buffer + j * num_radix_blocks * lwe_size;
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
block_shift_buffer, preshifted_radix_ct,
|
||||
i / msg_bits, num_radix_blocks, lwe_size);
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
block_shift_buffer, preshifted_radix_ct,
|
||||
i / msg_bits, num_radix_blocks, lwe_size);
|
||||
// create trivial assign for value = 0
|
||||
cuda_memset_async(block_shift_buffer, 0, (i / msg_bits) * lwe_size_bytes,
|
||||
streams[0], gpu_indexes[0]);
|
||||
|
||||
@@ -59,7 +59,31 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(
|
||||
// 256 threads are used in every block
|
||||
// block_count blocks will be used in the grid
|
||||
// one block is responsible to process single lwe ciphertext
|
||||
if (mem->shift_type == LEFT_SHIFT) {
|
||||
if (mem->shift_type == ROTATE_LEFT) {
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
|
||||
cuda_memcpy_async_gpu_to_gpu(lwe_array, rotated_buffer,
|
||||
num_blocks * big_lwe_size_bytes, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
if (shift_within_block == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto receiver_blocks = lwe_array;
|
||||
auto giver_blocks = rotated_buffer;
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, giver_blocks,
|
||||
lwe_array, 1, num_blocks, big_lwe_size);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
|
||||
giver_blocks, bsk, ksk, num_blocks, lut_bivariate,
|
||||
lut_bivariate->params.message_modulus);
|
||||
|
||||
} else {
|
||||
// left rotate
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
@@ -82,30 +106,6 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(
|
||||
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
|
||||
giver_blocks, bsk, ksk, num_blocks, lut_bivariate,
|
||||
lut_bivariate->params.message_modulus);
|
||||
|
||||
} else {
|
||||
// left shift
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
|
||||
cuda_memcpy_async_gpu_to_gpu(lwe_array, rotated_buffer,
|
||||
num_blocks * big_lwe_size_bytes, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
if (shift_within_block == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto receiver_blocks = lwe_array;
|
||||
auto giver_blocks = rotated_buffer;
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, giver_blocks,
|
||||
lwe_array, 1, num_blocks, big_lwe_size);
|
||||
|
||||
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
|
||||
giver_blocks, bsk, ksk, num_blocks, lut_bivariate,
|
||||
lut_bivariate->params.message_modulus);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -63,9 +63,9 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
|
||||
// block_count blocks will be used in the grid
|
||||
// one block is responsible to process single lwe ciphertext
|
||||
if (mem->shift_type == LEFT_SHIFT) {
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
|
||||
// create trivial assign for value = 0
|
||||
cuda_memset_async(rotated_buffer, 0, rotations * big_lwe_size_bytes,
|
||||
@@ -92,9 +92,9 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
|
||||
|
||||
} else {
|
||||
// right shift
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
|
||||
// rotate left as the blocks are from LSB to MSB
|
||||
// create trivial assign for value = 0
|
||||
@@ -173,9 +173,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
|
||||
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];
|
||||
|
||||
if (mem->shift_type == RIGHT_SHIFT) {
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_buffer, lwe_array, rotations,
|
||||
num_blocks, big_lwe_size);
|
||||
cuda_memcpy_async_gpu_to_gpu(lwe_array, rotated_buffer,
|
||||
num_blocks * big_lwe_size_bytes, streams[0],
|
||||
gpu_indexes[0]);
|
||||
|
||||
@@ -90,9 +90,9 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
|
||||
auto rotations = 1 << d;
|
||||
switch (mem->shift_type) {
|
||||
case LEFT_SHIFT:
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
|
||||
if (mem->is_signed && mem->shift_type == RIGHT_SHIFT)
|
||||
for (int i = 0; i < rotations; i++)
|
||||
@@ -104,9 +104,9 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
|
||||
streams[0], gpu_indexes[0]);
|
||||
break;
|
||||
case RIGHT_SHIFT:
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
|
||||
if (mem->is_signed)
|
||||
for (int i = 0; i < rotations; i++)
|
||||
@@ -118,16 +118,16 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
|
||||
rotated_input + (total_nb_bits - rotations) * big_lwe_size, 0,
|
||||
rotations * big_lwe_size_bytes, streams[0], gpu_indexes[0]);
|
||||
break;
|
||||
case LEFT_ROTATE:
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
break;
|
||||
case RIGHT_ROTATE:
|
||||
case ROTATE_LEFT:
|
||||
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
break;
|
||||
case ROTATE_RIGHT:
|
||||
host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count,
|
||||
rotated_input, input_bits_b, rotations,
|
||||
total_nb_bits, big_lwe_size);
|
||||
break;
|
||||
default:
|
||||
PANIC("Unknown operation")
|
||||
}
|
||||
|
||||
@@ -37,8 +37,8 @@ pub enum PBSType {
|
||||
pub enum ShiftRotateType {
|
||||
LeftShift = 0,
|
||||
RightShift = 1,
|
||||
LeftRotate = 2,
|
||||
RightRotate = 3,
|
||||
RotateLeft = 2,
|
||||
RotateRight = 3,
|
||||
}
|
||||
|
||||
#[repr(u32)]
|
||||
@@ -1481,7 +1481,7 @@ pub unsafe fn unchecked_rotate_right_integer_radix_kb_assign_async<
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightRotate as u32,
|
||||
ShiftRotateType::RotateRight as u32,
|
||||
is_signed,
|
||||
true,
|
||||
);
|
||||
@@ -1570,7 +1570,7 @@ pub unsafe fn unchecked_rotate_left_integer_radix_kb_assign_async<
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::LeftRotate as u32,
|
||||
ShiftRotateType::RotateLeft as u32,
|
||||
is_signed,
|
||||
true,
|
||||
);
|
||||
@@ -1750,7 +1750,7 @@ pub unsafe fn unchecked_scalar_rotate_left_integer_radix_kb_assign_async<
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::LeftShift as u32,
|
||||
ShiftRotateType::RotateLeft as u32,
|
||||
true,
|
||||
);
|
||||
cuda_integer_radix_scalar_rotate_kb_64_inplace(
|
||||
@@ -1832,7 +1832,7 @@ pub unsafe fn unchecked_scalar_rotate_right_integer_radix_kb_assign_async<
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
ShiftRotateType::RightShift as u32,
|
||||
ShiftRotateType::RotateRight as u32,
|
||||
true,
|
||||
);
|
||||
cuda_integer_radix_scalar_rotate_kb_64_inplace(
|
||||
|
||||
Reference in New Issue
Block a user