mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(frontend-python): correctly handle signedness during encrypted multiplication
This commit is contained in:
@@ -1401,9 +1401,6 @@ class Context:
|
||||
|
||||
assert self.is_bit_width_compatible(resulting_type, x, y)
|
||||
|
||||
x = self.to_signedness(x, of=resulting_type)
|
||||
y = self.to_signedness(y, of=resulting_type)
|
||||
|
||||
use_linalg = x.is_tensor or y.is_tensor
|
||||
|
||||
x = self.tensorize(x) if use_linalg else x
|
||||
@@ -1415,6 +1412,29 @@ class Context:
|
||||
dialect = fhelinalg if use_linalg else fhe
|
||||
operation = dialect.MulEintIntOp if y.is_clear else dialect.MulEintOp
|
||||
|
||||
if (x.is_signed or y.is_signed) and resulting_type.is_unsigned:
|
||||
x = self.to_signed(x)
|
||||
y = self.to_signed(y)
|
||||
|
||||
signed_resulting_type = self.typeof(
|
||||
Value(
|
||||
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
|
||||
shape=resulting_type.shape,
|
||||
is_encrypted=resulting_type.is_encrypted,
|
||||
)
|
||||
)
|
||||
intermediate_result = self.operation(
|
||||
operation,
|
||||
signed_resulting_type,
|
||||
x.result,
|
||||
y.result,
|
||||
)
|
||||
|
||||
return self.to_unsigned(intermediate_result)
|
||||
|
||||
x = self.to_signedness(x, of=resulting_type)
|
||||
y = self.to_signedness(y, of=resulting_type)
|
||||
|
||||
return self.operation(
|
||||
operation,
|
||||
resulting_type,
|
||||
|
||||
@@ -169,3 +169,30 @@ def test_mul(function, parameters, helpers):
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parameter_encryption_statuses,function,inputs",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
lambda x, y: (x - y) * ((x - y) > 0),
|
||||
[
|
||||
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
|
||||
np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]),
|
||||
],
|
||||
id="(x - y) * ((x - y) > 0)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mul_specific(parameter_encryption_statuses, function, inputs, helpers):
|
||||
"""
|
||||
Test mul with specific inputs.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
circuit = compiler.compile([tuple(inputs)], configuration)
|
||||
|
||||
helpers.check_execution(circuit, function, inputs)
|
||||
|
||||
Reference in New Issue
Block a user