Merge pull request #1265 from zama-ai/fix/tfhers-sign

fix(frontend/py): extend sign bit in tfhers.to_native
This commit is contained in:
Quentin Bourgerie
2025-05-15 09:26:19 +02:00
committed by GitHub
3 changed files with 41 additions and 6 deletions

View File

@@ -825,7 +825,7 @@ sum(ctx: Context, node: Node, preds: list[Conversion]) → Conversion
---
<a href="../../frontends/concrete-python/concrete/fhe/mlir/converter.py#L1002"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../frontends/concrete-python/concrete/fhe/mlir/converter.py#L1009"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
### <kbd>method</kbd> `tfhers_from_native`

View File

@@ -975,6 +975,7 @@ class Converter:
# we want to set the padding bit if the native type is signed
# and the ciphertext is negative (sign bit set to 1)
# we also want to extend the sign bit if we are going for a larger bitwidth
if dtype.is_signed:
# select MSBs of all tfhers ciphetexts
index = [slice(0, dim_size) for dim_size in tfhers_int.shape[:-1]] + [
@@ -985,15 +986,21 @@ class Converter:
tfhers_int,
index=index,
)
# construct padding bits based on sign bits (carry would be considered negative)
padding_bit_table = [
# extend sign bit up to the padding bit (carry would be considered negative)
sign_ext_bit_table = [
0,
] * 2 ** (msg_width - 1) + [
2**result_bit_width,
# we fill all new msbs (including the padding bit) with the sign bit
2**result_bit_width # padding bit
+ ( # sign bit extension
2**result_bit_width - 2**dtype.bit_width
if result_bit_width > dtype.bit_width
else 0
),
] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1))
padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table)
sign_ext_bits = ctx.tlu(result_type, msbs, sign_ext_bit_table)
# set padding bits (where necessary) in the final result
result = ctx.add(result_type, sum_result, padding_bits_inc)
result = ctx.add(result_type, sum_result, sign_ext_bits)
else:
result = sum_result

View File

@@ -72,6 +72,34 @@ def is_input_and_output_tfhers(
tfhers_int6_2_3 = partial(tfhers.TFHERSIntegerType, True, 6, 2, 3)
@pytest.mark.parametrize("native_bitwidth", [8, 10, 11, 12, 13, 14, 15, 16])
def test_tfhers_input_sign_ext(native_bitwidth):
"""Test sign extension of tfhers input"""
dtype_spec = parameterize_partial_dtype(tfhers.int8_2_2)
dtype = partial(tfhers.TFHERSInteger, dtype_spec)
inputset = (dtype(-120),)
# we want get the native ciphertext to native_bitwidth bits
multiplier = 2 ** (native_bitwidth - 8) - 1
def compute(x):
x = tfhers.to_native(x)
return x * multiplier
compiler = fhe.Compiler(compute, {"x": "encrypted"})
fhe_circuit = compiler.compile(inputset)
input_x = -1
tfhers_x = (dtype_spec.encode(input_x),)
print(fhe_circuit.mlir)
result = fhe_circuit.encrypt_run_decrypt(*tfhers_x)
assert result == input_x * multiplier
@pytest.mark.parametrize(
"function, parameters, input_dtype, output_dtype",
[