fix(frontend): non fusable graph is not an error

reported in https://community.zama.ai/t/implementing-comparison-strategies/3469/5
This commit is contained in:
rudy-6-4
2025-10-07 16:55:03 +02:00
committed by rudy
parent ecb729bb0a
commit d8d18022ad
2 changed files with 15 additions and 9 deletions

View File

@@ -248,12 +248,15 @@ def fuse(graph: Graph, artifacts: Optional["FunctionDebugArtifacts"] = None):
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
processed_terminal_nodes.add(terminal_node)
conversion_result = convert_subgraph_to_subgraph_node(
graph,
all_nodes,
start_nodes,
terminal_node,
)
try:
conversion_result = convert_subgraph_to_subgraph_node(
graph,
all_nodes,
start_nodes,
terminal_node,
)
except NotFusable:
conversion_result = None
if conversion_result is None:
continue
@@ -769,7 +772,7 @@ def convert_subgraph_to_subgraph_node(
if terminal_node.properties["name"] == "where":
return None
raise RuntimeError(
raise NotFusable(
"A subgraph within the function you are trying to compile cannot be fused "
"because it has multiple input nodes\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
@@ -826,6 +829,8 @@ def convert_subgraph_to_subgraph_node(
return subgraph_node, variable_input_node
class NotFusable(RuntimeError):
pass
def check_subgraph_fusibility(
graph: Graph,
@@ -868,7 +873,7 @@ def check_subgraph_fusibility(
if not node.is_fusable:
base_highlighted_nodes[node] = ["this node is not fusable", node.location]
raise RuntimeError(
raise NotFusable(
"A subgraph within the function you are trying to compile cannot be fused "
"because of a node, which is marked explicitly as non-fusable\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)
@@ -879,7 +884,7 @@ def check_subgraph_fusibility(
"this node has a different shape than the input node",
node.location,
]
raise RuntimeError(
raise NotFusable(
"A subgraph within the function you are trying to compile cannot be fused "
"because of a node, which is has a different shape than the input node\n\n"
+ graph.format(highlighted_nodes=base_highlighted_nodes, show_bounds=False)

View File

@@ -64,6 +64,7 @@ for _ in range(35):
("<=", lambda x, y: x <= y),
(">", lambda x, y: x > y),
(">=", lambda x, y: x >= y),
("mixed", lambda x, y: fhe.if_then_else(x < y, x, y))
]
),
# bit widths