diff --git a/concrete/numpy/compilation/utils.py b/concrete/numpy/compilation/utils.py index 1c6455722..a0245ad99 100644 --- a/concrete/numpy/compilation/utils.py +++ b/concrete/numpy/compilation/utils.py @@ -22,6 +22,10 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None): artifacts (Optional[DebugArtifacts], default = None): compilation artifacts to store information about the fusing process + + Raises: + RuntimeError: + if there is a subgraph which needs to be fused cannot be fused """ nx_graph = graph.graph @@ -31,12 +35,12 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None): while True: subgraph_to_fuse = ( find_float_subgraph_with_unique_terminal_node( - nx_graph, + graph, processed_terminal_nodes, ) if fusing_floats else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor( - nx_graph, + graph, processed_terminal_nodes, ) ) @@ -51,16 +55,12 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None): all_nodes, start_nodes, terminal_node = subgraph_to_fuse processed_terminal_nodes.add(terminal_node) - subgraph_conversion_result = convert_subgraph_to_subgraph_node( - nx_graph, + fused_node, node_before_subgraph = convert_subgraph_to_subgraph_node( + graph, all_nodes, start_nodes, terminal_node, ) - if subgraph_conversion_result is None: - continue - - fused_node, node_before_subgraph = subgraph_conversion_result nx_graph.add_node(fused_node) if terminal_node in graph.output_nodes.values(): @@ -89,14 +89,14 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None): def find_float_subgraph_with_unique_terminal_node( - nx_graph: nx.MultiDiGraph, + graph: Graph, processed_terminal_nodes: Set[Node], ) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]: """ Find a subgraph with float computations that end with an integer output. Args: - nx_graph (nx.MultiDiGraph): + graph (Graph): graph to search processed_terminal_nodes (Set[Node]): @@ -109,6 +109,7 @@ def find_float_subgraph_with_unique_terminal_node( and terminal node of the subgraph otherwise """ + nx_graph = graph.graph terminal_nodes = ( node for node in nx_graph.nodes() @@ -128,7 +129,7 @@ def find_float_subgraph_with_unique_terminal_node( start_single_int_output_nodes_search_from = terminal_node while True: all_nodes, start_nodes = find_closest_integer_output_nodes( - nx_graph, + graph, [start_single_int_output_nodes_search_from], all_nodes, ) @@ -141,14 +142,14 @@ def find_float_subgraph_with_unique_terminal_node( # find a common ancestor as we need a single variable input node # lca == lowest common ancestor - lca = find_single_lca(nx_graph, variable_start_nodes) + lca = find_single_lca(graph, variable_start_nodes) # if subgraph cannot be fused because there is no way to find a common ancestor, break if lca is None: break # add the nodes from the `start_nodes` to `lca`, to `all_nodes` - all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes) + all_nodes = add_nodes_from_to(graph, start_nodes, {lca: None}, all_nodes) # if `lca` is a valid starting node for fusing break if isinstance(lca.output.dtype, Integer): @@ -164,7 +165,7 @@ def find_float_subgraph_with_unique_terminal_node( def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor( - nx_graph: nx.MultiDiGraph, + graph: Graph, processed_terminal_nodes: Set[Node], ) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]: """ @@ -172,7 +173,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc where all variable inputs share a common ancestor. Args: - nx_graph (nx.MultiDiGraph): + graph (Graph): graph to search processed_terminal_nodes (Set[Node]): @@ -185,6 +186,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc and terminal node of the subgraph otherwise """ + nx_graph = graph.graph terminal_nodes = ( node for node in nx_graph.nodes() @@ -215,16 +217,17 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc # find a common ancestor as we need a single variable input node # lca == lowest common ancestor - lca = find_single_lca(nx_graph, variable_start_nodes) + lca = find_single_lca(graph, variable_start_nodes) # if subgraph cannot be fused because there is no way to find a common ancestor, break if lca is None: - start_nodes = {} + start_nodes = {node: None for node in variable_start_nodes} + all_nodes = {node: None for node in variable_start_nodes + [terminal_node]} break # add the nodes from the `start_nodes` to `lca`, to `all_nodes` all_nodes = add_nodes_from_to( - nx_graph, + graph, list(nx_graph.predecessors(terminal_node)), {lca: None}, all_nodes, @@ -240,12 +243,12 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc return all_nodes, start_nodes, terminal_node -def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]: +def find_single_lca(graph: Graph, nodes: List[Node]) -> Optional[Node]: """ Find the single lowest common ancestor of a list of nodes. Args: - nx_graph (nx.MultiDiGraph): + graph (Graph): graph to search for single lca nodes (List[Node]): @@ -256,6 +259,8 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No single lca if it exists, None otherwise """ + nx_graph = graph.graph + # find all ancestors of `nodes` # nodes themselves need to be in this set because the single lca can be within `nodes` all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes] @@ -278,7 +283,7 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No continue # check if node is a single common ancestor of `nodes` - if is_single_common_ancestor(nx_graph, candidate, nodes): + if is_single_common_ancestor(graph, candidate, nodes): # if so, it's the single lca of `nodes` # so return it return candidate @@ -289,7 +294,7 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No def is_single_common_ancestor( - nx_graph: nx.MultiDiGraph, + graph: Graph, candidate: Node, nodes: List[Node], ) -> bool: @@ -299,7 +304,7 @@ def is_single_common_ancestor( Note that this function doesn't care about `lowest` property of `lca`. Args: - nx_graph (nx.MultiDiGraph): + graph (Graph): graph to perform the check candidate (Node): @@ -313,6 +318,8 @@ def is_single_common_ancestor( True if `candidate` is a single common ancestor of `nodes`, False otherwise """ + nx_graph = graph.graph + # create a subgraph with `candidate` node subgraph = nx.DiGraph() subgraph.add_node(candidate) @@ -420,7 +427,7 @@ def is_single_common_ancestor( def find_closest_integer_output_nodes( - nx_graph: nx.MultiDiGraph, + graph: Graph, start_nodes: List[Node], all_nodes: Dict[Node, None], ) -> Tuple[Dict[Node, None], Dict[Node, None]]: @@ -428,7 +435,7 @@ def find_closest_integer_output_nodes( Find the closest upstream integer output nodes to a set of start nodes in a graph. Args: - nx_graph (nx.MultiDiGraph): + graph (Graph): graph to search start_nodes (List[Node]): @@ -442,6 +449,8 @@ def find_closest_integer_output_nodes( tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes` """ + nx_graph = graph.graph + closest_integer_output_nodes: Dict[Node, None] = {} visited_nodes: Set[Node] = set() @@ -465,7 +474,7 @@ def find_closest_integer_output_nodes( def add_nodes_from_to( - nx_graph: nx.MultiDiGraph, + graph: Graph, from_nodes: Iterable[Node], to_nodes: Dict[Node, None], all_nodes: Dict[Node, None], @@ -474,7 +483,7 @@ def add_nodes_from_to( Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`. Args: - nx_graph (nx.MultiDiGraph): + graph (Graph): graph to traverse from_nodes (Iterable[Node]): @@ -491,6 +500,8 @@ def add_nodes_from_to( extended `all_nodes` """ + nx_graph = graph.graph + all_nodes.update(to_nodes) visited_nodes: Set[Node] = set() @@ -511,17 +522,17 @@ def add_nodes_from_to( def convert_subgraph_to_subgraph_node( - nx_graph: nx.MultiDiGraph, + graph: Graph, all_nodes: Dict[Node, None], start_nodes: Dict[Node, None], terminal_node: Node, -) -> Optional[Tuple[Node, Node]]: +) -> Tuple[Node, Node]: """ Convert a subgraph to Operation.Generic node. Args: - nx_graph (nx.MultiDiGraph): - orginal networkx graph + graph (Graph): + orginal graph all_nodes (Dict[Node, None]): all nodes in the subgraph @@ -532,19 +543,32 @@ def convert_subgraph_to_subgraph_node( terminal_node (Node): terminal node of the subgraph + Raises: + RuntimeError: + if subgraph is not fusable + Returns: - Optional[Tuple[Node, Node]]: + Tuple[Node, Node]: None if the subgraph cannot be fused, subgraph node and its predecessor otherwise """ + nx_graph = graph.graph + variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant] if len(variable_input_nodes) != 1: - return None + base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes} + for variable_input_node in variable_input_nodes: + base_highlighted_nodes[variable_input_node] = ["this is one of the input nodes"] + + raise RuntimeError( + "A subgraph within the function you are trying to compile cannot be fused " + "because it has multiple input nodes\n\n" + + graph.format(highlighted_nodes=base_highlighted_nodes) + ) variable_input_node = variable_input_nodes[0] - if not subgraph_can_be_fused(all_nodes, variable_input_node): - return None + check_subgraph_fusability(graph, all_nodes, variable_input_node) nx_subgraph = nx.MultiDiGraph(nx_graph) nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes] @@ -583,10 +607,11 @@ def convert_subgraph_to_subgraph_node( return subgraph_node, variable_input_node -def subgraph_can_be_fused( +def check_subgraph_fusability( + graph: Graph, all_nodes: Dict[Node, None], variable_input_node: Node, -) -> bool: +): """ Determine if a subgraph can be fused. @@ -596,35 +621,42 @@ def subgraph_can_be_fused( between each cell of the input and each cell of the output for table lookups Args: + graph (Graph): + original graph + all_nodes (Dict[Node, None]): all nodes in the subgraph variable_input_node (Node): variable input node to the subgraph - Returns: - bool: - True if subgraph can be fused, - False otherwise + Raises: + RuntimeError: + if subgraph is not fusable """ - constant_nodes_with_bigger_size_than_variable_input = [ - node - for node in all_nodes - if ( - node.operation == Operation.Constant - and node.output.size > variable_input_node.output.size - ) - ] - if len(constant_nodes_with_bigger_size_than_variable_input) > 0: - return False + base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes} + base_highlighted_nodes[variable_input_node] = ["with this input node"] non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant) for node in non_constant_nodes: if node == variable_input_node: continue - if not node.is_fusable or node.output.shape != variable_input_node.output.shape: - return False + if not node.is_fusable: + base_highlighted_nodes[node] = ["this node is not fusable"] + raise RuntimeError( + "A subgraph within the function you are trying to compile cannot be fused " + "because of a node, which is marked explicitly as non-fusable\n\n" + + graph.format(highlighted_nodes=base_highlighted_nodes) + ) + + if node.output.shape != variable_input_node.output.shape: + base_highlighted_nodes[node] = ["this node has a different shape than the input node"] + raise RuntimeError( + "A subgraph within the function you are trying to compile cannot be fused " + "because of a node, which is has a different shape than the input node\n\n" + + graph.format(highlighted_nodes=base_highlighted_nodes) + ) return True diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 53880e6c5..c467ce496 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -165,8 +165,7 @@ class GraphConverter: for idx, pred in enumerate(graph.ordered_preds_of(node)) if not pred.operation == Operation.Constant ] - if len(variable_input_indices) != 1: - return "only single input table lookups are supported" + assert_that(len(variable_input_indices) == 1) if len(inputs) > 0 and all(input.is_clear for input in inputs): return "one of the operands must be encrypted" diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index 33703104a..65b799329 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -700,29 +700,36 @@ def test_others_bad_fusing(helpers): # pylint: disable=line-too-long """ -Function you are trying to compile cannot be converted to MLIR +A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes %0 = 10 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %1 = 10 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %2 = 2 # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %3 = 2 # ClearScalar - %4 = x # EncryptedScalar - %5 = y # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph + %4 = x # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes + %5 = y # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes %6 = sin(%4) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %7 = cos(%5) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %8 = power(%6, %2) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %9 = power(%7, %3) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %10 = multiply(%0, %8) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %11 = multiply(%1, %9) # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph %12 = add(%10, %11) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%13 = astype(%12, dtype=int_) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%13 = astype(%12, dtype=int_) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph return %13 """, # noqa: E501 @@ -730,62 +737,67 @@ return %13 str(excinfo.value), ) - # big intermediate constants - # -------------------------- - - @cnp.compiler({"x": "encrypted"}) - def function2(x): - return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64) - - with pytest.raises(RuntimeError) as excinfo: - inputset = range(100) - function2.compile(inputset, configuration) - - helpers.check_str( - # pylint: disable=line-too-long - """ - -Function you are trying to compile cannot be converted to MLIR - -%0 = [[1 2] [3 4]] # ClearTensor -%1 = x # EncryptedScalar -%2 = sin(%1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%3 = multiply(%2, %0) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%4 = astype(%3, dtype=int_) # EncryptedTensor -return %4 - - """, # noqa: E501 - # pylint: enable=line-too-long - str(excinfo.value), - ) - # intermediates with different shape # ---------------------------------- @cnp.compiler({"x": "encrypted"}) - def function3(x): + def function2(x): return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = [np.random.randint(0, 2**7, size=(3, 2)) for _ in range(100)] + function2.compile(inputset, configuration) + + helpers.check_str( + # pylint: disable=line-too-long + """ + +A subgraph within the function you are trying to compile cannot be fused because of a node, which is has a different shape than the input node + +%0 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node +%1 = sin(%0) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%2 = absolute(%1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%4 = astype(%3, dtype=int_) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node has a different shape than the input node +return %4 + + """, # noqa: E501 + # pylint: enable=line-too-long + str(excinfo.value), + ) + + # non-fusable operation + # --------------------- + + @cnp.compiler({"x": "encrypted"}) + def function3(x): + return np.abs(np.sin(x)).transpose().astype(np.int64) + + with pytest.raises(RuntimeError) as excinfo: + inputset = [np.random.randint(0, 2**7, size=(2, 2)) for _ in range(100)] function3.compile(inputset, configuration) helpers.check_str( # pylint: disable=line-too-long """ -Function you are trying to compile cannot be converted to MLIR +A subgraph within the function you are trying to compile cannot be fused because of a node, which is marked explicitly as non-fusable -%0 = x # EncryptedTensor -%1 = sin(%0) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%2 = absolute(%1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported -%4 = astype(%3, dtype=int_) # EncryptedTensor +%0 = x # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node +%1 = sin(%0) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%2 = absolute(%1) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +%3 = transpose(%2) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is not fusable +%4 = astype(%3, dtype=int_) # EncryptedTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph return %4 """, # noqa: E501 @@ -793,6 +805,36 @@ return %4 str(excinfo.value), ) + # integer two variable inputs + # --------------------------- + + @cnp.compiler({"x": "encrypted", "y": "clear"}) + def function4(x, y): + return np.maximum(x, y) + + with pytest.raises(RuntimeError) as excinfo: + inputset = [(i, i) for i in range(100)] + function4.compile(inputset, configuration) + + helpers.check_str( + # pylint: disable=line-too-long + """ + +A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes + +%0 = x # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes +%1 = y # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes +%2 = maximum(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph +return %2 + + """, # noqa: E501 + # pylint: enable=line-too-long + str(excinfo.value), + ) + def test_others_bad_univariate(helpers): """ diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 319afe653..7bce4cc67 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -311,26 +311,6 @@ return %1 """, # noqa: E501 ), - pytest.param( - lambda x, y: np.maximum(x, y), - {"x": "encrypted", "y": "clear"}, - [ - (np.random.randint(0, 2, size=(1,)), np.random.randint(0, 2, size=(1,))) - for _ in range(100) - ], - RuntimeError, - """ - -Function you are trying to compile cannot be converted to MLIR - -%0 = x # EncryptedTensor -%1 = y # ClearTensor -%2 = maximum(%0, %1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only single input table lookups are supported -return %2 - - """, # noqa: E501 - ), pytest.param( lambda x: np.maximum(x, np.array([3])), {"x": "clear"},