diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index 35024b44d..3393662c9 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, cast import numpy -from mlir.dialects import arith, tensor +from mlir.dialects import arith, linalg, tensor from mlir.ir import ( ArrayAttr, Attribute, @@ -117,7 +117,11 @@ class IntermediateNodeConverter: result = self.convert_dot() elif isinstance(self.node, GenericFunction): - result = self.convert_generic_function(additional_conversion_info) + if self.node.op_name in ["flatten", "reshape"]: + # notice flatten() == reshape(-1) and convert_reshape can handle that + result = self.convert_reshape() + else: + result = self.convert_generic_function(additional_conversion_info) elif isinstance(self.node, IndexConstant): result = self.convert_index_constant() @@ -481,6 +485,129 @@ class IntermediateNodeConverter: return result + def convert_reshape(self) -> OpResult: + """Convert a "reshape" node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 1) + assert_true(len(self.node.outputs) == 1) + + assert_true(isinstance(self.node.inputs[0], TensorValue)) + input_shape = cast(TensorValue, self.node.inputs[0]).shape + + assert_true(isinstance(self.node.outputs[0], TensorValue)) + output_shape = cast(TensorValue, self.node.outputs[0]).shape + + pred = self.preds[0] + if input_shape == output_shape: + return pred + + # we can either collapse or expand, which changes the number of dimensions + # this is a limitation of the current compiler and it will be improved in the future (#1060) + can_be_converted_directly = len(input_shape) != len(output_shape) + + reassociation: List[List[int]] = [] + if can_be_converted_directly: + if len(output_shape) == 1: + # output is 1 dimensional so collapse every dimension into the same dimension + reassociation.append(list(range(len(input_shape)))) + else: + # input is m dimensional + # output is n dimensional + # and m is different than n + + # we don't want to duplicate code so we forget about input and output + # and we focus on smaller shape and bigger shape + + smaller_shape, bigger_shape = ( + (output_shape, input_shape) + if len(output_shape) < len(input_shape) + else (input_shape, output_shape) + ) + s_index, b_index = 0, 0 + + # now we will figure out how to group the bigger shape to get the smaller shape + # think of the algorithm below as + # keep merging the dimensions of the bigger shape + # until we have a match on the smaller shape + # then try to match the next dimension of the smaller shape + # if all dimensions of the smaller shape is matched + # we can convert it + + group = [] + size = 1 + while s_index < len(smaller_shape) and b_index < len(bigger_shape): + # dimension `b_index` of `bigger_shape` belongs to current group + group.append(b_index) + + # and current group has `size * bigger_shape[b_index]` elements now + size *= bigger_shape[b_index] + + # if current group size matches the dimension `s_index` of `smaller_shape` + if size == smaller_shape[s_index]: + # we finalize this group and reset everything + size = 1 + reassociation.append(group) + group = [] + + # now try to match the next dimension of `smaller_shape` + s_index += 1 + + # now process the next dimension of `bigger_shape` + b_index += 1 + + # handle the case where bigger shape has proceeding 1s + # e.g., (5,) -> (5, 1) + while b_index < len(bigger_shape) and bigger_shape[b_index] == 1: + reassociation[-1].append(b_index) + b_index += 1 + + # if not all dimensions of both shapes are processed exactly + if s_index != len(smaller_shape) or b_index != len(bigger_shape): + # we cannot convert + can_be_converted_directly = False + + index_type = IndexType.parse("index") + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + + if can_be_converted_directly: + reassociation_attr = ArrayAttr.get( + [ + ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group]) + for group in reassociation + ] + ) + if len(output_shape) < len(input_shape): + return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result + return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result + + flattened_type = value_to_mlir_type( + self.ctx, + TensorValue( + self.node.inputs[0].dtype, + self.node.inputs[0].is_encrypted, + (numpy.prod(input_shape),), + ), + ) + flattened_result = linalg.TensorCollapseShapeOp( + flattened_type, + pred, + ArrayAttr.get( + [ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])] + ), + ).result + + return linalg.TensorExpandShapeOp( + resulting_type, + flattened_result, + ArrayAttr.get( + [ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])] + ), + ).result + def convert_sub(self) -> OpResult: """Convert a Sub node to its corresponding MLIR representation. diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 1283963e0..b1242ab2c 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -89,7 +89,8 @@ def check_node_compatibility_with_mlir( == 1 ) else: - return f"{node.op_name} is not supported for the time being" + if node.op_name not in ["flatten", "reshape"]: + return f"{node.op_name} is not supported for the time being" elif isinstance(node, intermediate.Dot): # constraints for dot product assert_true(len(inputs) == 2) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index adc9d5ef0..d96efd66a 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -362,13 +362,13 @@ class NPTracer(BaseTracer): ) return output_tracer - def reshape(self, arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": + def reshape(self, newshape: Tuple[Any, ...], **kwargs) -> "NPTracer": """Trace x.reshape. Returns: NPTracer: The output NPTracer containing the traced function """ - return self.numpy_reshape(self, arg1, **kwargs) + return self.numpy_reshape(self, newshape, **kwargs) def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": """Trace numpy.reshape. @@ -390,15 +390,13 @@ class NPTracer(BaseTracer): assert_true(isinstance(first_arg_output, TensorValue)) first_arg_output = cast(TensorValue, first_arg_output) - # Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, - # while classical form is numpy.reshape(x, (170,)) - newshape = deepcopy(arg1) if not isinstance(arg1, int) else (arg1,) - - # Check shape compatibility - assert_true( - numpy.product(newshape) == first_arg_output.size, - f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})", - ) + try: + # calculate a newshape using numpy to handle edge cases such as `-1`s within new shape + newshape = numpy.zeros(first_arg_output.shape).reshape(arg1).shape + except Exception as error: + raise ValueError( + f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {arg1})" + ) from error reshape_is_fusable = newshape == first_arg_output.shape diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 023f99b69..49146f91f 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1675,23 +1675,6 @@ function you are trying to compile isn't supported for MLIR lowering %0 = x # EncryptedTensor %1 = ravel(%0) # EncryptedTensor ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being -return %1 - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.reshape(x, (2, 6)), - {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 4)) for i in range(10)], - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being return %1 """.strip() # noqa: E501 diff --git a/tests/numpy/test_compile_memory_operations.py b/tests/numpy/test_compile_memory_operations.py new file mode 100644 index 000000000..68dfb1421 --- /dev/null +++ b/tests/numpy/test_compile_memory_operations.py @@ -0,0 +1,265 @@ +"""Test module for memory operations.""" + +import numpy +import pytest + +from concrete.common.data_types import UnsignedInteger +from concrete.common.values import EncryptedTensor +from concrete.numpy import compile_numpy_function + + +@pytest.mark.parametrize( + "function,parameters,inputset,test_input,expected_output", + [ + pytest.param( + lambda x: x.flatten(), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)], + [[0, 1], [1, 2], [2, 3]], + [0, 1, 1, 2, 2, 3], + ), + pytest.param( + lambda x: x.flatten(), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)], + (numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)), + (numpy.arange(720) % 10), + ), + pytest.param( + lambda x: x.reshape((1, 3)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3,)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)], + [5, 9, 1], + [[5, 9, 1]], + ), + pytest.param( + lambda x: x.reshape((3, 1)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3,)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)], + [5, 9, 1], + [[5], [9], [1]], + ), + pytest.param( + lambda x: x.reshape((3, 2)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)], + [[0, 1], [1, 2], [2, 3]], + [[0, 1], [1, 2], [2, 3]], + ), + pytest.param( + lambda x: x.reshape((3, 2)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)], + [[0, 1, 1], [2, 2, 3]], + [[0, 1], [1, 2], [2, 3]], + ), + pytest.param( + lambda x: x.reshape(-1), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)], + [[0, 1], [1, 2], [2, 3]], + [0, 1, 1, 2, 2, 3], + ), + pytest.param( + lambda x: x.reshape((2, 2, 3)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)), + }, + [numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)], + (numpy.arange(12) % 10).reshape((4, 3)), + (numpy.arange(12) % 10).reshape((2, 2, 3)), + ), + pytest.param( + lambda x: x.reshape((4, 3)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)], + (numpy.arange(12) % 10).reshape((2, 2, 3)), + (numpy.arange(12) % 10).reshape((4, 3)), + ), + pytest.param( + lambda x: x.reshape((3, 2, 2)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)], + (numpy.arange(12) % 10).reshape((3, 4)), + (numpy.arange(12) % 10).reshape((3, 2, 2)), + ), + pytest.param( + lambda x: x.reshape((3, 4)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)), + }, + [numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)], + (numpy.arange(12) % 10).reshape((3, 2, 2)), + (numpy.arange(12) % 10).reshape((3, 4)), + ), + pytest.param( + lambda x: x.reshape((5, 3, 2)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)), + }, + [numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)], + (numpy.arange(30) % 10).reshape((6, 5)), + (numpy.arange(30) % 10).reshape((5, 3, 2)), + ), + pytest.param( + lambda x: x.reshape((5, 6)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)], + (numpy.arange(30) % 10).reshape((2, 3, 5)), + (numpy.arange(30) % 10).reshape((5, 6)), + ), + pytest.param( + lambda x: x.reshape((6, 4, 30)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)], + (numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)), + (numpy.arange(720) % 10).reshape((6, 4, 30)), + ), + pytest.param( + lambda x: x.reshape((2, 60, 6)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)], + (numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)), + (numpy.arange(720) % 10).reshape((2, 60, 6)), + ), + pytest.param( + lambda x: x.reshape((6, 6, -1)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)], + (numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)), + (numpy.arange(144) % 10).reshape((6, 6, -1)), + ), + pytest.param( + lambda x: x.reshape((6, -1, 12)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)], + (numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)), + (numpy.arange(144) % 10).reshape((6, -1, 12)), + ), + pytest.param( + lambda x: x.reshape((-1, 18, 4)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)], + (numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)), + (numpy.arange(144) % 10).reshape((-1, 18, 4)), + ), + ], +) +def test_memory_operation_run_correctness( + function, + parameters, + inputset, + test_input, + expected_output, + default_compilation_configuration, + check_array_equality, +): + """ + Test correctness of results when running a compiled function with memory operators. + + e.g., + - flatten + - reshape + """ + circuit = compile_numpy_function( + function, + parameters, + inputset, + default_compilation_configuration, + ) + + actual = circuit.run(numpy.array(test_input, dtype=numpy.uint8)) + expected = numpy.array(expected_output, dtype=numpy.uint8) + + check_array_equality(actual, expected) + + +@pytest.mark.parametrize( + "function,parameters,inputset,error,match", + [ + pytest.param( + lambda x: x.reshape((-1, -1, 2)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)], + ValueError, + "shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))", + ), + pytest.param( + lambda x: x.reshape((3, -1, 3)), + { + "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)), + }, + [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)], + ValueError, + "shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))", + ), + ], +) +def test_memory_operation_failed_compilation( + function, + parameters, + inputset, + error, + match, + default_compilation_configuration, +): + """ + Test compilation failures of compiled function with memory operations. + + e.g., + - reshape + """ + + with pytest.raises(error) as excinfo: + compile_numpy_function( + function, + parameters, + inputset, + default_compilation_configuration, + ) + + assert ( + str(excinfo.value) == match + ), f""" + +Actual Output +============= +{excinfo.value} + +Expected Output +=============== +{match} + + """ diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 69c8e907e..b4050a3de 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -644,7 +644,7 @@ def test_tracing_numpy_calls( None, ) ], - marks=pytest.mark.xfail(strict=True, raises=AssertionError), + marks=pytest.mark.xfail(strict=True, raises=ValueError), ), pytest.param( lambda x: x.flatten(), @@ -985,7 +985,7 @@ def test_tracing_ndarray_calls( ) def test_errors_with_generic_function(lambda_f, params): "Test some errors with generic function" - with pytest.raises(AssertionError) as excinfo: + with pytest.raises(ValueError) as excinfo: tracing.trace_numpy_function(lambda_f, params) assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value) diff --git a/tests/numpy/test_tracing_calls.py b/tests/numpy/test_tracing_calls.py index 3e78fb8fa..473a6bc85 100644 --- a/tests/numpy/test_tracing_calls.py +++ b/tests/numpy/test_tracing_calls.py @@ -294,7 +294,7 @@ def test_tracing_numpy_calls( None, ) ], - marks=pytest.mark.xfail(strict=True, raises=AssertionError), + marks=pytest.mark.xfail(strict=True, raises=ValueError), ), ], ) diff --git a/tests/numpy/test_tracing_failures.py b/tests/numpy/test_tracing_failures.py index d1cbbf560..7a01f8418 100644 --- a/tests/numpy/test_tracing_failures.py +++ b/tests/numpy/test_tracing_failures.py @@ -178,7 +178,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match): ) def test_errors_with_generic_function(lambda_f, params): "Test some errors with generic function" - with pytest.raises(AssertionError) as excinfo: + with pytest.raises(ValueError) as excinfo: tracing.trace_numpy_function(lambda_f, params) assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)