diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index b1242ab2c..e2c588554 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -102,8 +102,6 @@ def check_node_compatibility_with_mlir( elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication assert_true(len(inputs) == 2) - if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]): - return "only unsigned integer matrix multiplication is supported" else: # pragma: no cover assert_not_reached("Non IntermediateNode object in the OPGraph") diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index b80fdc71f..ed23f208e 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -14,6 +14,7 @@ from concrete.common.extensions.table import LookupTable from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor from concrete.numpy import tracing from concrete.numpy.compile import ( + FHECircuit, compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds, ) @@ -1353,6 +1354,11 @@ def test_compile_and_run_constant_dot_correctness( (5, 5), (0, 4), ), + pytest.param( + (3, 2), + (2, 3), + (-4, 3), + ), ], ) def test_compile_and_run_matmul_correctness( @@ -1362,10 +1368,16 @@ def test_compile_and_run_matmul_correctness( low, high = input_range + check_mod = low < 0 or high < 0 + + max_abs = max(abs(low), abs(high)) + + # Inputset for x as lhs of matmul lhs_inputset = [ numpy.zeros(lhs_shape, dtype=numpy.uint32), numpy.ones(lhs_shape, dtype=numpy.uint32) * high, ] + # Inputset for x as rhs of matmul rhs_inputset = [ numpy.zeros(rhs_shape, dtype=numpy.uint32), numpy.ones(rhs_shape, dtype=numpy.uint32) * high, @@ -1377,6 +1389,32 @@ def test_compile_and_run_matmul_correctness( left_constant = numpy.random.randint(low, high + 1, size=lhs_shape) right_constant = numpy.random.randint(low, high + 1, size=rhs_shape) + # Generate worst case inputsets for bit widths, replacing negative values by 0 and putting + # the max value elsewhere, and then doing the same for positive values + rhs_inputset.extend( + [ + numpy.where(right_constant < 0, 0, max_abs), + numpy.where(right_constant > 0, 0, max_abs), + ] + ) + lhs_inputset.extend( + [ + numpy.where(left_constant < 0, 0, max_abs), + numpy.where(left_constant > 0, 0, max_abs), + ] + ) + + # Keep inputset positive + rhs_inputset = [numpy.clip(val, 0, high) for val in rhs_inputset] + lhs_inputset = [numpy.clip(val, 0, high) for val in lhs_inputset] + + def get_output_mod(circuit: FHECircuit): + assert len(circuit.op_graph.output_nodes) == 1 + assert isinstance( + output_dtype := circuit.op_graph.get_ordered_outputs()[0].outputs[0].dtype, Integer + ) + return 2 ** output_dtype.bit_width + def using_operator_left(x): return x @ right_constant @@ -1414,13 +1452,28 @@ def test_compile_and_run_matmul_correctness( default_compilation_configuration, ) - lhs_arg = numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8) - check_array_equality(operator_left_circuit.run(lhs_arg), using_operator_left(lhs_arg)) - check_array_equality(function_left_circuit.run(lhs_arg), using_function_left(lhs_arg)) + def check_result(circuit: FHECircuit, func, arg): + # Stay positive for input to FHE circuit + arg = numpy.clip(arg, 0, high).astype(numpy.uint8) - rhs_arg = numpy.random.randint(low, high + 1, size=rhs_shape, dtype=numpy.uint8) - check_array_equality(operator_right_circuit.run(rhs_arg), using_operator_right(rhs_arg)) - check_array_equality(function_right_circuit.run(rhs_arg), using_function_right(rhs_arg)) + circuit_output = circuit.run(arg) + func_output = func(arg) + + if check_mod: + output_mod = get_output_mod(circuit) + + circuit_output %= output_mod + func_output %= output_mod + + check_array_equality(circuit_output, func_output) + + arg = numpy.random.randint(low, high + 1, size=lhs_shape) + check_result(operator_left_circuit, using_operator_left, arg) + check_result(function_left_circuit, using_function_left, arg) + + arg = numpy.random.randint(low, high + 1, size=rhs_shape) + check_result(operator_right_circuit, using_operator_right, arg) + check_result(function_right_circuit, using_function_right, arg) @pytest.mark.parametrize( @@ -1643,24 +1696,6 @@ return %9 """.strip() # noqa: E501 ), ), - pytest.param( - lambda x: x @ -numpy.ones(shape=(2, 3), dtype=numpy.int32), - {"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -%1 = [[-1 -1 -1] [-1 -1 -1]] # ClearTensor -%2 = matmul(%0, %1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer matrix multiplication is supported -return %2 - - """.strip() # noqa: E501 - ), - ), pytest.param( lambda x: numpy.transpose(x), {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},