test: create check_array_equality fixture

This commit is contained in:
Umut
2021-12-06 16:16:50 +03:00
parent ae841cf77f
commit 5aad8c50ac
9 changed files with 85 additions and 45 deletions

View File

@@ -378,6 +378,7 @@ def test_fuse_float_operations(
warning_message,
capfd,
remove_color_codes,
check_array_equality,
):
"""Test function for fuse_float_operations"""
@@ -405,7 +406,7 @@ def test_fuse_float_operations(
input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32)
inputs += (input_,)
assert numpy.array_equal(function_to_trace(*inputs), op_graph(*inputs))
check_array_equality(function_to_trace(*inputs), op_graph(*inputs))
def subtest_tensor_no_fuse(fun, tensor_shape):

View File

@@ -220,10 +220,11 @@ def test_evaluate(
node: ir.IntermediateNode,
input_data,
expected_result: int,
check_array_equality,
):
"""Test evaluate methods on IntermediateNodes"""
if isinstance(expected_result, numpy.ndarray):
assert numpy.array_equal(node.evaluate(input_data), expected_result)
check_array_equality(node.evaluate(input_data), expected_result)
else:
assert node.evaluate(input_data) == expected_result

View File

@@ -335,3 +335,30 @@ def check_is_good_execution():
"""Fixture to seed torch"""
return check_is_good_execution_impl
def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True):
"""Assert that `actual` is equal to `expected`."""
assert numpy.array_equal(actual, expected), (
""
if not verbose
else f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
)
@pytest.fixture
def check_array_equality():
"""Fixture to check array equality"""
return check_array_equality_impl

View File

@@ -1096,6 +1096,7 @@ def test_compile_and_run_tensor_correctness(
use_check_good_exec,
default_compilation_configuration,
check_is_good_execution,
check_array_equality,
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
@@ -1113,7 +1114,7 @@ def test_compile_and_run_tensor_correctness(
if use_check_good_exec:
check_is_good_execution(circuit, function, numpy_test_input)
else:
assert numpy.array_equal(
check_array_equality(
circuit.run(*numpy_test_input),
numpy.array(function(*numpy_test_input), dtype=numpy.uint8),
)
@@ -1336,7 +1337,7 @@ def test_compile_and_run_constant_dot_correctness(
],
)
def test_compile_and_run_matmul_correctness(
lhs_shape, rhs_shape, input_range, default_compilation_configuration
lhs_shape, rhs_shape, input_range, default_compilation_configuration, check_array_equality
):
"""Test correctness of results when running a compiled function"""
@@ -1371,8 +1372,8 @@ def test_compile_and_run_matmul_correctness(
)
args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),)
assert numpy.array_equal(operator_circuit.run(*args), using_operator(*args))
assert numpy.array_equal(function_circuit.run(*args), using_function(*args))
check_array_equality(operator_circuit.run(*args), using_operator(*args))
check_array_equality(function_circuit.run(*args), using_function(*args))
@pytest.mark.parametrize(

View File

@@ -709,6 +709,7 @@ def test_constant_indexing_run_correctness(
test_input,
expected_output,
default_compilation_configuration,
check_array_equality,
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
@@ -725,19 +726,7 @@ def test_constant_indexing_run_correctness(
output = circuit.run(*numpy_test_input)
expected = np.array(expected_output, dtype=np.uint8)
assert np.array_equal(
output, expected
), f"""
Actual Output
=============
{output}
Expected Output
===============
{expected}
"""
check_array_equality(output, expected)
@pytest.mark.parametrize(

View File

@@ -22,13 +22,21 @@ def complicated_topology(x, y):
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
def test_np_fhe_compiler_op_graph(input_shape, default_compilation_configuration):
def test_np_fhe_compiler_op_graph(
input_shape, default_compilation_configuration, check_array_equality
):
"""Test NPFHECompiler in two subtests."""
subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration)
subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration)
subtest_np_fhe_compiler_1_input_op_graph(
input_shape, default_compilation_configuration, check_array_equality
)
subtest_np_fhe_compiler_2_inputs_op_graph(
input_shape, default_compilation_configuration, check_array_equality
)
def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration):
def subtest_np_fhe_compiler_1_input_op_graph(
input_shape, default_compilation_configuration, check_array_equality
):
"""test for NPFHECompiler on one input function"""
def function_to_compile(x):
@@ -48,7 +56,7 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co
for i in numpy.arange(5):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
assert numpy.array_equal(compiler(i), function_to_compile(i))
check_array_equality(compiler(i), function_to_compile(i))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
@@ -60,7 +68,7 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co
# Continue a bit more
for i in numpy.arange(5, 10):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
assert numpy.array_equal(compiler(i), function_to_compile(i))
check_array_equality(compiler(i), function_to_compile(i))
if input_shape == ():
assert (
@@ -106,7 +114,9 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co
), got
def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration):
def subtest_np_fhe_compiler_2_inputs_op_graph(
input_shape, default_compilation_configuration, check_array_equality
):
"""test for NPFHECompiler on two inputs function"""
compiler = NPFHECompiler(
@@ -124,7 +134,7 @@ def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_c
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
j = numpy.ones(input_shape, dtype=numpy.int64) * j
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
check_array_equality(compiler(i, j), complicated_topology(i, j))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
@@ -137,7 +147,7 @@ def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_c
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
j = numpy.ones(input_shape, dtype=numpy.int64) * j
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
check_array_equality(compiler(i, j), complicated_topology(i, j))
if input_shape == ():
assert (
@@ -200,6 +210,7 @@ def test_np_fhe_compiler_auto_flush(
inputset_len,
expected_remaining_inputset_len,
default_compilation_configuration,
check_array_equality,
):
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
@@ -213,7 +224,7 @@ def test_np_fhe_compiler_auto_flush(
)
for i in numpy.arange(inputset_len):
assert numpy.array_equal(compiler(i), function_to_compile(i))
check_array_equality(compiler(i), function_to_compile(i))
# Check the inputset was properly flushed
assert (
@@ -222,7 +233,7 @@ def test_np_fhe_compiler_auto_flush(
)
def test_np_fhe_compiler_full_compilation(default_compilation_configuration):
def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality):
"""Test the case where we generate an FHE circuit."""
def function_to_compile(x):
@@ -244,7 +255,7 @@ def test_np_fhe_compiler_full_compilation(default_compilation_configuration):
)
for i in numpy.arange(64):
assert numpy.array_equal(compiler(i), function_to_compile(i))
check_array_equality(compiler(i), function_to_compile(i))
fhe_circuit = compiler.get_compiled_fhe_circuit()

View File

@@ -327,7 +327,7 @@ def test_tracing_astype(
assert expected_output == evaluated_output
def test_tracing_astype_single_element_array_corner_case():
def test_tracing_astype_single_element_array_corner_case(check_array_equality):
"""Test corner case where an array could be transformed to its scalar element"""
a = numpy.array([1], dtype=numpy.float64)
@@ -336,7 +336,7 @@ def test_tracing_astype_single_element_array_corner_case():
)
eval_result = op_graph(a)
assert numpy.array_equal(numpy.array([1], dtype=numpy.int32), eval_result)
check_array_equality(eval_result, numpy.array([1], dtype=numpy.int32))
@pytest.mark.parametrize(
@@ -492,6 +492,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
def subtest_tracing_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
@@ -502,11 +503,7 @@ def subtest_tracing_calls(
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
if not numpy.array_equal(expected_output, evaluated_output):
print("Wrong result")
print(f"Expected: {expected_output}")
print(f"Got : {evaluated_output}")
raise AssertionError
check_array_equality(evaluated_output, expected_output)
@pytest.mark.parametrize(
@@ -582,9 +579,12 @@ def subtest_tracing_calls(
def test_tracing_numpy_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)
@pytest.mark.parametrize(
@@ -964,9 +964,12 @@ def test_tracing_numpy_calls(
def test_tracing_ndarray_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)
@pytest.mark.parametrize(

View File

@@ -142,6 +142,7 @@ def test_nptracer_get_tracing_func_for_np_functions(np_function):
def subtest_tracing_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
@@ -152,7 +153,7 @@ def subtest_tracing_calls(
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
assert numpy.array_equal(expected_output, evaluated_output)
check_array_equality(evaluated_output, expected_output)
@pytest.mark.parametrize(
@@ -228,9 +229,12 @@ def subtest_tracing_calls(
def test_tracing_numpy_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)
@pytest.mark.parametrize(
@@ -297,6 +301,9 @@ def test_tracing_numpy_calls(
def test_tracing_ndarray_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)

View File

@@ -20,7 +20,7 @@ N_BITS_ATOL_TUPLE_LIST = [
)
@pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)])
@pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))])
def test_quant_dequant_update(values, n_bits, atol, is_signed):
def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality):
"""Test the quant and dequant function."""
quant_array = QuantizedArray(n_bits, values, is_signed)
@@ -51,4 +51,4 @@ def test_quant_dequant_update(values, n_bits, atol, is_signed):
assert not numpy.array_equal(new_values, new_values_updated)
# Check that the __call__ returns also the qvalues.
assert numpy.array_equal(quant_array(), new_qvalues)
check_array_equality(quant_array(), new_qvalues)