diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 94f43ead0..add842be9 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -83,6 +83,19 @@ def value_is_scalar(value_to_check: BaseValue) -> bool: return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar +def value_is_integer(value_to_check: BaseValue) -> bool: + """Check that a value is of type Integer. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is of type Integer + """ + + return isinstance(value_to_check.dtype, INTEGER_TYPES) + + def value_is_unsigned_integer(value_to_check: BaseValue) -> bool: """Check that a value is of type Integer and is unsigned. diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index fe691c1d6..569fa84ad 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -17,6 +17,7 @@ from zamalang.dialects import hlfhe from ..data_types.dtypes_helpers import ( value_is_clear_scalar_integer, value_is_clear_tensor_integer, + value_is_encrypted_scalar_integer, value_is_encrypted_scalar_unsigned_integer, value_is_encrypted_tensor_integer, ) @@ -30,18 +31,18 @@ def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert an addition intermediate node.""" assert_true(len(node.inputs) == 2, "addition should have two inputs") assert_true(len(node.outputs) == 1, "addition should have a single output") - if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( + if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer( node.inputs[1] ): return _add_eint_int(node, preds, ir_to_mlir_node, ctx) - if value_is_encrypted_scalar_unsigned_integer(node.inputs[1]) and value_is_clear_scalar_integer( + if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer( node.inputs[0] ): # flip lhs and rhs return _add_eint_int(node, preds[::-1], ir_to_mlir_node, ctx) - if value_is_encrypted_scalar_unsigned_integer( - node.inputs[0] - ) and value_is_encrypted_scalar_unsigned_integer(node.inputs[1]): + if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer( + node.inputs[1] + ): return _add_eint_eint(node, preds, ir_to_mlir_node, ctx) raise TypeError( f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}" @@ -74,7 +75,7 @@ def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert a subtraction intermediate node.""" assert_true(len(node.inputs) == 2, "subtraction should have two inputs") assert_true(len(node.outputs) == 1, "subtraction should have a single output") - if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer( + if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer( node.inputs[1] ): return _sub_int_eint(node, preds, ir_to_mlir_node, ctx) @@ -98,11 +99,11 @@ def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert a multiplication intermediate node.""" assert_true(len(node.inputs) == 2, "multiplication should have two inputs") assert_true(len(node.outputs) == 1, "multiplication should have a single output") - if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( + if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer( node.inputs[1] ): return _mul_eint_int(node, preds, ir_to_mlir_node, ctx) - if value_is_encrypted_scalar_unsigned_integer(node.inputs[1]) and value_is_clear_scalar_integer( + if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer( node.inputs[0] ): # flip lhs and rhs @@ -131,8 +132,6 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No value = cast(TensorValue, value) dtype = cast(Integer, value.dtype) - if dtype.is_signed: - raise TypeError("Don't support signed constant integer") data = node.constant_data int_type = IntegerType.get_signless(dtype.bit_width, context=ctx) @@ -142,8 +141,6 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No value = cast(TensorValue, value) dtype = cast(Integer, value.dtype) - if dtype.is_signed: - raise TypeError("Don't support signed constant integer tensor") data = node.constant_data int_type = IntegerType.get_signless(dtype.bit_width, context=ctx) diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index cc75113c0..b46d9e355 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -7,6 +7,7 @@ from ..data_types.dtypes_helpers import ( value_is_clear_tensor_integer, value_is_encrypted_scalar_integer, value_is_encrypted_tensor_integer, + value_is_integer, value_is_scalar, value_is_unsigned_integer, ) @@ -58,8 +59,10 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) elif isinstance(node, intermediate.Constant): # constraints for constants assert_true(len(outputs) == 1) - if not value_is_unsigned_integer(outputs[0]): - return "only unsigned integer constants are supported" + # We currently can't fail on the following assert, but let it for possible changes in the + # future + if not value_is_integer(outputs[0]): + return "only integer constants are supported" # pragma: no cover elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions assert_true(len(inputs) == 1) @@ -84,8 +87,10 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) return "only scalar unsigned integer outputs are supported" else: for out in outputs: - if not value_is_unsigned_integer(out): - return "only unsigned integer intermediates are supported" + # We currently can't fail on the following assert, but let it for possible changes in + # the future + if not value_is_integer(out): + return "only integer intermediates are supported" # pragma: no cover # pylint: enable=too-many-branches,too-many-return-statements diff --git a/tests/common/mlir/test_converters.py b/tests/common/mlir/test_converters.py index d9f590669..8a292e6b5 100644 --- a/tests/common/mlir/test_converters.py +++ b/tests/common/mlir/test_converters.py @@ -37,14 +37,6 @@ def test_fail_non_integer_const(): constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None) -def test_fail_signed_integer_const(): - """Test failing constant converter with non-integer""" - with pytest.raises(TypeError, match=r"Don't support signed constant integer"): - constant(MockNode(outputs=[ClearScalar(Integer(8, True))]), None, None, None) - with pytest.raises(TypeError, match=r"Don't support signed constant integer tensor"): - constant(MockNode(outputs=[ClearTensor(Integer(8, True), shape=(2,))]), None, None, None) - - @pytest.mark.parametrize( "input_node", [ diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 95a21786e..1f134850d 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -482,6 +482,15 @@ def test_compile_function_multiple_outputs( @pytest.mark.parametrize( "function,input_ranges,list_of_arg_names", [ + pytest.param(lambda x: (-27) + 4 * (x + 8), ((0, 10),), ["x"]), + pytest.param(lambda x: x + (-33), ((40, 60),), ["x"]), + pytest.param(lambda x: 17 - (0 - x), ((0, 10),), ["x"]), + pytest.param(lambda x: 42 + x * (-3), ((0, 10),), ["x"]), + pytest.param(lambda x: 43 + (-4) * x, ((0, 10),), ["x"]), + pytest.param(lambda x: 3 - (-5) * x, ((0, 10),), ["x"]), + pytest.param(lambda x: (-2) * (-5) * x, ((0, 10),), ["x"]), + pytest.param(lambda x: (-2) * x * (-5), ((0, 10),), ["x"]), + pytest.param(lambda x, y: 40 - (-3 * x) + (-2 * y), ((0, 20), (0, 20)), ["x", "y"]), pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]), pytest.param(lambda x: x + 64, ((0, 10),), ["x"]), pytest.param(lambda x: x * 3, ((0, 40),), ["x"]), @@ -747,20 +756,6 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura "return(%2)\n" ), ), - pytest.param( - lambda x: x + (-1), - {"x": EncryptedScalar(Integer(4, is_signed=False))}, - [(i,) for i in range(1, 2 ** 4)], - ( - "function you are trying to compile isn't supported for MLIR lowering\n" - "\n" - "%0 = x # EncryptedScalar>\n" # noqa: E501 - "%1 = Constant(-1) # ClearScalar>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 - "%2 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 - "return(%2)\n" - ), - ), pytest.param( lambda x: x + 1, {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))}, @@ -869,16 +864,16 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura "function you are trying to compile isn't supported for MLIR lowering\n\n" "%0 = x # EncryptedScalar>\n" # noqa: E501 "%1 = Constant(2.8) # ClearScalar>\n" - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported\n" # noqa: E501 "%2 = y # EncryptedScalar>\n" # noqa: E501 "%3 = Constant(9.3) # ClearScalar>\n" - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported\n" # noqa: E501 "%4 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501 "%5 = Add(%2, %3) # EncryptedScalar>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501 "%6 = Add(%4, %5) # EncryptedScalar>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501 "%7 = astype(int32)(%6) # EncryptedScalar>\n" # noqa: E501 "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 "return(%7)\n" @@ -930,9 +925,7 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration) "%1 = Constant(10) # ClearScalar>\n" # noqa: E501 "%2 = x # EncryptedScalar>\n" # noqa: E501 "%3 = np.negative(%2) # EncryptedScalar>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 "%4 = Mul(%3, %1) # EncryptedScalar>\n" # noqa: E501 - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 "%5 = np.absolute(%4) # EncryptedScalar>\n" # noqa: E501 "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 "%6 = astype(int32)(%5) # EncryptedScalar>\n" # noqa: E501 @@ -1103,3 +1096,38 @@ def test_compile_too_high_bitwidth(default_compilation_configuration): data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), default_compilation_configuration, ) + + +def test_failure_for_signed_output(default_compilation_configuration): + """Test that we don't accept signed output""" + function = lambda x: x + (-3) # pylint: disable=unnecessary-lambda # noqa: E731 + input_ranges = ((0, 10),) + list_of_arg_names = ["x"] + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names + } + + with pytest.raises(RuntimeError) as excinfo: + compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + default_compilation_configuration, + ) + + # pylint: disable=line-too-long + assert ( + str(excinfo.value) + == "function you are trying to compile isn't supported for MLIR lowering\n\n" # noqa: E501 + "%0 = x # EncryptedScalar>\n" # noqa: E501 + "%1 = Constant(-3) # ClearScalar>\n" # noqa: E501 + "%2 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501 + "return(%2)\n" + ) + # pylint: enable=line-too-long