diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index c464d9d43..b6297eef8 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -115,8 +115,14 @@ def check_node_compatibility_with_mlir( if is_output: for out in outputs: - if not value_is_unsigned_integer(out): - return "only unsigned integer outputs are supported" + # For signed values and waiting for a real fix (#845): what is returned by the compiler + # is not the (possibly negative) result r, but the always-positive (r mod 2**t), where t + # is the bitwidth of r + + # We currently can't fail on the following assert, but let it for possible changes in + # the future + if not value_is_integer(out): + return "only integer outputs are supported" # pragma: no cover else: for out in outputs: # We currently can't fail on the following assert, but let it for possible changes in diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 40be43ddb..4ef224683 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1033,24 +1033,6 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura @pytest.mark.parametrize( "function,parameters,inputset,match", [ - pytest.param( - lambda x: 1 - x, - {"x": EncryptedScalar(Integer(3, is_signed=False))}, - [(i,) for i in range(8)], - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = 1 # ClearScalar -%1 = x # EncryptedScalar -%2 = sub(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported -return %2 - - """.strip() # noqa: E501 - ), - ), pytest.param( lambda x, y: numpy.dot(x, y), { @@ -1476,44 +1458,6 @@ return %2 ) -def test_failure_for_signed_output(default_compilation_configuration): - """Test that we don't accept signed output""" - function = lambda x: x + (-3) # noqa: E731 - input_ranges = ((0, 10),) - list_of_arg_names = ["x"] - - def data_gen(args): - for prod in itertools.product(*args): - yield prod - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - with pytest.raises(RuntimeError) as excinfo: - compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - assert ( - str(excinfo.value) - == """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedScalar -%1 = -3 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported -return %2 - -""".strip() # noqa: E501 # pylint: disable=line-too-long - ) - - def test_compile_with_random_inputset(default_compilation_configuration): """Test function for compile with random input set""" @@ -1649,3 +1593,41 @@ def test_compile_and_run_correctness_with_negative_values( args = [random.randint(low, high) for (low, high) in input_ranges] assert compiler_engine.run(*args) == function(*args) + + +def check_equality_modulo(a, b, modulus): + """Check that (a mod modulus) == (b mod modulus)""" + return (a % modulus) == (b % modulus) + + +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names,modulus", + [ + pytest.param(lambda x: x + (-20), ((0, 10),), ["x"], 128), + pytest.param(lambda x: 10 + x * (-3), ((0, 20),), ["x"], 128), + ], +) +def test_compile_and_run_correctness_with_negative_results( + function, input_ranges, list_of_arg_names, modulus, default_compilation_configuration +): + """Test correctness of computations when the result is possibly negative: until #845 is fixed, + results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg, + instead of returning -3, the execution may return -3 mod 128 = 125.""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names + } + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + default_compilation_configuration, + ) + + args = [random.randint(low, high) for (low, high) in input_ranges] + assert check_equality_modulo(compiler_engine.run(*args), function(*args), modulus)