diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index fd0f7ce11..381fdffaf 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -682,10 +682,25 @@ def test_nptracer_unsupported_operands(operation, exception_type, match): assert match in str(exc_info) +def subtest_tracing_calls( + function_to_trace, + input_value_input_and_expected_output_tuples, +): + """Test memory function managed by GenericFunction node of the form numpy.something""" + for input_value, input_, expected_output in input_value_input_and_expected_output_tuples: + + op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) + output_node = op_graph.output_nodes[0] + + node_results = op_graph.evaluate({0: input_}) + evaluated_output = node_results[output_node] + assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) + assert numpy.array_equal(expected_output, evaluated_output) + + @pytest.mark.parametrize( "function_to_trace,input_value_input_and_expected_output_tuples", [ - # Indirect calls, like numpy.function(x, ...) ( lambda x: numpy.transpose(x), [ @@ -751,7 +766,19 @@ def test_nptracer_unsupported_operands(operation, exception_type, match): ), ], ), - # Direct calls, like x.function(...) + ], +) +def test_tracing_numpy_calls( + function_to_trace, + input_value_input_and_expected_output_tuples, +): + """Test memory function managed by GenericFunction node of the form numpy.something""" + subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples) + + +@pytest.mark.parametrize( + "function_to_trace,input_value_input_and_expected_output_tuples", + [ ( lambda x: x.transpose() + 42, [ @@ -810,20 +837,12 @@ def test_nptracer_unsupported_operands(operation, exception_type, match): ), ], ) -def test_tracing_generic_function_memory_ops( +def test_tracing_ndarray_calls( function_to_trace, input_value_input_and_expected_output_tuples, ): - """Test memory function managed by GenericFunction node""" - for input_value, input_, expected_output in input_value_input_and_expected_output_tuples: - - op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) - output_node = op_graph.output_nodes[0] - - node_results = op_graph.evaluate({0: input_}) - evaluated_output = node_results[output_node] - assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) - assert numpy.array_equal(expected_output, evaluated_output) + """Test memory function managed by GenericFunction node of the form ndarray.something""" + subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples) @pytest.mark.parametrize(