feat: improve fusing related error messages

This commit is contained in:
Umut
2022-10-27 11:56:39 +02:00
parent 0827cb9360
commit f0014dc05f
4 changed files with 180 additions and 127 deletions

View File

@@ -22,6 +22,10 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
artifacts (Optional[DebugArtifacts], default = None):
compilation artifacts to store information about the fusing process
Raises:
RuntimeError:
if there is a subgraph which needs to be fused cannot be fused
"""
nx_graph = graph.graph
@@ -31,12 +35,12 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
while True:
subgraph_to_fuse = (
find_float_subgraph_with_unique_terminal_node(
nx_graph,
graph,
processed_terminal_nodes,
)
if fusing_floats
else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
nx_graph,
graph,
processed_terminal_nodes,
)
)
@@ -51,16 +55,12 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
processed_terminal_nodes.add(terminal_node)
subgraph_conversion_result = convert_subgraph_to_subgraph_node(
nx_graph,
fused_node, node_before_subgraph = convert_subgraph_to_subgraph_node(
graph,
all_nodes,
start_nodes,
terminal_node,
)
if subgraph_conversion_result is None:
continue
fused_node, node_before_subgraph = subgraph_conversion_result
nx_graph.add_node(fused_node)
if terminal_node in graph.output_nodes.values():
@@ -89,14 +89,14 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
graph: Graph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
Find a subgraph with float computations that end with an integer output.
Args:
nx_graph (nx.MultiDiGraph):
graph (Graph):
graph to search
processed_terminal_nodes (Set[Node]):
@@ -109,6 +109,7 @@ def find_float_subgraph_with_unique_terminal_node(
and terminal node of the subgraph otherwise
"""
nx_graph = graph.graph
terminal_nodes = (
node
for node in nx_graph.nodes()
@@ -128,7 +129,7 @@ def find_float_subgraph_with_unique_terminal_node(
start_single_int_output_nodes_search_from = terminal_node
while True:
all_nodes, start_nodes = find_closest_integer_output_nodes(
nx_graph,
graph,
[start_single_int_output_nodes_search_from],
all_nodes,
)
@@ -141,14 +142,14 @@ def find_float_subgraph_with_unique_terminal_node(
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
lca = find_single_lca(nx_graph, variable_start_nodes)
lca = find_single_lca(graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes)
all_nodes = add_nodes_from_to(graph, start_nodes, {lca: None}, all_nodes)
# if `lca` is a valid starting node for fusing break
if isinstance(lca.output.dtype, Integer):
@@ -164,7 +165,7 @@ def find_float_subgraph_with_unique_terminal_node(
def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
nx_graph: nx.MultiDiGraph,
graph: Graph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
@@ -172,7 +173,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
where all variable inputs share a common ancestor.
Args:
nx_graph (nx.MultiDiGraph):
graph (Graph):
graph to search
processed_terminal_nodes (Set[Node]):
@@ -185,6 +186,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
and terminal node of the subgraph otherwise
"""
nx_graph = graph.graph
terminal_nodes = (
node
for node in nx_graph.nodes()
@@ -215,16 +217,17 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
lca = find_single_lca(nx_graph, variable_start_nodes)
lca = find_single_lca(graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
start_nodes = {}
start_nodes = {node: None for node in variable_start_nodes}
all_nodes = {node: None for node in variable_start_nodes + [terminal_node]}
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(
nx_graph,
graph,
list(nx_graph.predecessors(terminal_node)),
{lca: None},
all_nodes,
@@ -240,12 +243,12 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
return all_nodes, start_nodes, terminal_node
def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]:
def find_single_lca(graph: Graph, nodes: List[Node]) -> Optional[Node]:
"""
Find the single lowest common ancestor of a list of nodes.
Args:
nx_graph (nx.MultiDiGraph):
graph (Graph):
graph to search for single lca
nodes (List[Node]):
@@ -256,6 +259,8 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No
single lca if it exists, None otherwise
"""
nx_graph = graph.graph
# find all ancestors of `nodes`
# nodes themselves need to be in this set because the single lca can be within `nodes`
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
@@ -278,7 +283,7 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No
continue
# check if node is a single common ancestor of `nodes`
if is_single_common_ancestor(nx_graph, candidate, nodes):
if is_single_common_ancestor(graph, candidate, nodes):
# if so, it's the single lca of `nodes`
# so return it
return candidate
@@ -289,7 +294,7 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No
def is_single_common_ancestor(
nx_graph: nx.MultiDiGraph,
graph: Graph,
candidate: Node,
nodes: List[Node],
) -> bool:
@@ -299,7 +304,7 @@ def is_single_common_ancestor(
Note that this function doesn't care about `lowest` property of `lca`.
Args:
nx_graph (nx.MultiDiGraph):
graph (Graph):
graph to perform the check
candidate (Node):
@@ -313,6 +318,8 @@ def is_single_common_ancestor(
True if `candidate` is a single common ancestor of `nodes`, False otherwise
"""
nx_graph = graph.graph
# create a subgraph with `candidate` node
subgraph = nx.DiGraph()
subgraph.add_node(candidate)
@@ -420,7 +427,7 @@ def is_single_common_ancestor(
def find_closest_integer_output_nodes(
nx_graph: nx.MultiDiGraph,
graph: Graph,
start_nodes: List[Node],
all_nodes: Dict[Node, None],
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
@@ -428,7 +435,7 @@ def find_closest_integer_output_nodes(
Find the closest upstream integer output nodes to a set of start nodes in a graph.
Args:
nx_graph (nx.MultiDiGraph):
graph (Graph):
graph to search
start_nodes (List[Node]):
@@ -442,6 +449,8 @@ def find_closest_integer_output_nodes(
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
"""
nx_graph = graph.graph
closest_integer_output_nodes: Dict[Node, None] = {}
visited_nodes: Set[Node] = set()
@@ -465,7 +474,7 @@ def find_closest_integer_output_nodes(
def add_nodes_from_to(
nx_graph: nx.MultiDiGraph,
graph: Graph,
from_nodes: Iterable[Node],
to_nodes: Dict[Node, None],
all_nodes: Dict[Node, None],
@@ -474,7 +483,7 @@ def add_nodes_from_to(
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
Args:
nx_graph (nx.MultiDiGraph):
graph (Graph):
graph to traverse
from_nodes (Iterable[Node]):
@@ -491,6 +500,8 @@ def add_nodes_from_to(
extended `all_nodes`
"""
nx_graph = graph.graph
all_nodes.update(to_nodes)
visited_nodes: Set[Node] = set()
@@ -511,17 +522,17 @@ def add_nodes_from_to(
def convert_subgraph_to_subgraph_node(
nx_graph: nx.MultiDiGraph,
graph: Graph,
all_nodes: Dict[Node, None],
start_nodes: Dict[Node, None],
terminal_node: Node,
) -> Optional[Tuple[Node, Node]]:
) -> Tuple[Node, Node]:
"""
Convert a subgraph to Operation.Generic node.
Args:
nx_graph (nx.MultiDiGraph):
orginal networkx graph
graph (Graph):
orginal graph
all_nodes (Dict[Node, None]):
all nodes in the subgraph
@@ -532,19 +543,32 @@ def convert_subgraph_to_subgraph_node(
terminal_node (Node):
terminal node of the subgraph
Raises:
RuntimeError:
if subgraph is not fusable
Returns:
Optional[Tuple[Node, Node]]:
Tuple[Node, Node]:
None if the subgraph cannot be fused,
subgraph node and its predecessor otherwise
"""
nx_graph = graph.graph
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
if len(variable_input_nodes) != 1:
return None
base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes}
for variable_input_node in variable_input_nodes:
base_highlighted_nodes[variable_input_node] = ["this is one of the input nodes"]
raise RuntimeError(
"A subgraph within the function you are trying to compile cannot be fused "
"because it has multiple input nodes\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes)
)
variable_input_node = variable_input_nodes[0]
if not subgraph_can_be_fused(all_nodes, variable_input_node):
return None
check_subgraph_fusability(graph, all_nodes, variable_input_node)
nx_subgraph = nx.MultiDiGraph(nx_graph)
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
@@ -583,10 +607,11 @@ def convert_subgraph_to_subgraph_node(
return subgraph_node, variable_input_node
def subgraph_can_be_fused(
def check_subgraph_fusability(
graph: Graph,
all_nodes: Dict[Node, None],
variable_input_node: Node,
) -> bool:
):
"""
Determine if a subgraph can be fused.
@@ -596,35 +621,42 @@ def subgraph_can_be_fused(
between each cell of the input and each cell of the output for table lookups
Args:
graph (Graph):
original graph
all_nodes (Dict[Node, None]):
all nodes in the subgraph
variable_input_node (Node):
variable input node to the subgraph
Returns:
bool:
True if subgraph can be fused,
False otherwise
Raises:
RuntimeError:
if subgraph is not fusable
"""
constant_nodes_with_bigger_size_than_variable_input = [
node
for node in all_nodes
if (
node.operation == Operation.Constant
and node.output.size > variable_input_node.output.size
)
]
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
return False
base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes}
base_highlighted_nodes[variable_input_node] = ["with this input node"]
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
for node in non_constant_nodes:
if node == variable_input_node:
continue
if not node.is_fusable or node.output.shape != variable_input_node.output.shape:
return False
if not node.is_fusable:
base_highlighted_nodes[node] = ["this node is not fusable"]
raise RuntimeError(
"A subgraph within the function you are trying to compile cannot be fused "
"because of a node, which is marked explicitly as non-fusable\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes)
)
if node.output.shape != variable_input_node.output.shape:
base_highlighted_nodes[node] = ["this node has a different shape than the input node"]
raise RuntimeError(
"A subgraph within the function you are trying to compile cannot be fused "
"because of a node, which is has a different shape than the input node\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes)
)
return True

View File

@@ -165,8 +165,7 @@ class GraphConverter:
for idx, pred in enumerate(graph.ordered_preds_of(node))
if not pred.operation == Operation.Constant
]
if len(variable_input_indices) != 1:
return "only single input table lookups are supported"
assert_that(len(variable_input_indices) == 1)
if len(inputs) > 0 and all(input.is_clear for input in inputs):
return "one of the operands must be encrypted"

View File

@@ -700,29 +700,36 @@ def test_others_bad_fusing(helpers):
# pylint: disable=line-too-long
"""
Function you are trying to compile cannot be converted to MLIR
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
%0 = 10 # ClearScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%1 = 10 # ClearScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%2 = 2 # ClearScalar<uint2>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%3 = 2 # ClearScalar<uint2>
%4 = x # EncryptedScalar<uint7>
%5 = y # ClearScalar<uint7>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%4 = x # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%5 = y # ClearScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%6 = sin(%4) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%7 = cos(%5) # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%8 = power(%6, %2) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%9 = power(%7, %3) # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%10 = multiply(%0, %8) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%11 = multiply(%1, %9) # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%12 = add(%10, %11) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
return %13
""", # noqa: E501
@@ -730,62 +737,67 @@ return %13
str(excinfo.value),
)
# big intermediate constants
# --------------------------
@cnp.compiler({"x": "encrypted"})
def function2(x):
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = range(100)
function2.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
Function you are trying to compile cannot be converted to MLIR
%0 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
%1 = x # EncryptedScalar<uint7>
%2 = sin(%1) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
%3 = multiply(%2, %0) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
%4 = astype(%3, dtype=int_) # EncryptedTensor<int3, shape=(2, 2)>
return %4
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
# intermediates with different shape
# ----------------------------------
@cnp.compiler({"x": "encrypted"})
def function3(x):
def function2(x):
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [np.random.randint(0, 2**7, size=(3, 2)) for _ in range(100)]
function2.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
A subgraph within the function you are trying to compile cannot be fused because of a node, which is has a different shape than the input node
%0 = x # EncryptedTensor<uint7, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node
%1 = sin(%0) # EncryptedTensor<float64, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%2 = absolute(%1) # EncryptedTensor<float64, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor<float64, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node has a different shape than the input node
return %4
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
# non-fusable operation
# ---------------------
@cnp.compiler({"x": "encrypted"})
def function3(x):
return np.abs(np.sin(x)).transpose().astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [np.random.randint(0, 2**7, size=(2, 2)) for _ in range(100)]
function3.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
Function you are trying to compile cannot be converted to MLIR
A subgraph within the function you are trying to compile cannot be fused because of a node, which is marked explicitly as non-fusable
%0 = x # EncryptedTensor<uint7, shape=(3, 2)>
%1 = sin(%0) # EncryptedTensor<float64, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
%2 = absolute(%1) # EncryptedTensor<float64, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor<float64, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 3)>
%0 = x # EncryptedTensor<uint7, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node
%1 = sin(%0) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%2 = absolute(%1) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%3 = transpose(%2) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is not fusable
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
return %4
""", # noqa: E501
@@ -793,6 +805,36 @@ return %4
str(excinfo.value),
)
# integer two variable inputs
# ---------------------------
@cnp.compiler({"x": "encrypted", "y": "clear"})
def function4(x, y):
return np.maximum(x, y)
with pytest.raises(RuntimeError) as excinfo:
inputset = [(i, i) for i in range(100)]
function4.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
%0 = x # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%1 = y # ClearScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%2 = maximum(%0, %1) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
return %2
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
def test_others_bad_univariate(helpers):
"""

View File

@@ -311,26 +311,6 @@ return %1
""", # noqa: E501
),
pytest.param(
lambda x, y: np.maximum(x, y),
{"x": "encrypted", "y": "clear"},
[
(np.random.randint(0, 2, size=(1,)), np.random.randint(0, 2, size=(1,)))
for _ in range(100)
],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedTensor<uint1, shape=(1,)>
%1 = y # ClearTensor<uint1, shape=(1,)>
%2 = maximum(%0, %1) # EncryptedTensor<uint1, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only single input table lookups are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.maximum(x, np.array([3])),
{"x": "clear"},