diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index a0087b4ac..0a3d3355a 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -27,7 +27,7 @@ from .extensions import ( zero, zeros, ) -from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH +from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH from .representation import Graph from .tracing.typing import ( f32, diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index 85173dce5..da7b7a61f 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -197,37 +197,58 @@ class Client: """ self.keygen(force=False) - results = ClientSupport.decrypt_result(self._keyset, result) - if not isinstance(results, tuple): - results = (results,) + outputs = ClientSupport.decrypt_result(self._keyset, result) + if not isinstance(outputs, tuple): + outputs = (outputs,) - sanitized_results: List[Union[int, np.ndarray]] = [] + sanitized_outputs: List[Union[int, np.ndarray]] = [] client_parameters_json = json.loads(self.specs.client_parameters.serialize()) assert_that("outputs" in client_parameters_json) output_specs = client_parameters_json["outputs"] - for index, spec in enumerate(output_specs): - n = spec["shape"]["width"] - expected_dtype = ( - SignedInteger(n) if self.specs.output_signs[index] else UnsignedInteger(n) + for index, output in enumerate(outputs): + is_signed = self.specs.output_signs[index] + crt_decomposition = ( + output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", []) ) - result = results[index] % (2**n) - if expected_dtype.is_signed: - if isinstance(result, int): - sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n) - sanitized_results.append(sanititzed_result) + if is_signed: + if crt_decomposition: + if isinstance(output, int): + sanititzed_output = ( + output + if output < (int(np.prod(crt_decomposition)) // 2) + else -int(np.prod(crt_decomposition)) + output + ) + else: + output = output.astype(np.longlong) # to prevent overflows in numpy + sanititzed_output = np.where( + output < (np.prod(crt_decomposition) // 2), + output, + -np.prod(crt_decomposition) + output, + ).astype(np.int64) + + sanitized_outputs.append(sanititzed_output) + else: - result = result.astype(np.longlong) # to prevent overflows in numpy - sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n)) - sanitized_results.append(sanititzed_result.astype(np.int64)) + n = output_specs[index]["shape"]["width"] + output %= 2**n + if isinstance(output, int): + sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n) + sanitized_outputs.append(sanititzed_output) + else: + output = output.astype(np.longlong) # to prevent overflows in numpy + sanititzed_output = np.where( + output < (2 ** (n - 1)), output, output - (2**n) + ).astype(np.int64) + sanitized_outputs.append(sanititzed_output) else: - sanitized_results.append( - result if isinstance(result, int) else result.astype(np.uint64) + sanitized_outputs.append( + output if isinstance(output, int) else output.astype(np.uint64) ) - return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) + return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs) @property def evaluation_keys(self) -> EvaluationKeys: diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index dd4c7c8bb..2330331bb 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -28,7 +28,7 @@ from ..internal.utils import assert_that from ..representation import Graph, Node, Operation from ..values import ClearScalar, EncryptedScalar from .node_converter import NodeConverter -from .utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH +from .utils import MAXIMUM_TLU_BIT_WIDTH # pylint: enable=no-member,no-name-in-module @@ -261,14 +261,6 @@ class GraphConverter: first_tlu_node.location, ] - if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS: - offending_nodes[first_signed_node] = [ - f"signed integers are only supported " - f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits " - f"on circuits with table lookups", - first_signed_node.location, - ] - if len(offending_nodes) != 0: raise RuntimeError( "Function you are trying to compile cannot be converted to MLIR:\n\n" diff --git a/concrete/numpy/mlir/utils.py b/concrete/numpy/mlir/utils.py index 74a832853..5a645fbdb 100644 --- a/concrete/numpy/mlir/utils.py +++ b/concrete/numpy/mlir/utils.py @@ -14,7 +14,6 @@ from ..internal.utils import assert_that from ..representation import Node, Operation MAXIMUM_TLU_BIT_WIDTH = 16 -MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS = 8 class HashableNdarray: diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index a917e243d..87a5d6358 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -635,6 +635,13 @@ def deterministic_unary_function(x): }, id="np.expand_dims(x, axis=(0, 1, 2))", ), + pytest.param( + lambda x: x**3, + { + "x": {"status": "encrypted", "range": [-30, 30]}, + }, + id="x ** 3", + ), ], ) def test_others(function, parameters, helpers): diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index bfcfa6ac6..3b1c0cc1c 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -399,35 +399,6 @@ Subgraphs: %5 = astype(%4, dtype=int_) # EncryptedScalar return %5 - """, # noqa: E501 - ), - pytest.param( - lambda x: (10 * np.sin(x + 300)).astype(np.int64), - {"x": "encrypted"}, - range(2**10, 2**11), - RuntimeError, - """ - -Function you are trying to compile cannot be converted to MLIR: - -%0 = x # EncryptedScalar ∈ [1024, 2047] -%1 = 300 # ClearScalar ∈ [300, 300] -%2 = add(%0, %1) # EncryptedScalar ∈ [1324, 2347] -%3 = subgraph(%2) # EncryptedScalar ∈ [-9, 9] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups -return %3 - -Subgraphs: - - %3 = subgraph(%2): - - %0 = input # EncryptedScalar - %1 = sin(%0) # EncryptedScalar - %2 = 10 # ClearScalar - %3 = multiply(%2, %1) # EncryptedScalar - %4 = astype(%3, dtype=int_) # EncryptedScalar - return %4 - """, # noqa: E501 ), ],