diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index a4d48ee3f..80affd9a7 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -7,11 +7,13 @@ from typing import Any, Dict, List, Tuple, cast import numpy -from mlir.dialects import arith +from mlir.dialects import arith, tensor from mlir.ir import ( + ArrayAttr, Attribute, Context, DenseElementsAttr, + IndexType, IntegerAttr, IntegerType, OpResult, @@ -21,12 +23,14 @@ from zamalang.dialects import hlfhe, hlfhelinalg from ..data_types import Integer from ..debugging import assert_true +from ..helpers.indexing_helpers import determine_new_dimension_size from ..operator_graph import OPGraph from ..representation.intermediate import ( Add, Constant, Dot, GenericFunction, + IndexConstant, IntermediateNode, MatMul, Mul, @@ -101,6 +105,8 @@ class IntermediateNodeConverter: str: textual MLIR representation corresponding to self.node """ + # pylint: disable=too-many-branches + if isinstance(self.node, Add): result = self.convert_add() @@ -113,6 +119,9 @@ class IntermediateNodeConverter: elif isinstance(self.node, GenericFunction): result = self.convert_generic_function(additional_conversion_info) + elif isinstance(self.node, IndexConstant): + result = self.convert_index_constant() + elif isinstance(self.node, MatMul): result = self.convert_matmul() @@ -126,6 +135,8 @@ class IntermediateNodeConverter: # this branch is not covered as unsupported opeations fail on check mlir compatibility raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet") + # pylint: enable=too-many-branches + mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip() self.nodes_to_mlir_names[self.node] = mlir_name @@ -319,6 +330,98 @@ class IntermediateNodeConverter: return result + def convert_index_constant(self) -> OpResult: + """Convert a IndexConstant node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 1) + assert_true(len(self.node.outputs) == 1) + + tensor_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + pred = self.preds[0] + + input_value = cast(TensorValue, self.node.inputs[0]) + input_shape = input_value.shape + + index = cast(IndexConstant, self.node).index + index_str = self.node.text_for_formatting([""], 0) + + index_type = IndexType.parse("index") + + if len(index) == len(input_shape) and all(isinstance(i, int) for i in index): + indices = [] + for value, dimension_size in zip(index, input_shape): + assert isinstance(value, int) # mypy + attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size) + indices.append(arith.ConstantOp(index_type, attr).result) + return tensor.ExtractOp(tensor_type, pred, indices).result + + offsets = [] + sizes = [] + strides = [] + + can_be_converted = True + for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): + + if isinstance(indexing_element, int): + size = 1 + stride = 1 + offset = ( + indexing_element if indexing_element >= 0 else indexing_element + dimension_size + ) + + elif isinstance(indexing_element, slice): + size = determine_new_dimension_size( + indexing_element, + dimension_size, + dimension, + input_shape, + index_str, + ) + if size == 1: + can_be_converted = False + break + + stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 + offset = ( + ( + indexing_element.start + if indexing_element.start >= 0 + else indexing_element.start + dimension_size + ) + if isinstance(indexing_element.start, int) + else (0 if stride > 0 else dimension_size - 1) + ) + + else: # pragma: no cover + # this branch is impossible to reach with all the previous checks + # but let's keep it as an extra measure + can_be_converted = False + break + + offsets.append(offset) + sizes.append(size) + strides.append(stride) + + if not can_be_converted: + raise NotImplementedError( + f"Indexing of {input_value} with {index_str} cannot be converted to MLIR yet", + ) + + return tensor.ExtractSliceOp( + tensor_type, + pred, + [], + [], + [], + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), + ).result + def convert_matmul(self) -> OpResult: """Convert a MatMul node to its corresponding MLIR representation. diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 6fb6f957f..62c386774 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -99,7 +99,6 @@ def check_node_compatibility_with_mlir( elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing assert_true(len(outputs) == 1) - return "indexing is not supported for the time being" elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication assert_true(len(inputs) == 2) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 6c0d40f85..9afd9e5b3 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -7,7 +7,7 @@ import numpy import pytest from concrete.common.compilation import CompilationConfiguration -from concrete.common.data_types.integers import Integer, UnsignedInteger +from concrete.common.data_types.integers import Integer, SignedInteger, UnsignedInteger from concrete.common.debugging import draw_graph, format_operation_graph from concrete.common.extensions.multi_table import MultiLookupTable from concrete.common.extensions.table import LookupTable @@ -1403,24 +1403,6 @@ return %2 """.strip() # noqa: E501 ), ), - pytest.param( - lambda x: x[0], - {"x": EncryptedTensor(Integer(3, is_signed=True), shape=(2, 2))}, - [(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)], - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported -%1 = %0[0] # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being -return %1 - - """.strip() # noqa: E501 - ), - ), pytest.param( no_fuse_unhandled, {"x": EncryptedScalar(Integer(2, False)), "y": EncryptedScalar(Integer(2, False))}, @@ -1539,6 +1521,52 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation assert str(excinfo.value) == match, str(excinfo.value) +@pytest.mark.parametrize( + "function,parameters,inputset,match", + [ + pytest.param( + lambda x: (x * 1.5)[0, 1], + {"x": EncryptedTensor(SignedInteger(3), shape=(2, 2))}, + [(numpy.random.randint(-4, 3, size=(2, 2)),) for i in range(10)], + ( + """ + +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported +%1 = 1.5 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported +%2 = mul(%0, %1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported +%3 = %2[0, 1] # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer outputs are supported +return %3 + + """.strip() # noqa: E501 + ), + ), + ], +) +def test_fail_compile_while_fusing_is_disabled( + function, parameters, inputset, match, default_compilation_configuration +): + """Test compile_numpy_function without fusing and with failing inputs""" + + configuration_to_use = deepcopy(default_compilation_configuration) + configuration_to_use.enable_topological_optimizations = False + + with pytest.raises(RuntimeError) as excinfo: + compile_numpy_function( + function, + parameters, + inputset, + configuration_to_use, + ) + + assert str(excinfo.value) == match, str(excinfo.value) + + def test_small_inputset_no_fail(): """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" compile_numpy_function_into_op_graph_and_measure_bounds( diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py index acf8b3788..6158cd200 100644 --- a/tests/numpy/test_compile_constant_indexing.py +++ b/tests/numpy/test_compile_constant_indexing.py @@ -5,7 +5,10 @@ import pytest from concrete.common.data_types import UnsignedInteger from concrete.common.values import EncryptedScalar, EncryptedTensor -from concrete.numpy import compile_numpy_function_into_op_graph_and_measure_bounds +from concrete.numpy import ( + compile_numpy_function, + compile_numpy_function_into_op_graph_and_measure_bounds, +) @pytest.mark.parametrize( @@ -595,3 +598,183 @@ def test_invalid_constant_indexing_with_numpy_values( except Exception as error: assert str(error) == expected_error_message raise + + +@pytest.mark.parametrize( + "function,parameters,inputset,test_input,expected_output", + [ + pytest.param( + lambda x: x[0], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), + }, + [(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], + ([4, 2, 6],), + 4, + ), + pytest.param( + lambda x: x[-1], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), + }, + [(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], + ([4, 2, 6],), + 6, + ), + pytest.param( + lambda x: x[:3], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), + }, + [(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)], + ([4, 2, 6, 1],), + [4, 2, 6], + ), + pytest.param( + lambda x: x[2:], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), + }, + [(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)], + ([4, 2, 6, 1],), + [6, 1], + ), + pytest.param( + lambda x: x[1:3], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), + }, + [(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)], + ([4, 2, 6, 1],), + [2, 6], + ), + pytest.param( + lambda x: x[::2], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), + }, + [(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)], + ([4, 2, 6, 1],), + [4, 6], + ), + pytest.param( + lambda x: x[::-1], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), + }, + [(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)], + ([4, 2, 6, 1],), + [1, 6, 2, 4], + ), + pytest.param( + lambda x: x[1, 0], + { + "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), + }, + [(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)], + ([[11, 12], [21, 22], [31, 32]],), + 21, + ), + pytest.param( + lambda x: x[:, :], + { + "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), + }, + [(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)], + ([[11, 12], [21, 22], [31, 32]],), + [[11, 12], [21, 22], [31, 32]], + ), + pytest.param( + lambda x: x[0, :], + { + "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), + }, + [(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)], + ([[11, 12], [21, 22], [31, 32]],), + [11, 12], + marks=pytest.mark.xfail(strict=True), + ), + pytest.param( + lambda x: x[:, 0], + { + "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), + }, + [(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)], + ([[11, 12], [21, 22], [31, 32]],), + [11, 21, 31], + marks=pytest.mark.xfail(strict=True), + ), + ], +) +def test_constant_indexing_run_correctness( + function, + parameters, + inputset, + test_input, + expected_output, + default_compilation_configuration, +): + """Test correctness of results when running a compiled function with tensor operators""" + circuit = compile_numpy_function( + function, + parameters, + inputset, + default_compilation_configuration, + ) + + numpy_test_input = tuple( + item if isinstance(item, int) else np.array(item, dtype=np.uint8) for item in test_input + ) + + output = circuit.run(*numpy_test_input) + expected = np.array(expected_output, dtype=np.uint8) + + assert np.array_equal( + output, expected + ), f""" + +Actual Output +============= +{output} + +Expected Output +=============== +{expected} + + """ + + +@pytest.mark.parametrize( + "function,parameters,inputset,match", + [ + pytest.param( + lambda x: x[0:1], + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), + }, + [(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], + ( + "Indexing of EncryptedTensor with [0:1] " + "cannot be converted to MLIR yet" + ), + ), + ], +) +def test_constant_indexing_failed_compilation( + function, + parameters, + inputset, + match, + default_compilation_configuration, +): + """Test compilation failures of compiled function with constant indexing""" + + with pytest.raises(RuntimeError) as excinfo: + compile_numpy_function( + function, + parameters, + inputset, + default_compilation_configuration, + ) + + assert str(excinfo.value) == match, str(excinfo.value)