chore: clarify these tests a bit

This commit is contained in:
Benoit Chevallier-Mames
2021-11-12 17:40:50 +01:00
committed by Benoit Chevallier
parent 93820e1588
commit 1394dd6db5

View File

@@ -682,10 +682,25 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
assert match in str(exc_info)
def subtest_tracing_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
assert numpy.array_equal(expected_output, evaluated_output)
@pytest.mark.parametrize(
"function_to_trace,input_value_input_and_expected_output_tuples",
[
# Indirect calls, like numpy.function(x, ...)
(
lambda x: numpy.transpose(x),
[
@@ -751,7 +766,19 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
),
],
),
# Direct calls, like x.function(...)
],
)
def test_tracing_numpy_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
@pytest.mark.parametrize(
"function_to_trace,input_value_input_and_expected_output_tuples",
[
(
lambda x: x.transpose() + 42,
[
@@ -810,20 +837,12 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
),
],
)
def test_tracing_generic_function_memory_ops(
def test_tracing_ndarray_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
):
"""Test memory function managed by GenericFunction node"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
assert numpy.array_equal(expected_output, evaluated_output)
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
@pytest.mark.parametrize(