feat(compiler/backend): support tfhers int/uint of 2-16 bits

This commit is contained in:
youben11
2025-02-21 11:02:51 +01:00
parent ff62daaf57
commit 970148ff4d
4 changed files with 1018 additions and 405 deletions

View File

@@ -123,7 +123,7 @@ extra_bindings = []
[parse.expand]
crates = []
crates = ["concrete-cpu"]
all_features = false
default_features = true
features = []

View File

@@ -369,12 +369,96 @@ void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out,
size_t input_dimension,
size_t output_dimension);
size_t concrete_cpu_lwe_array_to_tfhers_int10(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int12(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int14(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int16(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int2(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int4(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int6(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_int8(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint10(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint12(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint14(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint16(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint2(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint4(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint6(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
size_t n_elem,
struct TfhersFheIntDescription desc);
size_t concrete_cpu_lwe_array_to_tfhers_uint8(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
@@ -422,12 +506,96 @@ size_t concrete_cpu_serialize_lwe_secret_key_u64(const uint64_t *lwe_sk,
size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts, size_t n_elem);
int64_t concrete_cpu_tfhers_int10_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int12_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int14_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int16_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int2_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int4_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int6_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_int8_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint10_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint12_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint14_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint16_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint2_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint4_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint6_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
size_t n_elem,
struct TfhersFheIntDescription desc);
int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,

File diff suppressed because it is too large Load Diff

View File

@@ -209,12 +209,54 @@ Result<TransportValue> importTfhersInteger(llvm::ArrayRef<uint8_t> buffer,
std::function<int64_t(const uint8_t *, size_t, uint64_t *, size_t,
TfhersFheIntDescription)>
conversion_func;
if (integerDesc.width == 8) {
if (integerDesc.width == 2) {
if (integerDesc.is_signed) { // fheint2
conversion_func = concrete_cpu_tfhers_int2_to_lwe_array;
} else { // fheuint2
conversion_func = concrete_cpu_tfhers_uint2_to_lwe_array;
}
} else if (integerDesc.width == 4) {
if (integerDesc.is_signed) { // fheint4
conversion_func = concrete_cpu_tfhers_int4_to_lwe_array;
} else { // fheuint4
conversion_func = concrete_cpu_tfhers_uint4_to_lwe_array;
}
} else if (integerDesc.width == 6) {
if (integerDesc.is_signed) { // fheint6
conversion_func = concrete_cpu_tfhers_int6_to_lwe_array;
} else { // fheuint6
conversion_func = concrete_cpu_tfhers_uint6_to_lwe_array;
}
} else if (integerDesc.width == 8) {
if (integerDesc.is_signed) { // fheint8
conversion_func = concrete_cpu_tfhers_int8_to_lwe_array;
} else { // fheuint8
conversion_func = concrete_cpu_tfhers_uint8_to_lwe_array;
}
} else if (integerDesc.width == 10) {
if (integerDesc.is_signed) { // fheint10
conversion_func = concrete_cpu_tfhers_int10_to_lwe_array;
} else { // fheuint10
conversion_func = concrete_cpu_tfhers_uint10_to_lwe_array;
}
} else if (integerDesc.width == 12) {
if (integerDesc.is_signed) { // fheint12
conversion_func = concrete_cpu_tfhers_int12_to_lwe_array;
} else { // fheuint12
conversion_func = concrete_cpu_tfhers_uint12_to_lwe_array;
}
} else if (integerDesc.width == 14) {
if (integerDesc.is_signed) { // fheint14
conversion_func = concrete_cpu_tfhers_int14_to_lwe_array;
} else { // fheuint14
conversion_func = concrete_cpu_tfhers_uint14_to_lwe_array;
}
} else if (integerDesc.width == 16) {
if (integerDesc.is_signed) { // fheint16
conversion_func = concrete_cpu_tfhers_int16_to_lwe_array;
} else { // fheuint16
conversion_func = concrete_cpu_tfhers_uint16_to_lwe_array;
}
} else {
std::ostringstream stringStream;
stringStream << "importTfhersInteger: no support for " << integerDesc.width
@@ -273,12 +315,54 @@ exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) {
std::function<size_t(const uint64_t *, uint8_t *, size_t, size_t,
TfhersFheIntDescription)>
conversion_func;
if (integerDesc.width == 8) {
if (integerDesc.width == 2) {
if (integerDesc.is_signed) { // fheint2
conversion_func = concrete_cpu_lwe_array_to_tfhers_int2;
} else { // fheuint2
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint2;
}
} else if (integerDesc.width == 4) {
if (integerDesc.is_signed) { // fheint4
conversion_func = concrete_cpu_lwe_array_to_tfhers_int4;
} else { // fheuint4
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint4;
}
} else if (integerDesc.width == 6) {
if (integerDesc.is_signed) { // fheint6
conversion_func = concrete_cpu_lwe_array_to_tfhers_int6;
} else { // fheuint6
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint6;
}
} else if (integerDesc.width == 8) {
if (integerDesc.is_signed) { // fheint8
conversion_func = concrete_cpu_lwe_array_to_tfhers_int8;
} else { // fheuint8
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint8;
}
} else if (integerDesc.width == 10) {
if (integerDesc.is_signed) { // fheint10
conversion_func = concrete_cpu_lwe_array_to_tfhers_int10;
} else { // fheuint10
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint10;
}
} else if (integerDesc.width == 12) {
if (integerDesc.is_signed) { // fheint12
conversion_func = concrete_cpu_lwe_array_to_tfhers_int12;
} else { // fheuint12
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint12;
}
} else if (integerDesc.width == 14) {
if (integerDesc.is_signed) { // fheint14
conversion_func = concrete_cpu_lwe_array_to_tfhers_int14;
} else { // fheuint14
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint14;
}
} else if (integerDesc.width == 16) {
if (integerDesc.is_signed) { // fheint16
conversion_func = concrete_cpu_lwe_array_to_tfhers_int16;
} else { // fheuint16
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint16;
}
} else {
std::ostringstream stringStream;
stringStream << "exportTfhersInteger: no support for " << integerDesc.width