diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 207841f6b..bae096d3e 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -5,7 +5,7 @@ Declaration of `GraphConverter` class. # pylint: disable=no-member,no-name-in-module from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, cast import concrete.lang as concretelang import networkx as nx @@ -299,11 +299,63 @@ class GraphConverter: nx_graph.add_edge(add_offset, node, input_idx=variable_input_index) @staticmethod - def _sanitize_signed_inputs( - graph: Graph, - args: List[Any], - ctx: Context, - ) -> Tuple[List[Any], List[str]]: + def _tensorize_scalars_for_fhelinalg(graph: Graph): + """ + Tensorize scalars if they are used within fhelinalg operations. + + Args: + graph (Graph): + computation graph to update + """ + + # pylint: disable=invalid-name + OPS_TO_TENSORIZE = ["add", "dot", "multiply", "subtract"] + # pylint: enable=invalid-name + + tensorized_scalars: Dict[Node, Node] = {} + + nx_graph = graph.graph + for node in list(nx_graph.nodes): + if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE: + assert_that(len(node.inputs) == 2) + + if set(inp.is_scalar for inp in node.inputs) != {True, False}: + continue + + pred_to_tensorize: Optional[Node] = None + pred_to_tensorize_index = 0 + + preds = graph.ordered_preds_of(node) + for index, pred in enumerate(preds): + if pred.output.is_scalar: + pred_to_tensorize = pred + pred_to_tensorize_index = index + break + + assert pred_to_tensorize is not None + + tensorized_pred = tensorized_scalars.get(pred_to_tensorize) + if tensorized_pred is None: + tensorized_value = deepcopy(pred_to_tensorize.output) + tensorized_value.shape = (1,) + + tensorized_pred = Node.generic( + "array", + [pred_to_tensorize.output], + tensorized_value, + lambda *args: np.array(args), + ) + nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0) + + tensorized_scalars[pred_to_tensorize] = tensorized_pred + + assert tensorized_pred is not None + + nx_graph.remove_edge(pred_to_tensorize, node) + nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index) + + @staticmethod + def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]: """ Apply table lookup to signed inputs in the beginning of evaluation to sanitize them. @@ -341,8 +393,6 @@ class GraphConverter: """ sanitized_args = [] - arg_mlir_names = [] - for i, arg in enumerate(args): input_node = graph.input_nodes[i] input_value = input_node.output @@ -370,14 +420,10 @@ class GraphConverter: sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result sanitized_args.append(sanitized) - mlir_name = str(sanitized).replace("Value(", "").split("=", maxsplit=1)[0].strip() else: sanitized_args.append(arg) - mlir_name = f"%arg{i}" - arg_mlir_names.append(mlir_name) - - return sanitized_args, arg_mlir_names + return sanitized_args @staticmethod def convert(graph: Graph, virtual: bool = False) -> str: @@ -404,21 +450,7 @@ class GraphConverter: GraphConverter._update_bit_widths(graph) GraphConverter._offset_negative_lookup_table_inputs(graph) - - # There are no tensor +*- scalar operations in the compiler - # But such operations are used commonly, so we need to support them - # So, we implemented some workarounds (pull request #970) - # Once we have native support, this workaround shall be removed (issue #837) - # (most changes in #970 shall be reverted) - - # { node1: "%arg0", node2: "%0", node3: "%1" } - nodes_to_mlir_names: Dict[Node, str] = {} - - # { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" } - mlir_names_to_mlir_types: Dict[str, str] = {} - - # { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor - scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {} + GraphConverter._tensorize_scalars_for_fhelinalg(graph) # { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0 direct_replacements: Dict[str, str] = {} @@ -435,20 +467,12 @@ class GraphConverter: @builtin.FuncOp.from_py_func(*parameters) def main(*args): - sanitized_args, arg_mlir_names = GraphConverter._sanitize_signed_inputs( - graph, - args, - ctx, - ) + sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx) ir_to_mlir = {} for arg_num, node in graph.input_nodes.items(): ir_to_mlir[node] = sanitized_args[arg_num] - mlir_name = arg_mlir_names[arg_num] - nodes_to_mlir_names[node] = mlir_name - mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num]) - for node in nx.topological_sort(graph.graph): if node.operation == Operation.Input: continue @@ -459,9 +483,6 @@ class GraphConverter: graph, node, preds, - nodes_to_mlir_names, - mlir_names_to_mlir_types, - scalar_to_1d_tensor_conversion_hacks, direct_replacements, ) ir_to_mlir[node] = node_converter.convert() @@ -472,39 +493,11 @@ class GraphConverter: module_lines_after_hacks_are_applied = [] for line in str(module).split("\n"): mlir_name = line.split("=")[0].strip() - - if mlir_name in direct_replacements: - new_value = direct_replacements[mlir_name] - module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}") - continue - - if mlir_name not in scalar_to_1d_tensor_conversion_hacks: + if mlir_name not in direct_replacements: module_lines_after_hacks_are_applied.append(line) continue - to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name] - for arg_name in to_be_replaced: - new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}" - mlir_type = mlir_names_to_mlir_types[arg_name] - - hack_line = ( - f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>" - ) - module_lines_after_hacks_are_applied.append(hack_line) - - line = line.replace(arg_name, new_name) - - new_arg_types = [] - - arg_types = line.split(":")[1].split("->")[0].strip()[1:-1] - for arg in arg_types.split(", "): - if arg.startswith("tensor"): - new_arg_types.append(arg) - else: - new_arg_types.append(f"tensor<1x{arg}>") - - line = line.replace(arg_types, ", ".join(new_arg_types)) - - module_lines_after_hacks_are_applied.append(line) + new_value = direct_replacements[mlir_name] + module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}") return "\n".join(module_lines_after_hacks_are_applied).strip() diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 7601d7059..036bd82fb 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -51,9 +51,6 @@ class NodeConverter: all_of_the_inputs_are_tensors: bool one_of_the_inputs_is_a_tensor: bool - nodes_to_mlir_names: Dict[Node, str] - mlir_names_to_mlir_types: Dict[str, str] - scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] direct_replacements: Dict[str, str] # pylint: enable=too-many-instance-attributes @@ -115,9 +112,6 @@ class NodeConverter: graph: Graph, node: Node, preds: List[OpResult], - nodes_to_mlir_names: Dict[OpResult, str], - mlir_names_to_mlir_types: Dict[str, str], - scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]], direct_replacements: Dict[str, str], ): self.ctx = ctx @@ -138,9 +132,6 @@ class NodeConverter: else: self.one_of_the_inputs_is_a_tensor = True - self.nodes_to_mlir_names = nodes_to_mlir_names - self.mlir_names_to_mlir_types = mlir_names_to_mlir_types - self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks self.direct_replacements = direct_replacements def convert(self) -> OpResult: @@ -216,21 +207,6 @@ class NodeConverter: assert_that(self.node.converted_to_table_lookup) result = self._convert_tlu() - mlir_name = NodeConverter.mlir_name(result) - - self.nodes_to_mlir_names[self.node] = mlir_name - self.mlir_names_to_mlir_types[mlir_name] = str(result.type) - - if self.node.operation == Operation.Generic: - name = self.node.properties["name"] - if name in ["add", "dot", "multiply", "subtract"]: - if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors: - to_be_converted = [] - for pred in self.graph.ordered_preds_of(self.node): - if pred.output.is_scalar: - to_be_converted.append(self.nodes_to_mlir_names[pred]) - self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted - return result # pylint: enable=too-many-branches @@ -284,7 +260,7 @@ class NodeConverter: pred_names = [] for pred, value in zip(preds, self.node.inputs): - if value.is_encrypted: + if value.is_encrypted or self.node.output.is_clear: pred_names.append(NodeConverter.mlir_name(pred)) continue