From 07d6293ca87187902d435e1b463c28038b4794fc Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 4 Dec 2023 12:41:22 +0300 Subject: [PATCH] fix(frontend-python): assigning signed values to unsigned tensors --- .../concrete-python/concrete/fhe/mlir/context.py | 1 + frontends/concrete-python/tests/conftest.py | 9 +++++---- .../tests/execution/test_static_assignment.py | 16 ++++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 583fa85d8..ebd8cce13 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -1559,6 +1559,7 @@ class Context: np.broadcast_to(np.zeros(y.shape), required_y_shape) y = self.broadcast_to(y, required_y_shape) + x = self.to_signedness(x, of=resulting_type) y = self.to_signedness(y, of=resulting_type) return self.operation( diff --git a/frontends/concrete-python/tests/conftest.py b/frontends/concrete-python/tests/conftest.py index 9004a5a60..6bd011511 100644 --- a/frontends/concrete-python/tests/conftest.py +++ b/frontends/concrete-python/tests/conftest.py @@ -5,6 +5,7 @@ Configuration of `pytest`. import json import os import random +from copy import deepcopy from pathlib import Path from typing import Any, Callable, Dict, List, Tuple, Union @@ -295,8 +296,8 @@ class Helpers: if not only_simulation: for i in range(retries): - expected = sanitize(function(*sample)) - actual = sanitize(circuit.encrypt_run_decrypt(*sample)) + expected = sanitize(function(*deepcopy(sample))) + actual = sanitize(circuit.encrypt_run_decrypt(*deepcopy(sample))) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): break @@ -317,8 +318,8 @@ class Helpers: circuit.enable_fhe_simulation() for i in range(retries): - expected = sanitize(function(*sample)) - actual = sanitize(circuit.simulate(*sample)) + expected = sanitize(function(*deepcopy(sample))) + actual = sanitize(circuit.simulate(*deepcopy(sample))) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): break diff --git a/frontends/concrete-python/tests/execution/test_static_assignment.py b/frontends/concrete-python/tests/execution/test_static_assignment.py index b16b859e4..d568d3593 100644 --- a/frontends/concrete-python/tests/execution/test_static_assignment.py +++ b/frontends/concrete-python/tests/execution/test_static_assignment.py @@ -443,6 +443,21 @@ def assignment_case_28(): return shape, assign +def assignment_case_29(): + """ + Assignment test case. + """ + + shape = (5,) + value = -20 + + def assign(x): + x[0] = value + return x + + return shape, assign + + @pytest.mark.parametrize( "shape,function", [ @@ -475,6 +490,7 @@ def assignment_case_28(): pytest.param(*assignment_case_26()), pytest.param(*assignment_case_27()), pytest.param(*assignment_case_28()), + pytest.param(*assignment_case_29()), ], ) def test_static_assignment(shape, function, helpers):