diff --git a/concrete/common/mlir/graph_converter.py b/concrete/common/mlir/graph_converter.py index 9b714999a..7fdd99d21 100644 --- a/concrete/common/mlir/graph_converter.py +++ b/concrete/common/mlir/graph_converter.py @@ -5,7 +5,7 @@ # pylint: disable=no-name-in-module from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, List import networkx as nx import zamalang @@ -13,7 +13,7 @@ from mlir.dialects import builtin from mlir.ir import Context, InsertionPoint, Location, Module from ..operator_graph import OPGraph -from ..representation.intermediate import Input +from ..representation.intermediate import Input, IntermediateNode from .conversion_helpers import value_to_mlir_type from .node_converter import IntermediateNodeConverter @@ -35,6 +35,15 @@ class OPGraphConverter(ABC): additional_conversion_info = self._generate_additional_info_dict(op_graph) + # { node1: "%arg0", node2: "%0", node3: "%1" } + nodes_to_mlir_names: Dict[IntermediateNode, str] = {} + + # { "%arg0": "i5", "%0": "tensor<2x3x!HLFHE.eint<4>>" } + mlir_names_to_mlir_types: Dict[str, str] = {} + + # { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor + scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {} + with Context() as ctx, Location.unknown(): zamalang.register_dialects(ctx) @@ -51,12 +60,24 @@ class OPGraphConverter(ABC): for arg_num, node in op_graph.input_nodes.items(): ir_to_mlir[node] = arg[arg_num] + mlir_name = f"%arg{arg_num}" + nodes_to_mlir_names[node] = mlir_name + mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num]) + for node in nx.topological_sort(op_graph.graph): if isinstance(node, Input): continue preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)] - node_converter = IntermediateNodeConverter(ctx, op_graph, node, preds) + node_converter = IntermediateNodeConverter( + ctx, + op_graph, + node, + preds, + nodes_to_mlir_names, + mlir_names_to_mlir_types, + scalar_to_1d_tensor_conversion_hacks, + ) ir_to_mlir[node] = node_converter.convert(additional_conversion_info) results = ( @@ -64,7 +85,39 @@ class OPGraphConverter(ABC): ) return results - return str(module) + module_lines_after_hacks_are_applied = [] + for line in str(module).split("\n"): + mlir_name = line.split("=")[0].strip() + if mlir_name not in scalar_to_1d_tensor_conversion_hacks: + module_lines_after_hacks_are_applied.append(line) + continue + + to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name] + for arg_name in to_be_replaced: + new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}" + mlir_type = mlir_names_to_mlir_types[arg_name] + + hack_line = ( + f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>" + ) + module_lines_after_hacks_are_applied.append(hack_line) + + line = line.replace(arg_name, new_name) + + new_arg_types = [] + + arg_types = line.split(":")[1].split("->")[0].strip()[1:-1] + for arg in arg_types.split(", "): + if arg.startswith("tensor"): + new_arg_types.append(arg) + else: + new_arg_types.append(f"tensor<1x{arg}>") + + line = line.replace(arg_types, ", ".join(new_arg_types)) + + module_lines_after_hacks_are_applied.append(line) + + return "\n".join(module_lines_after_hacks_are_applied) @staticmethod @abstractmethod diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index 4ccf5db3c..d8457b6c6 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -50,8 +50,19 @@ class IntermediateNodeConverter: all_of_the_inputs_are_tensors: bool one_of_the_inputs_is_a_tensor: bool + nodes_to_mlir_names: Dict[IntermediateNode, str] + mlir_names_to_mlir_types: Dict[str, str] + scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] + def __init__( - self, ctx: Context, op_graph: OPGraph, node: IntermediateNode, preds: List[OpResult] + self, + ctx: Context, + op_graph: OPGraph, + node: IntermediateNode, + preds: List[OpResult], + nodes_to_mlir_names: Dict[OpResult, str], + mlir_names_to_mlir_types: Dict[str, str], + scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]], ): self.ctx = ctx self.op_graph = op_graph @@ -75,6 +86,10 @@ class IntermediateNodeConverter: # this branch is not covered as there are only TensorValues for now self.all_of_the_inputs_are_tensors = False + self.nodes_to_mlir_names = nodes_to_mlir_names + self.mlir_names_to_mlir_types = mlir_names_to_mlir_types + self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks + def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult: """Convert an intermediate node to its corresponding MLIR representation. @@ -111,6 +126,20 @@ class IntermediateNodeConverter: # this branch is not covered as unsupported opeations fail on check mlir compatibility raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet") + mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip() + + self.nodes_to_mlir_names[self.node] = mlir_name + self.mlir_names_to_mlir_types[mlir_name] = str(result.type) + + if isinstance(self.node, (Add, Mul, Sub, Dot)): + if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors: + to_be_converted = [] + for (pred, output) in self.op_graph.get_ordered_inputs_of(self.node): + inp = pred.outputs[output] + if isinstance(inp, TensorValue) and inp.is_scalar: + to_be_converted.append(self.nodes_to_mlir_names[pred]) + self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted + return result def convert_add(self) -> OpResult: diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 176aeeadf..b62ce99de 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -736,11 +736,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( @pytest.mark.parametrize( "function,parameters,inputset,test_input,expected_output", [ - # TODO: find a way to support this case - # https://github.com/zama-ai/concretefhe-internal/issues/837 - # - # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: x + 1, { @@ -759,7 +754,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [7, 2], [3, 6], ], - marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32), @@ -780,11 +774,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [5, 6], ], ), - # TODO: find a way to support this case - # https://github.com/zama-ai/concretefhe-internal/issues/837 - # - # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x, y: x + y, { @@ -811,7 +800,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [8, 3], [4, 7], ], - marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x, y: x + y, @@ -844,11 +832,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [5, 9], ], ), - # TODO: find a way to support this case - # https://github.com/zama-ai/concretefhe-internal/issues/837 - # - # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: 100 - x, { @@ -867,7 +850,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [94, 99], [98, 95], ], - marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x, @@ -888,11 +870,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [8, 25], ], ), - # TODO: find a way to support this case - # https://github.com/zama-ai/concretefhe-internal/issues/837 - # - # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: x * 2, { @@ -911,7 +888,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu( [12, 2], [4, 10], ], - marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32), @@ -994,7 +970,10 @@ def test_compile_and_run_tensor_correctness( default_compilation_configuration, ) - numpy_test_input = (numpy.array(item, dtype=numpy.uint8) for item in test_input) + numpy_test_input = ( + item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8) + for item in test_input + ) assert numpy.array_equal( circuit.run(*numpy_test_input), numpy.array(expected_output, dtype=numpy.uint8),