mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: improve fusing related error messages
This commit is contained in:
@@ -22,6 +22,10 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
compilation artifacts to store information about the fusing process
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if there is a subgraph which needs to be fused cannot be fused
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
@@ -31,12 +35,12 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
while True:
|
||||
subgraph_to_fuse = (
|
||||
find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph,
|
||||
graph,
|
||||
processed_terminal_nodes,
|
||||
)
|
||||
if fusing_floats
|
||||
else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
|
||||
nx_graph,
|
||||
graph,
|
||||
processed_terminal_nodes,
|
||||
)
|
||||
)
|
||||
@@ -51,16 +55,12 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
subgraph_conversion_result = convert_subgraph_to_subgraph_node(
|
||||
nx_graph,
|
||||
fused_node, node_before_subgraph = convert_subgraph_to_subgraph_node(
|
||||
graph,
|
||||
all_nodes,
|
||||
start_nodes,
|
||||
terminal_node,
|
||||
)
|
||||
if subgraph_conversion_result is None:
|
||||
continue
|
||||
|
||||
fused_node, node_before_subgraph = subgraph_conversion_result
|
||||
nx_graph.add_node(fused_node)
|
||||
|
||||
if terminal_node in graph.output_nodes.values():
|
||||
@@ -89,14 +89,14 @@ def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
graph: Graph,
|
||||
processed_terminal_nodes: Set[Node],
|
||||
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
"""
|
||||
Find a subgraph with float computations that end with an integer output.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph (Graph):
|
||||
graph to search
|
||||
|
||||
processed_terminal_nodes (Set[Node]):
|
||||
@@ -109,6 +109,7 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
and terminal node of the subgraph otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
@@ -128,7 +129,7 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
start_single_int_output_nodes_search_from = terminal_node
|
||||
while True:
|
||||
all_nodes, start_nodes = find_closest_integer_output_nodes(
|
||||
nx_graph,
|
||||
graph,
|
||||
[start_single_int_output_nodes_search_from],
|
||||
all_nodes,
|
||||
)
|
||||
@@ -141,14 +142,14 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
lca = find_single_lca(nx_graph, variable_start_nodes)
|
||||
lca = find_single_lca(graph, variable_start_nodes)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
break
|
||||
|
||||
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
||||
all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes)
|
||||
all_nodes = add_nodes_from_to(graph, start_nodes, {lca: None}, all_nodes)
|
||||
|
||||
# if `lca` is a valid starting node for fusing break
|
||||
if isinstance(lca.output.dtype, Integer):
|
||||
@@ -164,7 +165,7 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
|
||||
def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
graph: Graph,
|
||||
processed_terminal_nodes: Set[Node],
|
||||
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
"""
|
||||
@@ -172,7 +173,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
where all variable inputs share a common ancestor.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph (Graph):
|
||||
graph to search
|
||||
|
||||
processed_terminal_nodes (Set[Node]):
|
||||
@@ -185,6 +186,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
and terminal node of the subgraph otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
@@ -215,16 +217,17 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
lca = find_single_lca(nx_graph, variable_start_nodes)
|
||||
lca = find_single_lca(graph, variable_start_nodes)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
start_nodes = {}
|
||||
start_nodes = {node: None for node in variable_start_nodes}
|
||||
all_nodes = {node: None for node in variable_start_nodes + [terminal_node]}
|
||||
break
|
||||
|
||||
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
||||
all_nodes = add_nodes_from_to(
|
||||
nx_graph,
|
||||
graph,
|
||||
list(nx_graph.predecessors(terminal_node)),
|
||||
{lca: None},
|
||||
all_nodes,
|
||||
@@ -240,12 +243,12 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
return all_nodes, start_nodes, terminal_node
|
||||
|
||||
|
||||
def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]:
|
||||
def find_single_lca(graph: Graph, nodes: List[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the single lowest common ancestor of a list of nodes.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph (Graph):
|
||||
graph to search for single lca
|
||||
|
||||
nodes (List[Node]):
|
||||
@@ -256,6 +259,8 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No
|
||||
single lca if it exists, None otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
# find all ancestors of `nodes`
|
||||
# nodes themselves need to be in this set because the single lca can be within `nodes`
|
||||
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
|
||||
@@ -278,7 +283,7 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No
|
||||
continue
|
||||
|
||||
# check if node is a single common ancestor of `nodes`
|
||||
if is_single_common_ancestor(nx_graph, candidate, nodes):
|
||||
if is_single_common_ancestor(graph, candidate, nodes):
|
||||
# if so, it's the single lca of `nodes`
|
||||
# so return it
|
||||
return candidate
|
||||
@@ -289,7 +294,7 @@ def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[No
|
||||
|
||||
|
||||
def is_single_common_ancestor(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
graph: Graph,
|
||||
candidate: Node,
|
||||
nodes: List[Node],
|
||||
) -> bool:
|
||||
@@ -299,7 +304,7 @@ def is_single_common_ancestor(
|
||||
Note that this function doesn't care about `lowest` property of `lca`.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph (Graph):
|
||||
graph to perform the check
|
||||
|
||||
candidate (Node):
|
||||
@@ -313,6 +318,8 @@ def is_single_common_ancestor(
|
||||
True if `candidate` is a single common ancestor of `nodes`, False otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
# create a subgraph with `candidate` node
|
||||
subgraph = nx.DiGraph()
|
||||
subgraph.add_node(candidate)
|
||||
@@ -420,7 +427,7 @@ def is_single_common_ancestor(
|
||||
|
||||
|
||||
def find_closest_integer_output_nodes(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
graph: Graph,
|
||||
start_nodes: List[Node],
|
||||
all_nodes: Dict[Node, None],
|
||||
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
|
||||
@@ -428,7 +435,7 @@ def find_closest_integer_output_nodes(
|
||||
Find the closest upstream integer output nodes to a set of start nodes in a graph.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph (Graph):
|
||||
graph to search
|
||||
|
||||
start_nodes (List[Node]):
|
||||
@@ -442,6 +449,8 @@ def find_closest_integer_output_nodes(
|
||||
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
closest_integer_output_nodes: Dict[Node, None] = {}
|
||||
visited_nodes: Set[Node] = set()
|
||||
|
||||
@@ -465,7 +474,7 @@ def find_closest_integer_output_nodes(
|
||||
|
||||
|
||||
def add_nodes_from_to(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
graph: Graph,
|
||||
from_nodes: Iterable[Node],
|
||||
to_nodes: Dict[Node, None],
|
||||
all_nodes: Dict[Node, None],
|
||||
@@ -474,7 +483,7 @@ def add_nodes_from_to(
|
||||
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph (Graph):
|
||||
graph to traverse
|
||||
|
||||
from_nodes (Iterable[Node]):
|
||||
@@ -491,6 +500,8 @@ def add_nodes_from_to(
|
||||
extended `all_nodes`
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
all_nodes.update(to_nodes)
|
||||
visited_nodes: Set[Node] = set()
|
||||
|
||||
@@ -511,17 +522,17 @@ def add_nodes_from_to(
|
||||
|
||||
|
||||
def convert_subgraph_to_subgraph_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
graph: Graph,
|
||||
all_nodes: Dict[Node, None],
|
||||
start_nodes: Dict[Node, None],
|
||||
terminal_node: Node,
|
||||
) -> Optional[Tuple[Node, Node]]:
|
||||
) -> Tuple[Node, Node]:
|
||||
"""
|
||||
Convert a subgraph to Operation.Generic node.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
orginal networkx graph
|
||||
graph (Graph):
|
||||
orginal graph
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
all nodes in the subgraph
|
||||
@@ -532,19 +543,32 @@ def convert_subgraph_to_subgraph_node(
|
||||
terminal_node (Node):
|
||||
terminal node of the subgraph
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if subgraph is not fusable
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Node, Node]]:
|
||||
Tuple[Node, Node]:
|
||||
None if the subgraph cannot be fused,
|
||||
subgraph node and its predecessor otherwise
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
|
||||
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
|
||||
if len(variable_input_nodes) != 1:
|
||||
return None
|
||||
base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes}
|
||||
for variable_input_node in variable_input_nodes:
|
||||
base_highlighted_nodes[variable_input_node] = ["this is one of the input nodes"]
|
||||
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because it has multiple input nodes\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes)
|
||||
)
|
||||
|
||||
variable_input_node = variable_input_nodes[0]
|
||||
if not subgraph_can_be_fused(all_nodes, variable_input_node):
|
||||
return None
|
||||
check_subgraph_fusability(graph, all_nodes, variable_input_node)
|
||||
|
||||
nx_subgraph = nx.MultiDiGraph(nx_graph)
|
||||
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
|
||||
@@ -583,10 +607,11 @@ def convert_subgraph_to_subgraph_node(
|
||||
return subgraph_node, variable_input_node
|
||||
|
||||
|
||||
def subgraph_can_be_fused(
|
||||
def check_subgraph_fusability(
|
||||
graph: Graph,
|
||||
all_nodes: Dict[Node, None],
|
||||
variable_input_node: Node,
|
||||
) -> bool:
|
||||
):
|
||||
"""
|
||||
Determine if a subgraph can be fused.
|
||||
|
||||
@@ -596,35 +621,42 @@ def subgraph_can_be_fused(
|
||||
between each cell of the input and each cell of the output for table lookups
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
original graph
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
all nodes in the subgraph
|
||||
|
||||
variable_input_node (Node):
|
||||
variable input node to the subgraph
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if subgraph can be fused,
|
||||
False otherwise
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if subgraph is not fusable
|
||||
"""
|
||||
|
||||
constant_nodes_with_bigger_size_than_variable_input = [
|
||||
node
|
||||
for node in all_nodes
|
||||
if (
|
||||
node.operation == Operation.Constant
|
||||
and node.output.size > variable_input_node.output.size
|
||||
)
|
||||
]
|
||||
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
|
||||
return False
|
||||
base_highlighted_nodes = {node: ["within this subgraph"] for node in all_nodes}
|
||||
base_highlighted_nodes[variable_input_node] = ["with this input node"]
|
||||
|
||||
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
|
||||
for node in non_constant_nodes:
|
||||
if node == variable_input_node:
|
||||
continue
|
||||
|
||||
if not node.is_fusable or node.output.shape != variable_input_node.output.shape:
|
||||
return False
|
||||
if not node.is_fusable:
|
||||
base_highlighted_nodes[node] = ["this node is not fusable"]
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because of a node, which is marked explicitly as non-fusable\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes)
|
||||
)
|
||||
|
||||
if node.output.shape != variable_input_node.output.shape:
|
||||
base_highlighted_nodes[node] = ["this node has a different shape than the input node"]
|
||||
raise RuntimeError(
|
||||
"A subgraph within the function you are trying to compile cannot be fused "
|
||||
"because of a node, which is has a different shape than the input node\n\n"
|
||||
+ graph.format(highlighted_nodes=base_highlighted_nodes)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -165,8 +165,7 @@ class GraphConverter:
|
||||
for idx, pred in enumerate(graph.ordered_preds_of(node))
|
||||
if not pred.operation == Operation.Constant
|
||||
]
|
||||
if len(variable_input_indices) != 1:
|
||||
return "only single input table lookups are supported"
|
||||
assert_that(len(variable_input_indices) == 1)
|
||||
|
||||
if len(inputs) > 0 and all(input.is_clear for input in inputs):
|
||||
return "one of the operands must be encrypted"
|
||||
|
||||
@@ -700,29 +700,36 @@ def test_others_bad_fusing(helpers):
|
||||
# pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%1 = 10 # ClearScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%2 = 2 # ClearScalar<uint2>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%3 = 2 # ClearScalar<uint2>
|
||||
%4 = x # EncryptedScalar<uint7>
|
||||
%5 = y # ClearScalar<uint7>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%4 = x # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%5 = y # ClearScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%6 = sin(%4) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%7 = cos(%5) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%8 = power(%6, %2) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%9 = power(%7, %3) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%10 = multiply(%0, %8) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%11 = multiply(%1, %9) # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%12 = add(%10, %11) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
return %13
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -730,62 +737,67 @@ return %13
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
# big intermediate constants
|
||||
# --------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function2(x):
|
||||
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = range(100)
|
||||
function2.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
|
||||
%1 = x # EncryptedScalar<uint7>
|
||||
%2 = sin(%1) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%3 = multiply(%2, %0) # EncryptedTensor<float64, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%4 = astype(%3, dtype=int_) # EncryptedTensor<int3, shape=(2, 2)>
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
# intermediates with different shape
|
||||
# ----------------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function3(x):
|
||||
def function2(x):
|
||||
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [np.random.randint(0, 2**7, size=(3, 2)) for _ in range(100)]
|
||||
function2.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
A subgraph within the function you are trying to compile cannot be fused because of a node, which is has a different shape than the input node
|
||||
|
||||
%0 = x # EncryptedTensor<uint7, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node
|
||||
%1 = sin(%0) # EncryptedTensor<float64, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%2 = absolute(%1) # EncryptedTensor<float64, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor<float64, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node has a different shape than the input node
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
# non-fusable operation
|
||||
# ---------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function3(x):
|
||||
return np.abs(np.sin(x)).transpose().astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [np.random.randint(0, 2**7, size=(2, 2)) for _ in range(100)]
|
||||
function3.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
A subgraph within the function you are trying to compile cannot be fused because of a node, which is marked explicitly as non-fusable
|
||||
|
||||
%0 = x # EncryptedTensor<uint7, shape=(3, 2)>
|
||||
%1 = sin(%0) # EncryptedTensor<float64, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%2 = absolute(%1) # EncryptedTensor<float64, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor<float64, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer operations are supported
|
||||
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 3)>
|
||||
%0 = x # EncryptedTensor<uint7, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node
|
||||
%1 = sin(%0) # EncryptedTensor<float64, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%2 = absolute(%1) # EncryptedTensor<float64, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
%3 = transpose(%2) # EncryptedTensor<float64, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is not fusable
|
||||
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -793,6 +805,36 @@ return %4
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
# integer two variable inputs
|
||||
# ---------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted", "y": "clear"})
|
||||
def function4(x, y):
|
||||
return np.maximum(x, y)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [(i, i) for i in range(100)]
|
||||
function4.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
|
||||
|
||||
%0 = x # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%1 = y # ClearScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
|
||||
%2 = maximum(%0, %1) # EncryptedScalar<uint1>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
# pylint: enable=line-too-long
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
|
||||
def test_others_bad_univariate(helpers):
|
||||
"""
|
||||
|
||||
@@ -311,26 +311,6 @@ return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: np.maximum(x, y),
|
||||
{"x": "encrypted", "y": "clear"},
|
||||
[
|
||||
(np.random.randint(0, 2, size=(1,)), np.random.randint(0, 2, size=(1,)))
|
||||
for _ in range(100)
|
||||
],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedTensor<uint1, shape=(1,)>
|
||||
%1 = y # ClearTensor<uint1, shape=(1,)>
|
||||
%2 = maximum(%0, %1) # EncryptedTensor<uint1, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only single input table lookups are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.maximum(x, np.array([3])),
|
||||
{"x": "clear"},
|
||||
|
||||
Reference in New Issue
Block a user