diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index b64aa08fc..fbfaf6dd3 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -422,6 +422,8 @@ class IntermediateNodeConverter: str: textual MLIR representation corresponding to self.node """ + # pylint: disable=too-many-locals + assert_true(len(self.node.inputs) == 1) assert_true(len(self.node.outputs) == 1) @@ -448,10 +450,11 @@ class IntermediateNodeConverter: sizes = [] strides = [] - can_be_converted = True + destroyed_dimensions = [] for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): if isinstance(indexing_element, int): + destroyed_dimensions.append(dimension) size = 1 stride = 1 offset = ( @@ -466,10 +469,6 @@ class IntermediateNodeConverter: input_shape, index_str, ) - if size == 1: - can_be_converted = False - break - stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 offset = ( ( @@ -484,20 +483,43 @@ class IntermediateNodeConverter: else: # pragma: no cover # this branch is impossible to reach with all the previous checks # but let's keep it as an extra measure - can_be_converted = False - break + raise NotImplementedError( + f"Indexing of {input_value} with {index_str} cannot be converted to MLIR", + ) offsets.append(offset) sizes.append(size) strides.append(stride) - if not can_be_converted: - raise NotImplementedError( - f"Indexing of {input_value} with {index_str} cannot be converted to MLIR yet", - ) + if len(destroyed_dimensions) == 0: + return tensor.ExtractSliceOp( + tensor_type, + pred, + [], + [], + [], + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), + ).result - return tensor.ExtractSliceOp( - tensor_type, + output_value = cast(TensorValue, self.node.outputs[0]) + + intermediate_shape = list(output_value.shape) + for dimension in destroyed_dimensions: + intermediate_shape.insert(dimension, 1) + + intermediate_type = RankedTensorType.get( + intermediate_shape, + integer_to_mlir_type( + self.ctx, + cast(Integer, output_value.dtype), + output_value.is_encrypted, + ), + ) + + intermediate = tensor.ExtractSliceOp( + intermediate_type, pred, [], [], @@ -507,6 +529,36 @@ class IntermediateNodeConverter: ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), ).result + reassociaton = [] + + current_intermediate_dimension = 0 + for _ in range(len(output_value.shape)): + indices = [current_intermediate_dimension] + while current_intermediate_dimension in destroyed_dimensions: + current_intermediate_dimension += 1 + indices.append(current_intermediate_dimension) + + reassociaton.append(indices) + current_intermediate_dimension += 1 + while current_intermediate_dimension < len(intermediate_shape): + reassociaton[-1].append(current_intermediate_dimension) + current_intermediate_dimension += 1 + + return linalg.TensorCollapseShapeOp( + tensor_type, + intermediate, + ArrayAttr.get( + [ + ArrayAttr.get( + [IntegerAttr.get(index_type, index) for index in indices], + ) + for indices in reassociaton + ], + ), + ).result + + # pylint: enable=too-many-locals + def convert_matmul(self) -> OpResult: """Convert a MatMul node to its corresponding MLIR representation. diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py index 606922afa..41fc73c64 100644 --- a/tests/numpy/test_compile_constant_indexing.py +++ b/tests/numpy/test_compile_constant_indexing.py @@ -688,7 +688,6 @@ def test_invalid_constant_indexing_with_numpy_values( [np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)], ([[11, 12], [21, 22], [31, 32]],), [11, 12], - marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: x[:, 0], @@ -698,7 +697,6 @@ def test_invalid_constant_indexing_with_numpy_values( [np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)], ([[11, 12], [21, 22], [31, 32]],), [11, 21, 31], - marks=pytest.mark.xfail(strict=True), ), ], ) @@ -727,39 +725,3 @@ def test_constant_indexing_run_correctness( expected = np.array(expected_output, dtype=np.uint8) check_array_equality(output, expected) - - -@pytest.mark.parametrize( - "function,parameters,inputset,match", - [ - pytest.param( - lambda x: x[0:1], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), - }, - [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)], - ( - "Indexing of EncryptedTensor with [0:1] " - "cannot be converted to MLIR yet" - ), - ), - ], -) -def test_constant_indexing_failed_compilation( - function, - parameters, - inputset, - match, - default_compilation_configuration, -): - """Test compilation failures of compiled function with constant indexing""" - - with pytest.raises(RuntimeError) as excinfo: - compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - - assert str(excinfo.value) == match, str(excinfo.value)