diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index 6e623e06b..a4d48ee3f 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -270,7 +270,7 @@ class IntermediateNodeConverter: value = self.node.inputs[variable_input_index] assert_true(value.is_encrypted) - if not isinstance(value.dtype, Integer) or value.dtype.is_signed: # pragma: no cover + if not isinstance(value.dtype, Integer): # pragma: no cover # this branch is not covered as it's impossible to get into due to how compilation works # however, it doesn't hurt to keep it as an extra measure raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet") diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 2f0944e16..6fb6f957f 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -89,17 +89,6 @@ def check_node_compatibility_with_mlir( ) == 1 ) - - if not value_is_unsigned_integer(inputs[0]): - # this branch is not reachable because compilation fails during inputset evaluation - if node.op_name == "TLU": # pragma: no cover - return "only unsigned integer lookup tables are supported" - - if node.op_name == "MultiTLU": # pragma: no cover - return "only unsigned integer multi lookup tables are supported" - - # e.g., `np.absolute is not supported for the time being` - return f"{node.op_name} is not supported for the time being" else: return f"{node.op_name} is not supported for the time being" diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 7970ee188..383f3e9d7 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -370,8 +370,8 @@ class GenericFunction(IntermediateNode): def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]: """Get the table for the current input value of this GenericFunction. - This function only works if the GenericFunction variable input value is an unsigned Integer. - It only works if there is a single variable input node among ordered_preds. + This function only works if the GenericFunction variable input value is an Integer. + This function only works if there is a single variable input node among ordered_preds. Args: ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must @@ -393,16 +393,12 @@ class GenericFunction(IntermediateNode): variable_input_idx = variable_input_indices[0] variable_input_dtype = self.inputs[variable_input_idx].dtype - # Check the input is an unsigned integer to be able to build a table + # Check the input is an integer to be able to build a table assert_true( isinstance(variable_input_dtype, Integer), f"{self.get_table.__name__} only works for an unsigned Integer input", ) variable_input_dtype = cast(Integer, variable_input_dtype) - assert_true( - not variable_input_dtype.is_signed, - f"{self.get_table.__name__} only works for an unsigned Integer input", - ) input_value_constructor = self.inputs[0].underlying_constructor if input_value_constructor is None: diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index eaeea692a..872871482 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -2,7 +2,8 @@ import sys import traceback -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from copy import deepcopy +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast import numpy from zamalang import CompilerEngine @@ -21,7 +22,8 @@ from ..common.mlir.utils import ( ) from ..common.operator_graph import OPGraph from ..common.optimization.topological import fuse_float_operations -from ..common.values import BaseValue +from ..common.representation.intermediate import Add, Constant, GenericFunction +from ..common.values import BaseValue, ClearScalar from ..numpy.tracing import trace_numpy_function from .np_dtypes_helpers import ( get_base_data_type_for_numpy_or_python_constant_data, @@ -311,7 +313,66 @@ def compile_numpy_function_into_op_graph_and_measure_bounds( raise -def prepare_op_graph_for_mlir(op_graph): +# HACK +# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001 is +# done +# TODO: https://github.com/zama-ai/concretefhe-internal/issues/1015 +def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None: + """Hack the op_graph to add offsets to signed inputs to TLUs. + + Args: + op_graph (OPGraph): the OPGraph to hack. + """ + # Ugly hack to add an offset before entering a TLU if its variable input node has a signed + # output. + # This is ugly as this makes hardcoded assumptions about the way bit widths are handled in MLIR. + # This does not update the TLU input values to allow for proper table generation. + # Thankfully we are not supposed to touch the op_graph beyond that point + for node in list((nx_graph := op_graph.graph).nodes): + if isinstance(node, GenericFunction): + ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node) + variable_input_indices = [ + idx + for idx, (pred, _) in enumerate(ordered_preds_and_inputs) + if not isinstance(pred, Constant) + ] + assert_true(len(variable_input_indices) == 1) + variable_input_idx = variable_input_indices[0] + variable_input_node = ordered_preds_and_inputs[variable_input_idx][0] + variable_input_value = variable_input_node.outputs[0] + variable_input_dtype = variable_input_value.dtype + assert_true(isinstance(variable_input_dtype, Integer)) + variable_input_dtype = cast(Integer, variable_input_dtype) + if not variable_input_dtype.is_signed: + continue + + # input_bit_width + 1 to be MLIR compliant + input_bit_width = variable_input_dtype.bit_width + mlir_compliant_int_type = Integer(input_bit_width + 1, True) + + # Manually fix the output values to be MLIR compliant + # offset_constant is set to abs(min_value) for the variable input so that the values + # [- 2 ** (n - 1); 2 ** (n - 1) - 1] is mapped to [0; 2 ** n - 1], changing the signed + # TLU to an actual unsigned TLU. The get_table function creates the table from the min + # value to the max value. As we keep the input value as a signed value, it will be from + # - 2 ** (n - 1) to 2 ** (n - 1) - 1. Then, the get_table function stores corresponding + # values in increasing indexes from 0 to 2 ** n - 1. As our signed values have been + # shifted by 2 ** (n - 1), the table will be usable as-is, without needing any change in + # the lambda function of the GenericFunction. + offset_constant = Constant(abs(variable_input_dtype.min_value())) + offset_constant.outputs[0].dtype = deepcopy(mlir_compliant_int_type) + add_offset = Add( + [deepcopy(variable_input_value), ClearScalar(deepcopy(mlir_compliant_int_type))] + ) + add_offset.outputs[0] = deepcopy(variable_input_value) + + nx_graph.remove_edge(variable_input_node, node) + nx_graph.add_edge(variable_input_node, add_offset, input_idx=0, output_idx=0) + nx_graph.add_edge(offset_constant, add_offset, input_idx=1, output_idx=0) + nx_graph.add_edge(add_offset, node, input_idx=variable_input_idx, output_idx=0) + + +def prepare_op_graph_for_mlir(op_graph: OPGraph): """Prepare OPGraph for MLIR lowering. This includes checking compatibility, changing bit-widths, and modifying lookup tables. @@ -337,6 +398,12 @@ def prepare_op_graph_for_mlir(op_graph): # TODO: workaround extend LUT #359 extend_direct_lookup_tables(op_graph) + # HACK + # TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001 + # is done + # TODO: https://github.com/zama-ai/concretefhe-internal/issues/1015 + hack_offset_negative_inputs_to_lookup_tables(op_graph) + def _compile_numpy_function_internal( function_to_compile: Callable, diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index ae63a463f..6c0d40f85 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -267,8 +267,8 @@ def check_is_good_execution(compiler_engine, function, args, verbose=True): # Bad computation after nb_tries raise AssertionError( f"bad computation after {nb_tries} tries, which was supposed to happen with a " - f"probability of {expected_bad_luck}.\nLast engine result: {last_engine_result} " - f"last function result: {last_function_result}" + f"probability of {expected_bad_luck}.\nLast engine result:\n{last_engine_result}\n" + f"Last function result:\n{last_function_result}" ) @@ -725,7 +725,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( @pytest.mark.parametrize( - "function,parameters,inputset,test_input,expected_output", + "function,parameters,inputset,test_input,use_check_good_exec", [ pytest.param( lambda x: x + 1, @@ -740,11 +740,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [1, 8], - [7, 2], - [3, 6], - ], + False, ), pytest.param( lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32), @@ -759,11 +755,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [1, 7], - [8, 1], - [5, 6], - ], + False, ), pytest.param( lambda x, y: x + y, @@ -786,11 +778,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( ], 2, ), - [ - [2, 9], - [8, 3], - [4, 7], - ], + False, ), pytest.param( lambda x, y: x + y, @@ -817,11 +805,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [3, 4], ], ), - [ - [1, 13], - [8, 6], - [5, 9], - ], + False, ), pytest.param( lambda x: 100 - x, @@ -836,11 +820,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [100, 93], - [94, 99], - [98, 95], - ], + False, ), pytest.param( lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x, @@ -855,11 +835,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [10, 8], - [14, 14], - [8, 25], - ], + False, ), pytest.param( lambda x: x * 2, @@ -874,11 +850,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [0, 14], - [12, 2], - [4, 10], - ], + False, ), pytest.param( lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32), @@ -893,11 +865,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [4, 14], - [12, 1], - [6, 5], - ], + False, ), pytest.param( lambda x: LookupTable([2, 1, 3, 0])[x], @@ -912,11 +880,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [3, 0], ], ), - [ - [2, 1], - [3, 1], - [0, 2], - ], + True, ), pytest.param( lambda x: numpy.dot(x, 2), @@ -925,7 +889,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( }, [(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], ([2, 7, 1],), - [4, 14, 2], + False, ), pytest.param( lambda x: numpy.dot(2, x), @@ -934,7 +898,7 @@ def test_compile_and_run_correctness__for_prog_with_tlu( }, [(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], ([2, 7, 1],), - [4, 14, 2], + False, ), pytest.param( lambda x: numpy.clip(x, 1, 5), @@ -949,11 +913,22 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [1, 5], - [5, 1], - [2, 5], - ], + True, + ), + pytest.param( + lambda x: numpy.clip(x + (-4), -3, 5) + 3, + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), + }, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)], + ( + [ + [0, 7], + [6, 1], + [2, 5], + ], + ), + True, ), pytest.param( lambda x: x.clip(1, 5), @@ -968,16 +943,32 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [2, 5], ], ), - [ - [1, 5], - [5, 1], - [2, 5], - ], + True, + ), + pytest.param( + lambda x: (x + (-4)).clip(-3, 5) + 3, + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), + }, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)], + ( + [ + [0, 7], + [6, 1], + [2, 5], + ], + ), + True, ), ], ) def test_compile_and_run_tensor_correctness( - function, parameters, inputset, test_input, expected_output, default_compilation_configuration + function, + parameters, + inputset, + test_input, + use_check_good_exec, + default_compilation_configuration, ): """Test correctness of results when running a compiled function with tensor operators""" circuit = compile_numpy_function( @@ -987,14 +978,18 @@ def test_compile_and_run_tensor_correctness( default_compilation_configuration, ) - numpy_test_input = ( + numpy_test_input = tuple( item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8) for item in test_input ) - assert numpy.array_equal( - circuit.run(*numpy_test_input), - numpy.array(expected_output, dtype=numpy.uint8), - ) + + if use_check_good_exec: + check_is_good_execution(circuit, function, numpy_test_input) + else: + assert numpy.array_equal( + circuit.run(*numpy_test_input), + numpy.array(function(*numpy_test_input), dtype=numpy.uint8), + ) @pytest.mark.parametrize( @@ -1544,47 +1539,6 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation assert str(excinfo.value) == match, str(excinfo.value) -def test_fail_with_intermediate_signed_values(default_compilation_configuration): - """Test function with failing compilation due to intermediate signed integers.""" - - def function(x, y): - z = numpy.abs(10 * numpy.negative(x)) - z = z.astype(numpy.int32) + y - return z - - with pytest.raises(RuntimeError): - try: - compile_numpy_function( - function, - { - "x": EncryptedScalar(Integer(2, is_signed=False)), - "y": EncryptedScalar(Integer(2, is_signed=False)), - }, - [(i, j) for i in range(2 ** 2) for j in range(2 ** 2)], - default_compilation_configuration, - show_mlir=True, - ) - except RuntimeError as error: - match = """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = y # EncryptedScalar -%1 = 10 # ClearScalar -%2 = x # EncryptedScalar -%3 = negative(%2) # EncryptedScalar -%4 = mul(%3, %1) # EncryptedScalar -%5 = absolute(%4) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ absolute is not supported for the time being -%6 = astype(%5, dtype=int32) # EncryptedScalar -%7 = add(%6, %0) # EncryptedScalar -return %7 - - """.strip() # noqa: E501 # pylint: disable=line-too-long - assert str(error) == match - raise - - def test_small_inputset_no_fail(): """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" compile_numpy_function_into_op_graph_and_measure_bounds(