mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 12:57:55 -05:00
Merge pull request #1265 from zama-ai/fix/tfhers-sign
fix(frontend/py): extend sign bit in tfhers.to_native
This commit is contained in:
@@ -825,7 +825,7 @@ sum(ctx: Context, node: Node, preds: list[Conversion]) → Conversion
|
||||
|
||||
---
|
||||
|
||||
<a href="../../frontends/concrete-python/concrete/fhe/mlir/converter.py#L1002"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
|
||||
<a href="../../frontends/concrete-python/concrete/fhe/mlir/converter.py#L1009"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
|
||||
|
||||
### <kbd>method</kbd> `tfhers_from_native`
|
||||
|
||||
|
||||
@@ -975,6 +975,7 @@ class Converter:
|
||||
|
||||
# we want to set the padding bit if the native type is signed
|
||||
# and the ciphertext is negative (sign bit set to 1)
|
||||
# we also want to extend the sign bit if we are going for a larger bitwidth
|
||||
if dtype.is_signed:
|
||||
# select MSBs of all tfhers ciphetexts
|
||||
index = [slice(0, dim_size) for dim_size in tfhers_int.shape[:-1]] + [
|
||||
@@ -985,15 +986,21 @@ class Converter:
|
||||
tfhers_int,
|
||||
index=index,
|
||||
)
|
||||
# construct padding bits based on sign bits (carry would be considered negative)
|
||||
padding_bit_table = [
|
||||
# extend sign bit up to the padding bit (carry would be considered negative)
|
||||
sign_ext_bit_table = [
|
||||
0,
|
||||
] * 2 ** (msg_width - 1) + [
|
||||
2**result_bit_width,
|
||||
# we fill all new msbs (including the padding bit) with the sign bit
|
||||
2**result_bit_width # padding bit
|
||||
+ ( # sign bit extension
|
||||
2**result_bit_width - 2**dtype.bit_width
|
||||
if result_bit_width > dtype.bit_width
|
||||
else 0
|
||||
),
|
||||
] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1))
|
||||
padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table)
|
||||
sign_ext_bits = ctx.tlu(result_type, msbs, sign_ext_bit_table)
|
||||
# set padding bits (where necessary) in the final result
|
||||
result = ctx.add(result_type, sum_result, padding_bits_inc)
|
||||
result = ctx.add(result_type, sum_result, sign_ext_bits)
|
||||
else:
|
||||
result = sum_result
|
||||
|
||||
|
||||
@@ -72,6 +72,34 @@ def is_input_and_output_tfhers(
|
||||
tfhers_int6_2_3 = partial(tfhers.TFHERSIntegerType, True, 6, 2, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("native_bitwidth", [8, 10, 11, 12, 13, 14, 15, 16])
|
||||
def test_tfhers_input_sign_ext(native_bitwidth):
|
||||
"""Test sign extension of tfhers input"""
|
||||
|
||||
dtype_spec = parameterize_partial_dtype(tfhers.int8_2_2)
|
||||
|
||||
dtype = partial(tfhers.TFHERSInteger, dtype_spec)
|
||||
inputset = (dtype(-120),)
|
||||
|
||||
# we want get the native ciphertext to native_bitwidth bits
|
||||
multiplier = 2 ** (native_bitwidth - 8) - 1
|
||||
|
||||
def compute(x):
|
||||
x = tfhers.to_native(x)
|
||||
return x * multiplier
|
||||
|
||||
compiler = fhe.Compiler(compute, {"x": "encrypted"})
|
||||
fhe_circuit = compiler.compile(inputset)
|
||||
|
||||
input_x = -1
|
||||
tfhers_x = (dtype_spec.encode(input_x),)
|
||||
|
||||
print(fhe_circuit.mlir)
|
||||
result = fhe_circuit.encrypt_run_decrypt(*tfhers_x)
|
||||
|
||||
assert result == input_x * multiplier
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function, parameters, input_dtype, output_dtype",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user