diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index 5876f87d2..1dc22e87b 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -6,7 +6,7 @@ import json import shutil import tempfile from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np from concrete.compiler import ( @@ -125,7 +125,7 @@ class Client: if len(args) != len(input_specs): raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}") - sanitized_args = {} + sanitized_args: Dict[int, Union[int, np.ndarray]] = {} for index, spec in enumerate(input_specs): arg = args[index] is_valid = isinstance(arg, (int, np.integer)) or ( @@ -136,7 +136,9 @@ class Client: shape = tuple(spec["shape"]["dimensions"]) is_encrypted = spec["encryption"] is not None - expected_dtype = UnsignedInteger(width) + expected_dtype = ( + SignedInteger(width) if self.specs.input_signs[index] else UnsignedInteger(width) + ) expected_value = Value(expected_dtype, shape, is_encrypted) if is_valid: expected_min = expected_dtype.min() @@ -153,7 +155,14 @@ class Client: ) if is_valid: - sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8) + if isinstance(arg, int) and arg < 0: + sanitized_args[index] = 2 * (expected_max + 1) + arg + else: + sanitized_args[index] = np.where( + arg >= 0, + arg, + 2 * (expected_max + 1) + arg, + ).astype(np.uint8) if not is_valid: actual_value = Value.of(arg, is_encrypted=is_encrypted) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 4b2f1959b..a687e358e 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -5,13 +5,22 @@ Declaration of `GraphConverter` class. # pylint: disable=no-member,no-name-in-module from copy import deepcopy -from typing import Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Tuple, cast import concrete.lang as concretelang import networkx as nx import numpy as np -from mlir.dialects import builtin -from mlir.ir import Context, InsertionPoint, Location, Module +from concrete.lang.dialects import fhe, fhelinalg +from mlir.dialects import arith, builtin +from mlir.ir import ( + Context, + DenseElementsAttr, + InsertionPoint, + IntegerType, + Location, + Module, + RankedTensorType, +) from ..dtypes import Integer, SignedInteger from ..internal.utils import assert_that @@ -58,8 +67,8 @@ class GraphConverter: elif node.operation == Operation.Input: assert_that(len(inputs) == 1) assert_that(inputs[0] == output) - if not isinstance(output.dtype, Integer) or output.dtype.is_signed: - return "only unsigned integer inputs are supported" + if not isinstance(output.dtype, Integer): + return "only integer inputs are supported" else: assert_that(node.operation == Operation.Generic) @@ -283,6 +292,85 @@ class GraphConverter: nx_graph.add_edge(add_offset, node, input_idx=variable_input_index) + @staticmethod + def _sanitize_signed_inputs( + graph: Graph, + args: List[Any], + ctx: Context, + ) -> Tuple[List[Any], List[str]]: + """ + Apply table lookup to signed inputs in the beginning of evaluation to sanitize them. + + Sanitization in this context means to apply a table lookup to obtain proper input values. + + "encrypt" method of "Client" class will convert negative inputs to their corresponding + unsigned value in 2s complement representation. + + Here is an example for 3 bits: + 000 = 0 represents 0 + 001 = 1 represents 1 + 010 = 2 represents 2 + 011 = 3 represents 3 + 100 = 4 represents -4 + 101 = 5 represents -3 + 110 = 6 represents -2 + 111 = 7 represents -1 + + And, the following table lookup is applied before anything else to sanitize the inputs: + [0, 1, 2, 3, -4, -3, -2, -1] + + Args: + graph (Graph): + computation graph being converted + + args (List[Any]): + list of arguments from mlir main + + ctx (Context): + mlir context where the conversion is being performed + + Returns: + Tuple[List[str], List[Any]]: + sanitized args and name of the sanitized variables in MLIR + """ + + sanitized_args = [] + arg_mlir_names = [] + + for i, arg in enumerate(args): + input_node = graph.input_nodes[i] + input_value = input_node.output + + assert_that(isinstance(input_value.dtype, Integer)) + input_dtype = cast(Integer, input_value.dtype) + + if input_dtype.is_signed: + n = input_dtype.bit_width + lut_range = np.arange(2**n) + + lut_values = np.where(lut_range < (2 ** (n - 1)), lut_range, lut_range - (2**n)) + lut_type = RankedTensorType.get( + (2**n,), IntegerType.get_signless(64, context=ctx) + ) + lut_attr = DenseElementsAttr.get(lut_values, context=ctx) + lut = arith.ConstantOp(lut_type, lut_attr).result + + resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value) + if input_value.is_scalar: + sanitized = fhe.ApplyLookupTableEintOp(resulting_type, arg, lut).result + else: + sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result + + sanitized_args.append(sanitized) + mlir_name = str(sanitized).replace("Value(", "").split("=", maxsplit=1)[0].strip() + else: + sanitized_args.append(arg) + mlir_name = f"%arg{i}" + + arg_mlir_names.append(mlir_name) + + return sanitized_args, arg_mlir_names + @staticmethod def convert(graph: Graph, virtual: bool = False) -> str: """ @@ -335,12 +423,18 @@ class GraphConverter: ] @builtin.FuncOp.from_py_func(*parameters) - def main(*arg): + def main(*args): + sanitized_args, arg_mlir_names = GraphConverter._sanitize_signed_inputs( + graph, + args, + ctx, + ) + ir_to_mlir = {} for arg_num, node in graph.input_nodes.items(): - ir_to_mlir[node] = arg[arg_num] + ir_to_mlir[node] = sanitized_args[arg_num] - mlir_name = f"%arg{arg_num}" + mlir_name = arg_mlir_names[arg_num] nodes_to_mlir_names[node] = mlir_name mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num]) diff --git a/tests/conftest.py b/tests/conftest.py index 2f5e287b7..ee7c68aec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,9 +203,6 @@ class Helpers: for description in parameters.values(): minimum, maximum = description.get("range", [0, 127]) - assert minimum >= 0 - assert maximum <= 127 - if "shape" in description: shape = description["shape"] sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64)) diff --git a/tests/execution/test_direct_table_lookup.py b/tests/execution/test_direct_table_lookup.py index 731d875ef..651899ffa 100644 --- a/tests/execution/test_direct_table_lookup.py +++ b/tests/execution/test_direct_table_lookup.py @@ -179,10 +179,34 @@ def test_direct_table_lookup(bits, function, helpers): compiler = cnp.Compiler(function, {"x": "encrypted"}) - inputset = [np.random.randint(0, 2**bits, size=(3, 2), dtype=np.uint8) for _ in range(100)] + inputset = [np.random.randint(0, 2**bits, size=(3, 2)) for _ in range(100)] circuit = compiler.compile(inputset, configuration) - sample = np.random.randint(0, 2**bits, size=(3, 2), dtype=np.uint8) + sample = np.random.randint(0, 2**bits, size=(3, 2)) + helpers.check_execution(circuit, function, sample, retries=10) + + # negative scalar + # --------------- + + compiler = cnp.Compiler(function, {"x": "encrypted"}) + + inputset = range(-(2 ** (bits - 1)), 2 ** (bits - 1)) + circuit = compiler.compile(inputset, configuration) + + sample = int(np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1))) + helpers.check_execution(circuit, function, sample, retries=10) + + # negative tensor + # --------------- + + compiler = cnp.Compiler(function, {"x": "encrypted"}) + + inputset = [ + np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2)) for _ in range(100) + ] + circuit = compiler.compile(inputset, configuration) + + sample = np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2)) helpers.check_execution(circuit, function, sample, retries=10) @@ -209,10 +233,10 @@ def test_direct_multi_table_lookup(helpers): compiler = cnp.Compiler(function, {"x": "encrypted"}) - inputset = [np.random.randint(0, 2**2, size=(3, 2), dtype=np.uint8) for _ in range(100)] + inputset = [np.random.randint(0, 2**2, size=(3, 2)) for _ in range(100)] circuit = compiler.compile(inputset, configuration) - sample = np.random.randint(0, 2**2, size=(3, 2), dtype=np.uint8) + sample = np.random.randint(0, 2**2, size=(3, 2)) helpers.check_execution(circuit, function, sample, retries=10) diff --git a/tests/execution/test_neg.py b/tests/execution/test_neg.py index 0bf0d5a8e..43d042cb3 100644 --- a/tests/execution/test_neg.py +++ b/tests/execution/test_neg.py @@ -17,6 +17,12 @@ import concrete.numpy as cnp { "x": {"range": [0, 64], "status": "encrypted", "shape": (3, 2)}, }, + { + "x": {"range": [-63, 0], "status": "encrypted"}, + }, + { + "x": {"range": [-63, 0], "status": "encrypted", "shape": (3, 2)}, + }, ], ) def test_neg(parameters, helpers): diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index a16c06484..81c6a8544 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -43,7 +43,7 @@ return (%2, %3) Function you are trying to compile cannot be converted to MLIR %0 = x # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported %1 = 1.5 # ClearScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported %2 = multiply(%0, %1) # EncryptedScalar