diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index da7b7a61f..10c90afa9 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -158,14 +158,13 @@ class Client: ) if is_valid: - if isinstance(arg, int) and arg < 0: - sanitized_args[index] = 2 * (expected_max + 1) + arg + is_signed = self.specs.input_signs[index] + sanitizer = 0 if not is_signed else 2 ** (width - 1) + + if isinstance(arg, int): + sanitized_args[index] = arg + sanitizer else: - sanitized_args[index] = np.where( - arg >= 0, - arg, - 2 * (expected_max + 1) + arg, - ).astype(np.uint64) + sanitized_args[index] = (arg + sanitizer).astype(np.uint64) if not is_valid: actual_value = Value.of(arg, is_encrypted=is_encrypted) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 2330331bb..a13029b36 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -13,9 +13,10 @@ import numpy as np from concrete.lang.dialects import fhe, fhelinalg from mlir.dialects import arith, func from mlir.ir import ( + Attribute, Context, - DenseElementsAttr, InsertionPoint, + IntegerAttr, IntegerType, Location, Module, @@ -551,25 +552,26 @@ class GraphConverter: if input_dtype.is_signed: assert_that(input_value.is_encrypted) - n = input_dtype.bit_width - lut_range = np.arange(2**n) - lut_values = np.where(lut_range < (2 ** (n - 1)), lut_range, lut_range - (2**n)) - lut_type = RankedTensorType.get( - (2**n,), IntegerType.get_signless(64, context=ctx) - ) - lut_attr = DenseElementsAttr.get(lut_values, context=ctx) - # ConstantOp is being decorated, and the init function is supposed to take more - # arguments than those pylint is considering + sanitizer_type = IntegerType.get_signless(n + 1) + sanitizer = 2 ** (n - 1) + + if input_value.is_scalar: + sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer) + else: + sanitizer_type = RankedTensorType.get((1,), sanitizer_type) + sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}") + # pylint: disable=too-many-function-args - lut = arith.ConstantOp(lut_type, lut_attr).result + sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr) # pylint: enable=too-many-function-args + resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value) if input_value.is_scalar: - sanitized = fhe.ApplyLookupTableEintOp(resulting_type, arg, lut).result + sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result else: - sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result + sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result sanitized_args.append(sanitized) else: