fix(frontend-python): correctly handle signedness during encrypted multiplication

This commit is contained in:
Umut
2023-05-01 11:58:55 +02:00
committed by Quentin Bourgerie
parent eb3aca19ef
commit e162c58dfb
2 changed files with 50 additions and 3 deletions

View File

@@ -1401,9 +1401,6 @@ class Context:
assert self.is_bit_width_compatible(resulting_type, x, y)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
use_linalg = x.is_tensor or y.is_tensor
x = self.tensorize(x) if use_linalg else x
@@ -1415,6 +1412,29 @@ class Context:
dialect = fhelinalg if use_linalg else fhe
operation = dialect.MulEintIntOp if y.is_clear else dialect.MulEintOp
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
x = self.to_signed(x)
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
)
)
intermediate_result = self.operation(
operation,
signed_resulting_type,
x.result,
y.result,
)
return self.to_unsigned(intermediate_result)
x = self.to_signedness(x, of=resulting_type)
y = self.to_signedness(y, of=resulting_type)
return self.operation(
operation,
resulting_type,

View File

@@ -169,3 +169,30 @@ def test_mul(function, parameters, helpers):
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"parameter_encryption_statuses,function,inputs",
[
pytest.param(
{"x": "encrypted", "y": "encrypted"},
lambda x, y: (x - y) * ((x - y) > 0),
[
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]),
],
id="(x - y) * ((x - y) > 0)",
),
],
)
def test_mul_specific(parameter_encryption_statuses, function, inputs, helpers):
"""
Test mul with specific inputs.
"""
configuration = helpers.configuration()
compiler = fhe.Compiler(function, parameter_encryption_statuses)
circuit = compiler.compile([tuple(inputs)], configuration)
helpers.check_execution(circuit, function, inputs)