diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index 31927c1ad..4a93182c1 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -30,10 +30,6 @@ class OPGraph: if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input) } - graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0) - - assert set(self.output_nodes.values()) == graph_outputs - def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]: """Function to evaluate a graph and get intermediate values for all nodes diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 9187a91a1..ea611a791 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -58,6 +58,10 @@ from hdk.hnumpy import tracing lambda x, y: (y, x), "\n%0 = y\n%1 = x\nreturn(%0, %1)", ), + ( + lambda x, y: (x, x + 1), + "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%0, %2)", + ), ], ) @pytest.mark.parametrize(