diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index b1ae53264..bf48470ca 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -71,6 +71,30 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool: ) +def value_is_scalar(value_to_check: BaseValue) -> bool: + """Check that a value is a scalar. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is a scalar + """ + return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar + + +def value_is_integer(value_to_check: BaseValue) -> bool: + """Check that a value is of type Integer. + + Args: + value_to_check (BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is of type Integer + """ + return isinstance(value_to_check.dtype, INTEGER_TYPES) + + def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: """Check that a value is an encrypted TensorValue of type Integer. diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 6c37cddf1..ad2e302b8 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -1,5 +1,5 @@ """Utilities for MLIR conversion.""" -from typing import cast +from typing import Dict, Optional, cast from ..data_types import Integer from ..data_types.dtypes_helpers import ( @@ -7,32 +7,41 @@ from ..data_types.dtypes_helpers import ( value_is_clear_tensor_integer, value_is_encrypted_scalar_integer, value_is_encrypted_tensor_integer, - value_is_scalar_integer, + value_is_integer, + value_is_scalar, ) from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph -from ..representation.intermediate import UnivariateFunction +from ..representation.intermediate import IntermediateNode, UnivariateFunction # TODO: should come from compiler, through an API, #402 ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7 -def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool: +def check_graph_values_compatibility_with_mlir( + op_graph: OPGraph, +) -> Optional[Dict[IntermediateNode, str]]: """Make sure the graph outputs are unsigned integers, which is what the compiler supports. Args: op_graph: computation graph to check Returns: - bool: is the graph compatible with the expected MLIR representation + Dict[IntermediateNode, str]: None if the graph is compatible + information about offending nodes otherwise """ - return all( - all( - value_is_scalar_integer(out) and not cast(Integer, out.dtype).is_signed - for out in out_node.outputs - ) - for out_node in op_graph.output_nodes.values() - ) + + offending_nodes = {} + + for out_node in op_graph.output_nodes.values(): + for out in out_node.outputs: + if not value_is_scalar(out): + offending_nodes[out_node] = "non scalar outputs aren't supported" + + if value_is_integer(out) and cast(Integer, out.dtype).is_signed: + offending_nodes[out_node] = "signed integer outputs aren't supported" + + return None if len(offending_nodes) == 0 else offending_nodes def _set_all_bit_width(op_graph: OPGraph, p: int): diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index f45b4df33..fa64e54c3 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -11,11 +11,12 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in from ..common.common_helpers import check_op_graph_is_integer_program from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Integer +from ..common.debugging import get_printable_graph from ..common.fhe_circuit import FHECircuit from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter from ..common.mlir.utils import ( + check_graph_values_compatibility_with_mlir, extend_direct_lookup_tables, - is_graph_values_compatible_with_mlir, update_bit_width_for_mlir, ) from ..common.operator_graph import OPGraph @@ -162,8 +163,12 @@ def _compile_numpy_function_into_op_graph_internal( compilation_artifacts.add_operation_graph("final", op_graph) # Make sure the graph can be lowered to MLIR - if not is_graph_values_compatible_with_mlir(op_graph): - raise RuntimeError("function you are trying to compile isn't supported for MLIR lowering") + offending_nodes = check_graph_values_compatibility_with_mlir(op_graph) + if offending_nodes is not None: + raise RuntimeError( + "function you are trying to compile isn't supported for MLIR lowering\n\n" + + get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes) + ) # Update bit_width for MLIR update_bit_width_for_mlir(op_graph) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 9af843c9d..5a4662b4d 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -560,29 +560,50 @@ def test_compile_function_with_direct_tlu_overflow(): @pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", + "function,parameters,inputset,match", [ - pytest.param(lambda x: x - 10, ((-5, 5),), ["x"]), + pytest.param( + lambda x: 1 - x, + {"x": EncryptedScalar(Integer(3, is_signed=False))}, + [(i,) for i in range(8)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n" + "\n" + "%0 = Constant(1) # ClearScalar>\n" # noqa: E501 + "%1 = x # EncryptedScalar>\n" # noqa: E501 + "%2 = Sub(%0, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integer outputs aren't supported\n" # noqa: E501 + "return(%2)\n" + ), + ), + pytest.param( + lambda x: x + 1, + {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))}, + [(numpy.random.randint(0, 8, size=(2, 2)),) for i in range(10)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n" + "\n" + "%0 = x # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "%1 = Constant(1) # ClearScalar>\n" # noqa: E501 + "%2 = Add(%0, %1) # EncryptedTensor, shape=(2, 2)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ non scalar outputs aren't supported\n" # noqa: E501 + "return(%2)\n" + ), + ), ], ) -def test_fail_compile(function, input_ranges, list_of_arg_names): +def test_fail_compile(function, parameters, inputset, match): """Test function compile_numpy_function_into_op_graph for a program with signed values""" - def data_gen(args): - for prod in itertools.product(*args): - yield prod - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names - } - - with pytest.raises(RuntimeError, match=".*isn't supported for MLIR lowering.*"): + try: compile_numpy_function( function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + parameters, + inputset, CompilationConfiguration(dump_artifacts_on_unexpected_failures=False), ) + except RuntimeError as error: + assert str(error) == match def test_small_inputset():