mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore: clarify these tests a bit
This commit is contained in:
committed by
Benoit Chevallier
parent
93820e1588
commit
1394dd6db5
@@ -682,10 +682,25 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
assert match in str(exc_info)
|
||||
|
||||
|
||||
def subtest_tracing_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
# Indirect calls, like numpy.function(x, ...)
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
[
|
||||
@@ -751,7 +766,19 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
),
|
||||
],
|
||||
),
|
||||
# Direct calls, like x.function(...)
|
||||
],
|
||||
)
|
||||
def test_tracing_numpy_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: x.transpose() + 42,
|
||||
[
|
||||
@@ -810,20 +837,12 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_generic_function_memory_ops(
|
||||
def test_tracing_ndarray_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user